diff --git a/ROADMAP.md b/ROADMAP.md index c8e8af3..93105ee 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -47,9 +47,7 @@ * [x] Message deletion/hiding * [ ] Message reactions * [x] Message history - * [ ] Read receipts - * [ ] On backfill - * [x] On live event + * [x] Read receipts * [x] Admin status * [ ] Membership actions * [ ] Invite diff --git a/matrix_appservice_kakaotalk/db/message.py b/matrix_appservice_kakaotalk/db/message.py index 893d9c0..c46b21a 100644 --- a/matrix_appservice_kakaotalk/db/message.py +++ b/matrix_appservice_kakaotalk/db/message.py @@ -36,7 +36,7 @@ class Message: mx_room: RoomID ktid: Long | None = field(converter=to_optional_long) index: int - kt_chat: Long = field(converter=Long) + kt_channel: Long = field(converter=Long) kt_receiver: Long = field(converter=Long) timestamp: int @@ -48,18 +48,18 @@ class Message: def _from_optional_row(cls, row: Record | None) -> Message | None: return cls._from_row(row) if row is not None else None - columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp' + columns = 'mxid, mx_room, ktid, "index", kt_channel, kt_receiver, timestamp' @classmethod - async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]: - q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2" - rows = await cls.db.fetch(q, ktid, kt_receiver) + async def get_all_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int) -> list[Message]: + q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3" + rows = await cls.db.fetch(q, ktid, kt_channel, kt_receiver) return [cls._from_row(row) for row in rows if row] @classmethod - async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None: - q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' - row = await cls.db.fetchrow(q, ktid, kt_receiver, index) + async def get_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int, index: int = 0) -> Message | None: + q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3 AND "index"=$4' + row = await cls.db.fetchrow(q, ktid, kt_channel, kt_receiver, index) return cls._from_optional_row(row) @classmethod @@ -73,30 +73,39 @@ class Message: return cls._from_optional_row(row) @classmethod - async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None: + async def get_most_recent(cls, kt_channel: int, kt_receiver: int) -> Message | None: q = ( f"SELECT {cls.columns} " - "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL " - "ORDER BY timestamp DESC LIMIT 1" + "FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid IS NOT NULL " + "ORDER BY ktid DESC LIMIT 1" ) - row = await cls.db.fetchrow(q, kt_chat, kt_receiver) + row = await cls.db.fetchrow(q, kt_channel, kt_receiver) return cls._from_optional_row(row) @classmethod async def get_closest_before( - cls, kt_chat: int, kt_receiver: int, ktid: Long + cls, kt_channel: int, kt_receiver: int, ktid: int ) -> Message | None: q = ( f"SELECT {cls.columns} " - "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid<=$3 AND " - " ktid IS NOT NULL " + "FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid<=$3 " "ORDER BY ktid DESC LIMIT 1" ) - row = await cls.db.fetchrow(q, kt_chat, kt_receiver, ktid) + row = await cls.db.fetchrow(q, kt_channel, kt_receiver, ktid) return cls._from_optional_row(row) + @classmethod + async def get_all_since(cls, kt_channel: int, kt_receiver: int, since_ktid: Long | None) -> list[Message]: + q = ( + f"SELECT {cls.columns} " + "FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid>=$3 " + "ORDER BY ktid" + ) + rows = await cls.db.fetch(q, kt_channel, kt_receiver, since_ktid or 0) + return [cls._from_row(row) for row in rows if row] + _insert_query = ( - 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, ' + 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_channel, kt_receiver, ' " timestamp) " "VALUES ($1, $2, $3, $4, $5, $6, $7)" ) @@ -105,7 +114,7 @@ class Message: async def bulk_create( cls, ktid: Long, - kt_chat: Long, + kt_channel: Long, kt_receiver: Long, event_ids: list[EventID], timestamp: int, @@ -115,7 +124,7 @@ class Message: return [] columns = [col.strip('"') for col in cls.columns.split(", ")] records = [ - (mxid, mx_room, ktid, index, kt_chat, kt_receiver, timestamp) + (mxid, mx_room, ktid, index, kt_channel, kt_receiver, timestamp) for index, mxid in enumerate(event_ids) ] async with cls.db.acquire() as conn, conn.transaction(): @@ -134,7 +143,7 @@ class Message: self.mx_room, self.ktid, self.index, - self.kt_chat, + self.kt_channel, self.kt_receiver, self.timestamp, ) diff --git a/matrix_appservice_kakaotalk/db/portal.py b/matrix_appservice_kakaotalk/db/portal.py index 7f6734c..221675e 100644 --- a/matrix_appservice_kakaotalk/db/portal.py +++ b/matrix_appservice_kakaotalk/db/portal.py @@ -23,7 +23,7 @@ from attr import dataclass, field from mautrix.types import ContentURI, RoomID, UserID from mautrix.util.async_db import Database -from ..kt.types.bson import Long +from ..kt.types.bson import Long, to_optional_long from ..kt.types.channel.channel_type import ChannelType fake_db = Database.create("") if TYPE_CHECKING else None @@ -45,6 +45,7 @@ class Portal: name_set: bool topic_set: bool avatar_set: bool + fully_read_kt_chat: Long | None = field(converter=to_optional_long) relay_user_id: UserID | None @classmethod @@ -55,43 +56,32 @@ class Portal: def _from_optional_row(cls, row: Record | None) -> Portal | None: return cls._from_row(row) if row is not None else None + _columns = ( + "ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted, " + "name_set, avatar_set, fully_read_kt_chat, relay_user_id" + ) + @classmethod async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None: - q = """ - SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted, - name_set, avatar_set, relay_user_id - FROM portal WHERE ktid=$1 AND kt_receiver=$2 - """ + q = f"SELECT {cls._columns} FROM portal WHERE ktid=$1 AND kt_receiver=$2" row = await cls.db.fetchrow(q, ktid, kt_receiver) return cls._from_optional_row(row) @classmethod async def get_by_mxid(cls, mxid: RoomID) -> Portal | None: - q = """ - SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted, - name_set, avatar_set, relay_user_id - FROM portal WHERE mxid=$1 - """ + q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1" row = await cls.db.fetchrow(q, mxid) return cls._from_optional_row(row) @classmethod async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]: - q = """ - SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted, - name_set, avatar_set, relay_user_id - FROM portal WHERE kt_receiver=$1 - """ + q = f"SELECT {cls._columns} FROM portal WHERE kt_receiver=$1" rows = await cls.db.fetch(q, kt_receiver) return [cls._from_row(row) for row in rows if row] @classmethod async def all(cls) -> list[Portal]: - q = """ - SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted, - name_set, avatar_set, relay_user_id - FROM portal - """ + q = f"SELECT {cls._columns} FROM portal" rows = await cls.db.fetch(q) return [cls._from_row(row) for row in rows if row] @@ -109,25 +99,24 @@ class Portal: self.encrypted, self.name_set, self.avatar_set, + self.fully_read_kt_chat, self.relay_user_id, ) + _args = "$" + ", $".join(str(i + 1) for i in range(_columns.count(',') + 1)) async def insert(self) -> None: - q = """ - INSERT INTO portal (ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, - encrypted, name_set, avatar_set, relay_user_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) - """ + q = f"INSERT INTO portal ({self._columns}) VALUES ({self._args})" await self.db.execute(q, *self._values) async def delete(self) -> None: q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2" await self.db.execute(q, self.ktid, self.kt_receiver) + _nonkey_column_asgns = ", ".join( + map(lambda t: (i:=t[0], word:=t[1], f"{word}=${i + 3}")[-1], + enumerate(_columns.split(", ")[2:]) + ) + ) async def save(self) -> None: - q = """ - UPDATE portal SET kt_type=$3, mxid=$4, name=$5, description=$6, photo_id=$7, avatar_url=$8, - encrypted=$9, name_set=$10, avatar_set=$11, relay_user_id=$12 - WHERE ktid=$1 AND kt_receiver=$2 - """ + q = f"UPDATE portal SET {self._nonkey_column_asgns} WHERE ktid=$1 AND kt_receiver=$2" await self.db.execute(q, *self._values) diff --git a/matrix_appservice_kakaotalk/db/upgrade/__init__.py b/matrix_appservice_kakaotalk/db/upgrade/__init__.py index af8ed7f..a79510b 100644 --- a/matrix_appservice_kakaotalk/db/upgrade/__init__.py +++ b/matrix_appservice_kakaotalk/db/upgrade/__init__.py @@ -20,3 +20,4 @@ upgrade_table = UpgradeTable() from . import v01_initial_revision from . import v02_channel_meta from . import v03_user_connection +from . import v04_read_receipt_sync diff --git a/matrix_appservice_kakaotalk/db/upgrade/v04_read_receipt_sync.py b/matrix_appservice_kakaotalk/db/upgrade/v04_read_receipt_sync.py new file mode 100644 index 0000000..eab2f05 --- /dev/null +++ b/matrix_appservice_kakaotalk/db/upgrade/v04_read_receipt_sync.py @@ -0,0 +1,52 @@ +# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. +# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from mautrix.util.async_db import Connection, Scheme + +from . import upgrade_table + + +@upgrade_table.register(description="Fix message table and add tracking to assist with backfilling read receipts") +async def upgrade_v4(conn: Connection, scheme: Scheme) -> None: + if scheme != Scheme.SQLITE: + await conn.execute("ALTER TABLE message RENAME COLUMN kt_chat TO kt_channel") + await conn.execute("ALTER TABLE message DROP CONSTRAINT message_pkey") + await conn.execute('ALTER TABLE message ADD PRIMARY KEY (ktid, kt_channel, kt_receiver, "index")') + else: + await conn.execute( + """CREATE TABLE message_v4 ( + mxid TEXT NOT NULL, + mx_room TEXT NOT NULL, + ktid BIGINT, + kt_receiver BIGINT NOT NULL, + "index" SMALLINT NOT NULL, + kt_channel BIGINT NOT NULL, + timestamp BIGINT NOT NULL, + PRIMARY KEY (ktid, kt_channel, kt_receiver, "index"), + FOREIGN KEY (kt_channel, kt_receiver) REFERENCES portal(ktid, kt_receiver) + ON UPDATE CASCADE ON DELETE CASCADE, + UNIQUE (mxid, mx_room) + )""" + ) + await conn.execute( + """ + INSERT INTO message_v4 (mxid, mx_room, ktid, kt_receiver, "index", kt_channel, timestamp) + SELECT mxid, mx_room, ktid, kt_receiver, "index", kt_chat, timestamp FROM message + """ + ) + await conn.execute("DROP TABLE message") + await conn.execute("ALTER TABLE message_v4 RENAME TO message") + + await conn.execute("ALTER TABLE portal ADD COLUMN fully_read_kt_chat BIGINT") diff --git a/matrix_appservice_kakaotalk/kt/client/client.py b/matrix_appservice_kakaotalk/kt/client/client.py index 7a09486..020bb91 100644 --- a/matrix_appservice_kakaotalk/kt/client/client.py +++ b/matrix_appservice_kakaotalk/kt/client/client.py @@ -67,6 +67,7 @@ from .types import ( ChannelProps, PortalChannelInfo, PortalChannelParticipantInfo, + Receipt, SettingsStruct, UserInfoUnion, ) @@ -315,6 +316,14 @@ class Client: limit=limit, ) + def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]: + return self._api_user_request_result( + ResultListType(Receipt), + "get_read_receipts", + channel_props=channel_props.serialize(), + unread_chat_ids=[c.serialize() for c in unread_chat_ids], + ) + def list_friends(self) -> Awaitable[FriendListStruct]: return self._api_user_request_result( FriendListStruct, diff --git a/matrix_appservice_kakaotalk/kt/client/types.py b/matrix_appservice_kakaotalk/kt/client/types.py index 983fbdc..1a4e193 100644 --- a/matrix_appservice_kakaotalk/kt/client/types.py +++ b/matrix_appservice_kakaotalk/kt/client/types.py @@ -66,6 +66,11 @@ def deserialize_user_info_union(data: JSON) -> UserInfoUnion: setattr(UserInfoUnion, "deserialize", deserialize_user_info_union) +@dataclass +class Receipt(SerializableAttrs): + userId: Long + chatId: Long + @dataclass class PortalChannelParticipantInfo(SerializableAttrs): participants: list[UserInfoUnion] diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py index 5364535..b8d7370 100644 --- a/matrix_appservice_kakaotalk/portal.py +++ b/matrix_appservice_kakaotalk/portal.py @@ -180,6 +180,7 @@ class Portal(DBPortal, BasePortal): name_set: bool = False, topic_set: bool = False, avatar_set: bool = False, + fully_read_kt_chat: Long | None = None, relay_user_id: UserID | None = None, ) -> None: super().__init__( @@ -195,6 +196,7 @@ class Portal(DBPortal, BasePortal): name_set, topic_set, avatar_set, + fully_read_kt_chat, relay_user_id, ) self.log = self.log.getChild(self.ktid_log) @@ -750,17 +752,23 @@ class Portal(DBPortal, BasePortal): info = await self.update_info(source, info) - # TODO Sync read receipts? + await self._sync_read_receipts(source) - """ - async def _sync_read_receipts(self, receipts: list[None]) -> None: + async def _sync_read_receipts(self, source: u.User) -> None: + messages = await DBMessage.get_all_since(self.ktid, self.kt_receiver, self.fully_read_kt_chat) + receipts = await source.client.get_read_receipts( + self.channel_props, + [m.ktid for m in messages if m.ktid] + ) + if not receipts: + return for receipt in receipts: message = await DBMessage.get_closest_before( - self.ktid, self.kt_receiver, receipt.timestamp + self.ktid, self.kt_receiver, receipt.chatId ) if not message: continue - puppet = await p.Puppet.get_by_ktid(receipt.actor.id, create=False) + puppet = await p.Puppet.get_by_ktid(receipt.userId, create=False) if not puppet: continue try: @@ -771,7 +779,10 @@ class Portal(DBPortal, BasePortal): f"as read by {puppet.intent.mxid}", exc_info=True, ) - """ + fully_read_kt_chat = min(receipt.chatId for receipt in receipts) + if not self.fully_read_kt_chat or self.fully_read_kt_chat < fully_read_kt_chat: + self.fully_read_kt_chat = fully_read_kt_chat + await self.save() async def create_matrix_room( self, source: u.User, info: PortalChannelInfo | None = None @@ -830,7 +841,7 @@ class Portal(DBPortal, BasePortal): self, source: u.User, info: PortalChannelInfo | None = None ) -> RoomID: if self.mxid: - await self._update_matrix_room(source, info) + await self._update_matrix_room(source, info=info) return self.mxid self.log.debug(f"Creating Matrix room") @@ -947,11 +958,10 @@ class Portal(DBPortal, BasePortal): if info.channel_info: try: await self.backfill(source, is_initial=True, channel_info=info.channel_info) + # NOTE This also syncs read receipts except Exception: self.log.exception("Failed to backfill new portal") - # TODO Sync read receipts? - return self.mxid # endregion @@ -1061,7 +1071,7 @@ class Portal(DBPortal, BasePortal): mx_room=self.mxid, ktid=ktid, index=0, - kt_chat=self.ktid, + kt_channel=self.ktid, kt_receiver=self.kt_receiver, timestamp=int(time.time() * 1000), ) @@ -1397,7 +1407,7 @@ class Portal(DBPortal, BasePortal): async def _add_kakaotalk_reply( self, content: MessageEventContent, reply_to: ReplyAttachment ) -> None: - message = await DBMessage.get_by_ktid(reply_to.src_logId, self.kt_receiver) + message = await DBMessage.get_by_ktid(reply_to.src_logId, *self.ktid_full) if not message: self.log.warning( f"Couldn't find reply target {reply_to.src_logId} to bridge reply metadata to Matrix" @@ -1488,7 +1498,7 @@ class Portal(DBPortal, BasePortal): # TODO Might have to handle remote reactions on messages created by bulk_create await DBMessage.bulk_create( ktid=chat.logId, - kt_chat=self.ktid, + kt_channel=self.ktid, kt_receiver=self.kt_receiver, mx_room=self.mxid, timestamp=chat.sendAt, @@ -1705,7 +1715,7 @@ class Portal(DBPortal, BasePortal): ) -> None: if not self.mxid: return - for message in await DBMessage.get_all_by_ktid(chat_id, self.kt_receiver): + for message in await DBMessage.get_all_by_ktid(chat_id, *self.ktid_full): try: await sender.intent_for(self).redact( message.mx_room, message.mxid, timestamp=timestamp @@ -1856,6 +1866,7 @@ class Portal(DBPortal, BasePortal): self.log.trace("Leaving room with %s post-backfill", intent.mxid) await intent.leave_room(self.mxid) self.log.info("Backfilled %d messages through %s", len(chats), source.mxid) + self._sync_read_receipts(source) # region Database getters diff --git a/node/src/client.js b/node/src/client.js index e2c3515..8024007 100644 --- a/node/src/client.js +++ b/node/src/client.js @@ -793,6 +793,47 @@ export default class PeerClient { return res } + /** + * @param {Object} req + * @param {string} req.mxid + * @param {ChannelProps} req.channel_props + * @param {[Long]} req.unread_chat_ids Must be in DECREASING order + */ + getReadReceipts = async (req) => { + const talkChannel = await this.#getUserChannel(req.mxid, req.channel_props) + // TODO Is any pre-syncing needed? + const userCount = talkChannel.userCount + if (userCount == 1) return makeCommandResult([]) + /** @type {Map */ + const latestReceiptByUser = new Map() + let fullyRead = false + for (const chatId of req.unread_chat_ids) { + const chatReaders = talkChannel.getReaders({ logId: chatId }) + for (const chatReader of chatReaders) { + if (!latestReceiptByUser.has(chatReader.userId)) { + latestReceiptByUser.set(chatReader.userId, chatId) + if (latestReceiptByUser.size == userCount) { + fullyRead = true + break + } + } + } + if (fullyRead) { + break + } + } + + /** + * @typedef {Object} Receipt + * @property {Long} userId + * @property {Long} chatId + */ + /** @type {[Receipt]} */ + const receipts = [] + latestReceiptByUser.forEach((value, key) => receipts.push({ "userId": key, "chatId": value })) + return makeCommandResult(receipts) + } + /** * @param {Object} req * @param {string} req.mxid @@ -1081,6 +1122,7 @@ export default class PeerClient { get_portal_channel_participant_info: this.getPortalChannelParticipantInfo, get_participants: this.getParticipants, get_chats: this.getChats, + get_read_receipts: this.getReadReceipts, list_friends: this.listFriends, get_friend_dm_id: req => this.getFriendProperty(req, "directChatId"), get_memo_ids: this.getMemoIds,