Backfill inbound read receipts

Also rename "chat" to "channel" for the Message DB, and make its primary
key include channel IDs
This commit is contained in:
Andrew Ferrazzutti 2022-05-05 03:00:41 -04:00
parent 63fe843724
commit 9a82db2257
9 changed files with 183 additions and 67 deletions

View File

@ -47,9 +47,7 @@
* [x] Message deletion/hiding
* [ ] Message reactions
* [x] Message history
* [ ] Read receipts
* [ ] On backfill
* [x] On live event
* [x] Read receipts
* [x] Admin status
* [ ] Membership actions
* [ ] Invite

View File

@ -36,7 +36,7 @@ class Message:
mx_room: RoomID
ktid: Long | None = field(converter=to_optional_long)
index: int
kt_chat: Long = field(converter=Long)
kt_channel: Long = field(converter=Long)
kt_receiver: Long = field(converter=Long)
timestamp: int
@ -48,18 +48,18 @@ class Message:
def _from_optional_row(cls, row: Record | None) -> Message | None:
return cls._from_row(row) if row is not None else None
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
columns = 'mxid, mx_room, ktid, "index", kt_channel, kt_receiver, timestamp'
@classmethod
async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
rows = await cls.db.fetch(q, ktid, kt_receiver)
async def get_all_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3"
rows = await cls.db.fetch(q, ktid, kt_channel, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
async def get_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3 AND "index"=$4'
row = await cls.db.fetchrow(q, ktid, kt_channel, kt_receiver, index)
return cls._from_optional_row(row)
@classmethod
@ -73,30 +73,39 @@ class Message:
return cls._from_optional_row(row)
@classmethod
async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
async def get_most_recent(cls, kt_channel: int, kt_receiver: int) -> Message | None:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1"
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
"ORDER BY ktid DESC LIMIT 1"
)
row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
row = await cls.db.fetchrow(q, kt_channel, kt_receiver)
return cls._from_optional_row(row)
@classmethod
async def get_closest_before(
cls, kt_chat: int, kt_receiver: int, ktid: Long
cls, kt_channel: int, kt_receiver: int, ktid: int
) -> Message | None:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid<=$3 AND "
" ktid IS NOT NULL "
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid<=$3 "
"ORDER BY ktid DESC LIMIT 1"
)
row = await cls.db.fetchrow(q, kt_chat, kt_receiver, ktid)
row = await cls.db.fetchrow(q, kt_channel, kt_receiver, ktid)
return cls._from_optional_row(row)
@classmethod
async def get_all_since(cls, kt_channel: int, kt_receiver: int, since_ktid: Long | None) -> list[Message]:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid>=$3 "
"ORDER BY ktid"
)
rows = await cls.db.fetch(q, kt_channel, kt_receiver, since_ktid or 0)
return [cls._from_row(row) for row in rows if row]
_insert_query = (
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_channel, kt_receiver, '
" timestamp) "
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
@ -105,7 +114,7 @@ class Message:
async def bulk_create(
cls,
ktid: Long,
kt_chat: Long,
kt_channel: Long,
kt_receiver: Long,
event_ids: list[EventID],
timestamp: int,
@ -115,7 +124,7 @@ class Message:
return []
columns = [col.strip('"') for col in cls.columns.split(", ")]
records = [
(mxid, mx_room, ktid, index, kt_chat, kt_receiver, timestamp)
(mxid, mx_room, ktid, index, kt_channel, kt_receiver, timestamp)
for index, mxid in enumerate(event_ids)
]
async with cls.db.acquire() as conn, conn.transaction():
@ -134,7 +143,7 @@ class Message:
self.mx_room,
self.ktid,
self.index,
self.kt_chat,
self.kt_channel,
self.kt_receiver,
self.timestamp,
)

View File

@ -23,7 +23,7 @@ from attr import dataclass, field
from mautrix.types import ContentURI, RoomID, UserID
from mautrix.util.async_db import Database
from ..kt.types.bson import Long
from ..kt.types.bson import Long, to_optional_long
from ..kt.types.channel.channel_type import ChannelType
fake_db = Database.create("") if TYPE_CHECKING else None
@ -45,6 +45,7 @@ class Portal:
name_set: bool
topic_set: bool
avatar_set: bool
fully_read_kt_chat: Long | None = field(converter=to_optional_long)
relay_user_id: UserID | None
@classmethod
@ -55,43 +56,32 @@ class Portal:
def _from_optional_row(cls, row: Record | None) -> Portal | None:
return cls._from_row(row) if row is not None else None
_columns = (
"ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted, "
"name_set, avatar_set, fully_read_kt_chat, relay_user_id"
)
@classmethod
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id
FROM portal WHERE ktid=$1 AND kt_receiver=$2
"""
q = f"SELECT {cls._columns} FROM portal WHERE ktid=$1 AND kt_receiver=$2"
row = await cls.db.fetchrow(q, ktid, kt_receiver)
return cls._from_optional_row(row)
@classmethod
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id
FROM portal WHERE mxid=$1
"""
q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
row = await cls.db.fetchrow(q, mxid)
return cls._from_optional_row(row)
@classmethod
async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id
FROM portal WHERE kt_receiver=$1
"""
q = f"SELECT {cls._columns} FROM portal WHERE kt_receiver=$1"
rows = await cls.db.fetch(q, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
async def all(cls) -> list[Portal]:
q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id
FROM portal
"""
q = f"SELECT {cls._columns} FROM portal"
rows = await cls.db.fetch(q)
return [cls._from_row(row) for row in rows if row]
@ -109,25 +99,24 @@ class Portal:
self.encrypted,
self.name_set,
self.avatar_set,
self.fully_read_kt_chat,
self.relay_user_id,
)
_args = "$" + ", $".join(str(i + 1) for i in range(_columns.count(',') + 1))
async def insert(self) -> None:
q = """
INSERT INTO portal (ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url,
encrypted, name_set, avatar_set, relay_user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
"""
q = f"INSERT INTO portal ({self._columns}) VALUES ({self._args})"
await self.db.execute(q, *self._values)
async def delete(self) -> None:
q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2"
await self.db.execute(q, self.ktid, self.kt_receiver)
_nonkey_column_asgns = ", ".join(
map(lambda t: (i:=t[0], word:=t[1], f"{word}=${i + 3}")[-1],
enumerate(_columns.split(", ")[2:])
)
)
async def save(self) -> None:
q = """
UPDATE portal SET kt_type=$3, mxid=$4, name=$5, description=$6, photo_id=$7, avatar_url=$8,
encrypted=$9, name_set=$10, avatar_set=$11, relay_user_id=$12
WHERE ktid=$1 AND kt_receiver=$2
"""
q = f"UPDATE portal SET {self._nonkey_column_asgns} WHERE ktid=$1 AND kt_receiver=$2"
await self.db.execute(q, *self._values)

View File

@ -20,3 +20,4 @@ upgrade_table = UpgradeTable()
from . import v01_initial_revision
from . import v02_channel_meta
from . import v03_user_connection
from . import v04_read_receipt_sync

View File

@ -0,0 +1,52 @@
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Connection, Scheme
from . import upgrade_table
@upgrade_table.register(description="Fix message table and add tracking to assist with backfilling read receipts")
async def upgrade_v4(conn: Connection, scheme: Scheme) -> None:
if scheme != Scheme.SQLITE:
await conn.execute("ALTER TABLE message RENAME COLUMN kt_chat TO kt_channel")
await conn.execute("ALTER TABLE message DROP CONSTRAINT message_pkey")
await conn.execute('ALTER TABLE message ADD PRIMARY KEY (ktid, kt_channel, kt_receiver, "index")')
else:
await conn.execute(
"""CREATE TABLE message_v4 (
mxid TEXT NOT NULL,
mx_room TEXT NOT NULL,
ktid BIGINT,
kt_receiver BIGINT NOT NULL,
"index" SMALLINT NOT NULL,
kt_channel BIGINT NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (ktid, kt_channel, kt_receiver, "index"),
FOREIGN KEY (kt_channel, kt_receiver) REFERENCES portal(ktid, kt_receiver)
ON UPDATE CASCADE ON DELETE CASCADE,
UNIQUE (mxid, mx_room)
)"""
)
await conn.execute(
"""
INSERT INTO message_v4 (mxid, mx_room, ktid, kt_receiver, "index", kt_channel, timestamp)
SELECT mxid, mx_room, ktid, kt_receiver, "index", kt_chat, timestamp FROM message
"""
)
await conn.execute("DROP TABLE message")
await conn.execute("ALTER TABLE message_v4 RENAME TO message")
await conn.execute("ALTER TABLE portal ADD COLUMN fully_read_kt_chat BIGINT")

View File

@ -67,6 +67,7 @@ from .types import (
ChannelProps,
PortalChannelInfo,
PortalChannelParticipantInfo,
Receipt,
SettingsStruct,
UserInfoUnion,
)
@ -315,6 +316,14 @@ class Client:
limit=limit,
)
def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]:
return self._api_user_request_result(
ResultListType(Receipt),
"get_read_receipts",
channel_props=channel_props.serialize(),
unread_chat_ids=[c.serialize() for c in unread_chat_ids],
)
def list_friends(self) -> Awaitable[FriendListStruct]:
return self._api_user_request_result(
FriendListStruct,

View File

@ -66,6 +66,11 @@ def deserialize_user_info_union(data: JSON) -> UserInfoUnion:
setattr(UserInfoUnion, "deserialize", deserialize_user_info_union)
@dataclass
class Receipt(SerializableAttrs):
userId: Long
chatId: Long
@dataclass
class PortalChannelParticipantInfo(SerializableAttrs):
participants: list[UserInfoUnion]

View File

@ -180,6 +180,7 @@ class Portal(DBPortal, BasePortal):
name_set: bool = False,
topic_set: bool = False,
avatar_set: bool = False,
fully_read_kt_chat: Long | None = None,
relay_user_id: UserID | None = None,
) -> None:
super().__init__(
@ -195,6 +196,7 @@ class Portal(DBPortal, BasePortal):
name_set,
topic_set,
avatar_set,
fully_read_kt_chat,
relay_user_id,
)
self.log = self.log.getChild(self.ktid_log)
@ -750,17 +752,23 @@ class Portal(DBPortal, BasePortal):
info = await self.update_info(source, info)
# TODO Sync read receipts?
await self._sync_read_receipts(source)
"""
async def _sync_read_receipts(self, receipts: list[None]) -> None:
async def _sync_read_receipts(self, source: u.User) -> None:
messages = await DBMessage.get_all_since(self.ktid, self.kt_receiver, self.fully_read_kt_chat)
receipts = await source.client.get_read_receipts(
self.channel_props,
[m.ktid for m in messages if m.ktid]
)
if not receipts:
return
for receipt in receipts:
message = await DBMessage.get_closest_before(
self.ktid, self.kt_receiver, receipt.timestamp
self.ktid, self.kt_receiver, receipt.chatId
)
if not message:
continue
puppet = await p.Puppet.get_by_ktid(receipt.actor.id, create=False)
puppet = await p.Puppet.get_by_ktid(receipt.userId, create=False)
if not puppet:
continue
try:
@ -771,7 +779,10 @@ class Portal(DBPortal, BasePortal):
f"as read by {puppet.intent.mxid}",
exc_info=True,
)
"""
fully_read_kt_chat = min(receipt.chatId for receipt in receipts)
if not self.fully_read_kt_chat or self.fully_read_kt_chat < fully_read_kt_chat:
self.fully_read_kt_chat = fully_read_kt_chat
await self.save()
async def create_matrix_room(
self, source: u.User, info: PortalChannelInfo | None = None
@ -830,7 +841,7 @@ class Portal(DBPortal, BasePortal):
self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID:
if self.mxid:
await self._update_matrix_room(source, info)
await self._update_matrix_room(source, info=info)
return self.mxid
self.log.debug(f"Creating Matrix room")
@ -947,11 +958,10 @@ class Portal(DBPortal, BasePortal):
if info.channel_info:
try:
await self.backfill(source, is_initial=True, channel_info=info.channel_info)
# NOTE This also syncs read receipts
except Exception:
self.log.exception("Failed to backfill new portal")
# TODO Sync read receipts?
return self.mxid
# endregion
@ -1061,7 +1071,7 @@ class Portal(DBPortal, BasePortal):
mx_room=self.mxid,
ktid=ktid,
index=0,
kt_chat=self.ktid,
kt_channel=self.ktid,
kt_receiver=self.kt_receiver,
timestamp=int(time.time() * 1000),
)
@ -1397,7 +1407,7 @@ class Portal(DBPortal, BasePortal):
async def _add_kakaotalk_reply(
self, content: MessageEventContent, reply_to: ReplyAttachment
) -> None:
message = await DBMessage.get_by_ktid(reply_to.src_logId, self.kt_receiver)
message = await DBMessage.get_by_ktid(reply_to.src_logId, *self.ktid_full)
if not message:
self.log.warning(
f"Couldn't find reply target {reply_to.src_logId} to bridge reply metadata to Matrix"
@ -1488,7 +1498,7 @@ class Portal(DBPortal, BasePortal):
# TODO Might have to handle remote reactions on messages created by bulk_create
await DBMessage.bulk_create(
ktid=chat.logId,
kt_chat=self.ktid,
kt_channel=self.ktid,
kt_receiver=self.kt_receiver,
mx_room=self.mxid,
timestamp=chat.sendAt,
@ -1705,7 +1715,7 @@ class Portal(DBPortal, BasePortal):
) -> None:
if not self.mxid:
return
for message in await DBMessage.get_all_by_ktid(chat_id, self.kt_receiver):
for message in await DBMessage.get_all_by_ktid(chat_id, *self.ktid_full):
try:
await sender.intent_for(self).redact(
message.mx_room, message.mxid, timestamp=timestamp
@ -1856,6 +1866,7 @@ class Portal(DBPortal, BasePortal):
self.log.trace("Leaving room with %s post-backfill", intent.mxid)
await intent.leave_room(self.mxid)
self.log.info("Backfilled %d messages through %s", len(chats), source.mxid)
self._sync_read_receipts(source)
# region Database getters

View File

@ -793,6 +793,47 @@ export default class PeerClient {
return res
}
/**
* @param {Object} req
* @param {string} req.mxid
* @param {ChannelProps} req.channel_props
* @param {[Long]} req.unread_chat_ids Must be in DECREASING order
*/
getReadReceipts = async (req) => {
const talkChannel = await this.#getUserChannel(req.mxid, req.channel_props)
// TODO Is any pre-syncing needed?
const userCount = talkChannel.userCount
if (userCount == 1) return makeCommandResult([])
/** @type {Map<Long, Long> */
const latestReceiptByUser = new Map()
let fullyRead = false
for (const chatId of req.unread_chat_ids) {
const chatReaders = talkChannel.getReaders({ logId: chatId })
for (const chatReader of chatReaders) {
if (!latestReceiptByUser.has(chatReader.userId)) {
latestReceiptByUser.set(chatReader.userId, chatId)
if (latestReceiptByUser.size == userCount) {
fullyRead = true
break
}
}
}
if (fullyRead) {
break
}
}
/**
* @typedef {Object} Receipt
* @property {Long} userId
* @property {Long} chatId
*/
/** @type {[Receipt]} */
const receipts = []
latestReceiptByUser.forEach((value, key) => receipts.push({ "userId": key, "chatId": value }))
return makeCommandResult(receipts)
}
/**
* @param {Object} req
* @param {string} req.mxid
@ -1081,6 +1122,7 @@ export default class PeerClient {
get_portal_channel_participant_info: this.getPortalChannelParticipantInfo,
get_participants: this.getParticipants,
get_chats: this.getChats,
get_read_receipts: this.getReadReceipts,
list_friends: this.listFriends,
get_friend_dm_id: req => this.getFriendProperty(req, "directChatId"),
get_memo_ids: this.getMemoIds,