diff --git a/ROADMAP.md b/ROADMAP.md
index c8e8af3..93105ee 100644
--- a/ROADMAP.md
+++ b/ROADMAP.md
@@ -47,9 +47,7 @@
* [x] Message deletion/hiding
* [ ] Message reactions
* [x] Message history
- * [ ] Read receipts
- * [ ] On backfill
- * [x] On live event
+ * [x] Read receipts
* [x] Admin status
* [ ] Membership actions
* [ ] Invite
diff --git a/matrix_appservice_kakaotalk/db/message.py b/matrix_appservice_kakaotalk/db/message.py
index 893d9c0..c46b21a 100644
--- a/matrix_appservice_kakaotalk/db/message.py
+++ b/matrix_appservice_kakaotalk/db/message.py
@@ -36,7 +36,7 @@ class Message:
mx_room: RoomID
ktid: Long | None = field(converter=to_optional_long)
index: int
- kt_chat: Long = field(converter=Long)
+ kt_channel: Long = field(converter=Long)
kt_receiver: Long = field(converter=Long)
timestamp: int
@@ -48,18 +48,18 @@ class Message:
def _from_optional_row(cls, row: Record | None) -> Message | None:
return cls._from_row(row) if row is not None else None
- columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
+ columns = 'mxid, mx_room, ktid, "index", kt_channel, kt_receiver, timestamp'
@classmethod
- async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]:
- q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
- rows = await cls.db.fetch(q, ktid, kt_receiver)
+ async def get_all_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int) -> list[Message]:
+ q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3"
+ rows = await cls.db.fetch(q, ktid, kt_channel, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
- async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None:
- q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
- row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
+ async def get_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int, index: int = 0) -> Message | None:
+ q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3 AND "index"=$4'
+ row = await cls.db.fetchrow(q, ktid, kt_channel, kt_receiver, index)
return cls._from_optional_row(row)
@classmethod
@@ -73,30 +73,39 @@ class Message:
return cls._from_optional_row(row)
@classmethod
- async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
+ async def get_most_recent(cls, kt_channel: int, kt_receiver: int) -> Message | None:
q = (
f"SELECT {cls.columns} "
- "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
- "ORDER BY timestamp DESC LIMIT 1"
+ "FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
+ "ORDER BY ktid DESC LIMIT 1"
)
- row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
+ row = await cls.db.fetchrow(q, kt_channel, kt_receiver)
return cls._from_optional_row(row)
@classmethod
async def get_closest_before(
- cls, kt_chat: int, kt_receiver: int, ktid: Long
+ cls, kt_channel: int, kt_receiver: int, ktid: int
) -> Message | None:
q = (
f"SELECT {cls.columns} "
- "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid<=$3 AND "
- " ktid IS NOT NULL "
+ "FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid<=$3 "
"ORDER BY ktid DESC LIMIT 1"
)
- row = await cls.db.fetchrow(q, kt_chat, kt_receiver, ktid)
+ row = await cls.db.fetchrow(q, kt_channel, kt_receiver, ktid)
return cls._from_optional_row(row)
+ @classmethod
+ async def get_all_since(cls, kt_channel: int, kt_receiver: int, since_ktid: Long | None) -> list[Message]:
+ q = (
+ f"SELECT {cls.columns} "
+ "FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid>=$3 "
+ "ORDER BY ktid"
+ )
+ rows = await cls.db.fetch(q, kt_channel, kt_receiver, since_ktid or 0)
+ return [cls._from_row(row) for row in rows if row]
+
_insert_query = (
- 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
+ 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_channel, kt_receiver, '
" timestamp) "
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
@@ -105,7 +114,7 @@ class Message:
async def bulk_create(
cls,
ktid: Long,
- kt_chat: Long,
+ kt_channel: Long,
kt_receiver: Long,
event_ids: list[EventID],
timestamp: int,
@@ -115,7 +124,7 @@ class Message:
return []
columns = [col.strip('"') for col in cls.columns.split(", ")]
records = [
- (mxid, mx_room, ktid, index, kt_chat, kt_receiver, timestamp)
+ (mxid, mx_room, ktid, index, kt_channel, kt_receiver, timestamp)
for index, mxid in enumerate(event_ids)
]
async with cls.db.acquire() as conn, conn.transaction():
@@ -134,7 +143,7 @@ class Message:
self.mx_room,
self.ktid,
self.index,
- self.kt_chat,
+ self.kt_channel,
self.kt_receiver,
self.timestamp,
)
diff --git a/matrix_appservice_kakaotalk/db/portal.py b/matrix_appservice_kakaotalk/db/portal.py
index 7f6734c..221675e 100644
--- a/matrix_appservice_kakaotalk/db/portal.py
+++ b/matrix_appservice_kakaotalk/db/portal.py
@@ -23,7 +23,7 @@ from attr import dataclass, field
from mautrix.types import ContentURI, RoomID, UserID
from mautrix.util.async_db import Database
-from ..kt.types.bson import Long
+from ..kt.types.bson import Long, to_optional_long
from ..kt.types.channel.channel_type import ChannelType
fake_db = Database.create("") if TYPE_CHECKING else None
@@ -45,6 +45,7 @@ class Portal:
name_set: bool
topic_set: bool
avatar_set: bool
+ fully_read_kt_chat: Long | None = field(converter=to_optional_long)
relay_user_id: UserID | None
@classmethod
@@ -55,43 +56,32 @@ class Portal:
def _from_optional_row(cls, row: Record | None) -> Portal | None:
return cls._from_row(row) if row is not None else None
+ _columns = (
+ "ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted, "
+ "name_set, avatar_set, fully_read_kt_chat, relay_user_id"
+ )
+
@classmethod
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
- q = """
- SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
- name_set, avatar_set, relay_user_id
- FROM portal WHERE ktid=$1 AND kt_receiver=$2
- """
+ q = f"SELECT {cls._columns} FROM portal WHERE ktid=$1 AND kt_receiver=$2"
row = await cls.db.fetchrow(q, ktid, kt_receiver)
return cls._from_optional_row(row)
@classmethod
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
- q = """
- SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
- name_set, avatar_set, relay_user_id
- FROM portal WHERE mxid=$1
- """
+ q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
row = await cls.db.fetchrow(q, mxid)
return cls._from_optional_row(row)
@classmethod
async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
- q = """
- SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
- name_set, avatar_set, relay_user_id
- FROM portal WHERE kt_receiver=$1
- """
+ q = f"SELECT {cls._columns} FROM portal WHERE kt_receiver=$1"
rows = await cls.db.fetch(q, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
async def all(cls) -> list[Portal]:
- q = """
- SELECT ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url, encrypted,
- name_set, avatar_set, relay_user_id
- FROM portal
- """
+ q = f"SELECT {cls._columns} FROM portal"
rows = await cls.db.fetch(q)
return [cls._from_row(row) for row in rows if row]
@@ -109,25 +99,24 @@ class Portal:
self.encrypted,
self.name_set,
self.avatar_set,
+ self.fully_read_kt_chat,
self.relay_user_id,
)
+ _args = "$" + ", $".join(str(i + 1) for i in range(_columns.count(',') + 1))
async def insert(self) -> None:
- q = """
- INSERT INTO portal (ktid, kt_receiver, kt_type, mxid, name, description, photo_id, avatar_url,
- encrypted, name_set, avatar_set, relay_user_id)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
- """
+ q = f"INSERT INTO portal ({self._columns}) VALUES ({self._args})"
await self.db.execute(q, *self._values)
async def delete(self) -> None:
q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2"
await self.db.execute(q, self.ktid, self.kt_receiver)
+ _nonkey_column_asgns = ", ".join(
+ map(lambda t: (i:=t[0], word:=t[1], f"{word}=${i + 3}")[-1],
+ enumerate(_columns.split(", ")[2:])
+ )
+ )
async def save(self) -> None:
- q = """
- UPDATE portal SET kt_type=$3, mxid=$4, name=$5, description=$6, photo_id=$7, avatar_url=$8,
- encrypted=$9, name_set=$10, avatar_set=$11, relay_user_id=$12
- WHERE ktid=$1 AND kt_receiver=$2
- """
+ q = f"UPDATE portal SET {self._nonkey_column_asgns} WHERE ktid=$1 AND kt_receiver=$2"
await self.db.execute(q, *self._values)
diff --git a/matrix_appservice_kakaotalk/db/upgrade/__init__.py b/matrix_appservice_kakaotalk/db/upgrade/__init__.py
index af8ed7f..a79510b 100644
--- a/matrix_appservice_kakaotalk/db/upgrade/__init__.py
+++ b/matrix_appservice_kakaotalk/db/upgrade/__init__.py
@@ -20,3 +20,4 @@ upgrade_table = UpgradeTable()
from . import v01_initial_revision
from . import v02_channel_meta
from . import v03_user_connection
+from . import v04_read_receipt_sync
diff --git a/matrix_appservice_kakaotalk/db/upgrade/v04_read_receipt_sync.py b/matrix_appservice_kakaotalk/db/upgrade/v04_read_receipt_sync.py
new file mode 100644
index 0000000..eab2f05
--- /dev/null
+++ b/matrix_appservice_kakaotalk/db/upgrade/v04_read_receipt_sync.py
@@ -0,0 +1,52 @@
+# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
+# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from mautrix.util.async_db import Connection, Scheme
+
+from . import upgrade_table
+
+
+@upgrade_table.register(description="Fix message table and add tracking to assist with backfilling read receipts")
+async def upgrade_v4(conn: Connection, scheme: Scheme) -> None:
+ if scheme != Scheme.SQLITE:
+ await conn.execute("ALTER TABLE message RENAME COLUMN kt_chat TO kt_channel")
+ await conn.execute("ALTER TABLE message DROP CONSTRAINT message_pkey")
+ await conn.execute('ALTER TABLE message ADD PRIMARY KEY (ktid, kt_channel, kt_receiver, "index")')
+ else:
+ await conn.execute(
+ """CREATE TABLE message_v4 (
+ mxid TEXT NOT NULL,
+ mx_room TEXT NOT NULL,
+ ktid BIGINT,
+ kt_receiver BIGINT NOT NULL,
+ "index" SMALLINT NOT NULL,
+ kt_channel BIGINT NOT NULL,
+ timestamp BIGINT NOT NULL,
+ PRIMARY KEY (ktid, kt_channel, kt_receiver, "index"),
+ FOREIGN KEY (kt_channel, kt_receiver) REFERENCES portal(ktid, kt_receiver)
+ ON UPDATE CASCADE ON DELETE CASCADE,
+ UNIQUE (mxid, mx_room)
+ )"""
+ )
+ await conn.execute(
+ """
+ INSERT INTO message_v4 (mxid, mx_room, ktid, kt_receiver, "index", kt_channel, timestamp)
+ SELECT mxid, mx_room, ktid, kt_receiver, "index", kt_chat, timestamp FROM message
+ """
+ )
+ await conn.execute("DROP TABLE message")
+ await conn.execute("ALTER TABLE message_v4 RENAME TO message")
+
+ await conn.execute("ALTER TABLE portal ADD COLUMN fully_read_kt_chat BIGINT")
diff --git a/matrix_appservice_kakaotalk/kt/client/client.py b/matrix_appservice_kakaotalk/kt/client/client.py
index 7a09486..020bb91 100644
--- a/matrix_appservice_kakaotalk/kt/client/client.py
+++ b/matrix_appservice_kakaotalk/kt/client/client.py
@@ -67,6 +67,7 @@ from .types import (
ChannelProps,
PortalChannelInfo,
PortalChannelParticipantInfo,
+ Receipt,
SettingsStruct,
UserInfoUnion,
)
@@ -315,6 +316,14 @@ class Client:
limit=limit,
)
+ def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]:
+ return self._api_user_request_result(
+ ResultListType(Receipt),
+ "get_read_receipts",
+ channel_props=channel_props.serialize(),
+ unread_chat_ids=[c.serialize() for c in unread_chat_ids],
+ )
+
def list_friends(self) -> Awaitable[FriendListStruct]:
return self._api_user_request_result(
FriendListStruct,
diff --git a/matrix_appservice_kakaotalk/kt/client/types.py b/matrix_appservice_kakaotalk/kt/client/types.py
index 983fbdc..1a4e193 100644
--- a/matrix_appservice_kakaotalk/kt/client/types.py
+++ b/matrix_appservice_kakaotalk/kt/client/types.py
@@ -66,6 +66,11 @@ def deserialize_user_info_union(data: JSON) -> UserInfoUnion:
setattr(UserInfoUnion, "deserialize", deserialize_user_info_union)
+@dataclass
+class Receipt(SerializableAttrs):
+ userId: Long
+ chatId: Long
+
@dataclass
class PortalChannelParticipantInfo(SerializableAttrs):
participants: list[UserInfoUnion]
diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py
index 5364535..b8d7370 100644
--- a/matrix_appservice_kakaotalk/portal.py
+++ b/matrix_appservice_kakaotalk/portal.py
@@ -180,6 +180,7 @@ class Portal(DBPortal, BasePortal):
name_set: bool = False,
topic_set: bool = False,
avatar_set: bool = False,
+ fully_read_kt_chat: Long | None = None,
relay_user_id: UserID | None = None,
) -> None:
super().__init__(
@@ -195,6 +196,7 @@ class Portal(DBPortal, BasePortal):
name_set,
topic_set,
avatar_set,
+ fully_read_kt_chat,
relay_user_id,
)
self.log = self.log.getChild(self.ktid_log)
@@ -750,17 +752,23 @@ class Portal(DBPortal, BasePortal):
info = await self.update_info(source, info)
- # TODO Sync read receipts?
+ await self._sync_read_receipts(source)
- """
- async def _sync_read_receipts(self, receipts: list[None]) -> None:
+ async def _sync_read_receipts(self, source: u.User) -> None:
+ messages = await DBMessage.get_all_since(self.ktid, self.kt_receiver, self.fully_read_kt_chat)
+ receipts = await source.client.get_read_receipts(
+ self.channel_props,
+ [m.ktid for m in messages if m.ktid]
+ )
+ if not receipts:
+ return
for receipt in receipts:
message = await DBMessage.get_closest_before(
- self.ktid, self.kt_receiver, receipt.timestamp
+ self.ktid, self.kt_receiver, receipt.chatId
)
if not message:
continue
- puppet = await p.Puppet.get_by_ktid(receipt.actor.id, create=False)
+ puppet = await p.Puppet.get_by_ktid(receipt.userId, create=False)
if not puppet:
continue
try:
@@ -771,7 +779,10 @@ class Portal(DBPortal, BasePortal):
f"as read by {puppet.intent.mxid}",
exc_info=True,
)
- """
+ fully_read_kt_chat = min(receipt.chatId for receipt in receipts)
+ if not self.fully_read_kt_chat or self.fully_read_kt_chat < fully_read_kt_chat:
+ self.fully_read_kt_chat = fully_read_kt_chat
+ await self.save()
async def create_matrix_room(
self, source: u.User, info: PortalChannelInfo | None = None
@@ -830,7 +841,7 @@ class Portal(DBPortal, BasePortal):
self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID:
if self.mxid:
- await self._update_matrix_room(source, info)
+ await self._update_matrix_room(source, info=info)
return self.mxid
self.log.debug(f"Creating Matrix room")
@@ -947,11 +958,10 @@ class Portal(DBPortal, BasePortal):
if info.channel_info:
try:
await self.backfill(source, is_initial=True, channel_info=info.channel_info)
+ # NOTE This also syncs read receipts
except Exception:
self.log.exception("Failed to backfill new portal")
- # TODO Sync read receipts?
-
return self.mxid
# endregion
@@ -1061,7 +1071,7 @@ class Portal(DBPortal, BasePortal):
mx_room=self.mxid,
ktid=ktid,
index=0,
- kt_chat=self.ktid,
+ kt_channel=self.ktid,
kt_receiver=self.kt_receiver,
timestamp=int(time.time() * 1000),
)
@@ -1397,7 +1407,7 @@ class Portal(DBPortal, BasePortal):
async def _add_kakaotalk_reply(
self, content: MessageEventContent, reply_to: ReplyAttachment
) -> None:
- message = await DBMessage.get_by_ktid(reply_to.src_logId, self.kt_receiver)
+ message = await DBMessage.get_by_ktid(reply_to.src_logId, *self.ktid_full)
if not message:
self.log.warning(
f"Couldn't find reply target {reply_to.src_logId} to bridge reply metadata to Matrix"
@@ -1488,7 +1498,7 @@ class Portal(DBPortal, BasePortal):
# TODO Might have to handle remote reactions on messages created by bulk_create
await DBMessage.bulk_create(
ktid=chat.logId,
- kt_chat=self.ktid,
+ kt_channel=self.ktid,
kt_receiver=self.kt_receiver,
mx_room=self.mxid,
timestamp=chat.sendAt,
@@ -1705,7 +1715,7 @@ class Portal(DBPortal, BasePortal):
) -> None:
if not self.mxid:
return
- for message in await DBMessage.get_all_by_ktid(chat_id, self.kt_receiver):
+ for message in await DBMessage.get_all_by_ktid(chat_id, *self.ktid_full):
try:
await sender.intent_for(self).redact(
message.mx_room, message.mxid, timestamp=timestamp
@@ -1856,6 +1866,7 @@ class Portal(DBPortal, BasePortal):
self.log.trace("Leaving room with %s post-backfill", intent.mxid)
await intent.leave_room(self.mxid)
self.log.info("Backfilled %d messages through %s", len(chats), source.mxid)
+ self._sync_read_receipts(source)
# region Database getters
diff --git a/node/src/client.js b/node/src/client.js
index e2c3515..8024007 100644
--- a/node/src/client.js
+++ b/node/src/client.js
@@ -793,6 +793,47 @@ export default class PeerClient {
return res
}
+ /**
+ * @param {Object} req
+ * @param {string} req.mxid
+ * @param {ChannelProps} req.channel_props
+ * @param {[Long]} req.unread_chat_ids Must be in DECREASING order
+ */
+ getReadReceipts = async (req) => {
+ const talkChannel = await this.#getUserChannel(req.mxid, req.channel_props)
+ // TODO Is any pre-syncing needed?
+ const userCount = talkChannel.userCount
+ if (userCount == 1) return makeCommandResult([])
+ /** @type {Map */
+ const latestReceiptByUser = new Map()
+ let fullyRead = false
+ for (const chatId of req.unread_chat_ids) {
+ const chatReaders = talkChannel.getReaders({ logId: chatId })
+ for (const chatReader of chatReaders) {
+ if (!latestReceiptByUser.has(chatReader.userId)) {
+ latestReceiptByUser.set(chatReader.userId, chatId)
+ if (latestReceiptByUser.size == userCount) {
+ fullyRead = true
+ break
+ }
+ }
+ }
+ if (fullyRead) {
+ break
+ }
+ }
+
+ /**
+ * @typedef {Object} Receipt
+ * @property {Long} userId
+ * @property {Long} chatId
+ */
+ /** @type {[Receipt]} */
+ const receipts = []
+ latestReceiptByUser.forEach((value, key) => receipts.push({ "userId": key, "chatId": value }))
+ return makeCommandResult(receipts)
+ }
+
/**
* @param {Object} req
* @param {string} req.mxid
@@ -1081,6 +1122,7 @@ export default class PeerClient {
get_portal_channel_participant_info: this.getPortalChannelParticipantInfo,
get_participants: this.getParticipants,
get_chats: this.getChats,
+ get_read_receipts: this.getReadReceipts,
list_friends: this.listFriends,
get_friend_dm_id: req => this.getFriendProperty(req, "directChatId"),
get_memo_ids: this.getMemoIds,