From d30402a98f763aaa933a1a43d5bbad0e36eb8420 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Thu, 22 Apr 2021 02:39:52 -0400 Subject: [PATCH] More robust message syncing and room cleaning --- matrix_puppeteer_line/db/message.py | 4 +- matrix_puppeteer_line/db/portal.py | 6 ++- matrix_puppeteer_line/matrix.py | 17 +++++-- matrix_puppeteer_line/portal.py | 35 ++++++------- matrix_puppeteer_line/rpc/client.py | 10 ++-- matrix_puppeteer_line/user.py | 1 + puppet/src/puppet.js | 78 ++++++++++++++++------------- 7 files changed, 88 insertions(+), 63 deletions(-) diff --git a/matrix_puppeteer_line/db/message.py b/matrix_puppeteer_line/db/message.py index f536c8f..8b27284 100644 --- a/matrix_puppeteer_line/db/message.py +++ b/matrix_puppeteer_line/db/message.py @@ -30,7 +30,7 @@ class Message: mxid: EventID mx_room: RoomID mid: int - chat_id: int + chat_id: str async def insert(self) -> None: q = "INSERT INTO message (mxid, mx_room, mid, chat_id) VALUES ($1, $2, $3, $4)" @@ -49,7 +49,7 @@ class Message: return await cls.db.fetchval("SELECT MAX(mid) FROM message WHERE mx_room=$1", room_id) @classmethod - async def get_max_mids(cls) -> Dict[int, int]: + async def get_max_mids(cls) -> Dict[str, int]: rows = await cls.db.fetch("SELECT chat_id, MAX(mid) AS max_mid " "FROM message GROUP BY chat_id") data = {} diff --git a/matrix_puppeteer_line/db/portal.py b/matrix_puppeteer_line/db/portal.py index b5d764c..b1a587d 100644 --- a/matrix_puppeteer_line/db/portal.py +++ b/matrix_puppeteer_line/db/portal.py @@ -50,6 +50,10 @@ class Portal: self.icon_path, self.icon_mxc, self.encrypted) + async def delete(self) -> None: + q = "DELETE FROM portal WHERE chat_id=$1" + await self.db.execute(q, self.chat_id) + @classmethod async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']: q = ("SELECT chat_id, other_user, mxid, name, icon_path, icon_mxc, encrypted " @@ -60,7 +64,7 @@ class Portal: return cls(**row) @classmethod - async def get_by_chat_id(cls, chat_id: int) -> Optional['Portal']: + async def get_by_chat_id(cls, chat_id: str) -> Optional['Portal']: q = ("SELECT chat_id, other_user, mxid, name, icon_path, icon_mxc, encrypted " "FROM portal WHERE chat_id=$1") row = await cls.db.fetchrow(q, chat_id) diff --git a/matrix_puppeteer_line/matrix.py b/matrix_puppeteer_line/matrix.py index 85da96d..a2553a8 100644 --- a/matrix_puppeteer_line/matrix.py +++ b/matrix_puppeteer_line/matrix.py @@ -16,10 +16,10 @@ from typing import TYPE_CHECKING from mautrix.bridge import BaseMatrixHandler -from mautrix.types import (Event, ReactionEvent, MessageEvent, StateEvent, EncryptedEvent, RoomID, - RedactionEvent) +from mautrix.types import (Event, ReactionEvent, MessageEvent, StateEvent, EncryptedEvent, RedactionEvent, + EventID, RoomID, UserID) -from . import puppet as pu, user as u +from . import portal as po, puppet as pu, user as u if TYPE_CHECKING: from .__main__ import MessagesBridge @@ -48,3 +48,14 @@ class MatrixHandler(BaseMatrixHandler): await inviter.update() await self.az.intent.send_notice(room_id, "This room has been marked as your " "LINE bridge notice room.") + + async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: + portal = await po.Portal.get_by_mxid(room_id) + if not portal: + return + + user = await u.User.get_by_mxid(user_id, create=False) + if not user: + return + + await portal.handle_matrix_leave(user) diff --git a/matrix_puppeteer_line/portal.py b/matrix_puppeteer_line/portal.py index da318c5..32af678 100644 --- a/matrix_puppeteer_line/portal.py +++ b/matrix_puppeteer_line/portal.py @@ -56,7 +56,7 @@ ReuploadedMediaInfo = NamedTuple('ReuploadedMediaInfo', mxc=Optional[ContentURI] class Portal(DBPortal, BasePortal): invite_own_puppet_to_pm: bool = False by_mxid: Dict[RoomID, 'Portal'] = {} - by_chat_id: Dict[int, 'Portal'] = {} + by_chat_id: Dict[str, 'Portal'] = {} config: Config matrix: 'm.MatrixHandler' az: AppService @@ -66,7 +66,7 @@ class Portal(DBPortal, BasePortal): backfill_lock: SimpleLock _last_participant_update: Set[str] - def __init__(self, chat_id: int, other_user: Optional[str] = None, + def __init__(self, chat_id: str, other_user: Optional[str] = None, mxid: Optional[RoomID] = None, name: Optional[str] = None, icon_path: Optional[str] = None, icon_mxc: Optional[ContentURI] = None, encrypted: bool = False) -> None: @@ -100,6 +100,7 @@ class Portal(DBPortal, BasePortal): @classmethod def init_cls(cls, bridge: 'MessagesBridge') -> None: + BasePortal.bridge = bridge cls.config = bridge.config cls.matrix = bridge.matrix cls.az = bridge.az @@ -163,13 +164,15 @@ class Portal(DBPortal, BasePortal): self.log.warning(f"Handled Matrix message {event_id} -> {message_id}") async def handle_matrix_leave(self, user: 'u.User') -> None: - if self.is_direct: - self.log.info(f"{user.mxid} left private chat portal with {self.other_user}, " - f"cleaning up and deleting...") - await self.cleanup_and_delete() - else: - self.log.debug(f"{user.mxid} left portal to {self.chat_id}") - # TODO cleanup if empty + self.log.info(f"{user.mxid} left portal to {self.chat_id}, " + f"cleaning up and deleting...") + if self.invite_own_puppet_to_pm: + # TODO Use own puppet instead of bridge bot. Then cleanup_and_delete will handle it + try: + await self.az.intent.leave_room(self.mxid) + except MatrixError: + pass + await self.cleanup_and_delete() async def _bridge_own_message_pm(self, source: 'u.User', sender: Optional['p.Puppet'], mid: str, invite: bool = True) -> Optional[IntentAPI]: @@ -586,14 +589,12 @@ class Portal(DBPortal, BasePortal): self._main_intent = self.az.intent async def delete(self) -> None: - await DBMessage.delete_all(self.mxid) + if self.mxid: + # TODO Handle this with db foreign keys instead + await DBMessage.delete_all(self.mxid) + self.by_chat_id.pop(self.chat_id, None) self.by_mxid.pop(self.mxid, None) - self.mxid = None - self.name = None - self.icon_path = None - self.icon_mxc = None - self.encrypted = False - await self.update() + await super().delete() async def save(self) -> None: await self.update() @@ -624,7 +625,7 @@ class Portal(DBPortal, BasePortal): return None @classmethod - async def get_by_chat_id(cls, chat_id: int, create: bool = False) -> Optional['Portal']: + async def get_by_chat_id(cls, chat_id: str, create: bool = False) -> Optional['Portal']: try: return cls.by_chat_id[chat_id] except KeyError: diff --git a/matrix_puppeteer_line/rpc/client.py b/matrix_puppeteer_line/rpc/client.py index b919cde..44edc42 100644 --- a/matrix_puppeteer_line/rpc/client.py +++ b/matrix_puppeteer_line/rpc/client.py @@ -45,10 +45,10 @@ class Client(RPCClient): resp = await self.request("get_chats") return [ChatListInfo.deserialize(data) for data in resp] - async def get_chat(self, chat_id: int) -> ChatInfo: + async def get_chat(self, chat_id: str) -> ChatInfo: return ChatInfo.deserialize(await self.request("get_chat", chat_id=chat_id)) - async def get_messages(self, chat_id: int) -> List[Message]: + async def get_messages(self, chat_id: str) -> List[Message]: resp = await self.request("get_messages", chat_id=chat_id) return [Message.deserialize(data) for data in resp] @@ -75,15 +75,15 @@ class Client(RPCClient): resp = await self.request("is_connected") return resp["is_connected"] - async def send(self, chat_id: int, text: str) -> int: + async def send(self, chat_id: str, text: str) -> int: resp = await self.request("send", chat_id=chat_id, text=text) return resp["id"] - async def send_file(self, chat_id: int, file_path: str) -> int: + async def send_file(self, chat_id: str, file_path: str) -> int: resp = await self.request("send_file", chat_id=chat_id, file_path=file_path) return resp["id"] - async def set_last_message_ids(self, msg_ids: Dict[int, int]) -> None: + async def set_last_message_ids(self, msg_ids: Dict[str, int]) -> None: await self.request("set_last_message_ids", msg_ids=msg_ids) async def on_message(self, func: Callable[[Message], Awaitable[None]]) -> None: diff --git a/matrix_puppeteer_line/user.py b/matrix_puppeteer_line/user.py index 50fd3d9..32d3a53 100644 --- a/matrix_puppeteer_line/user.py +++ b/matrix_puppeteer_line/user.py @@ -148,6 +148,7 @@ class User(DBUser, BaseUser): portal = await po.Portal.get_by_chat_id(evt.chat_id, create=True) puppet = await pu.Puppet.get_by_mid(evt.sender.id) if not portal.is_direct else None if not portal.mxid: + await self.client.set_last_message_ids(await DBMessage.get_max_mids()) chat_info = await self.client.get_chat(evt.chat_id) await portal.create_matrix_room(self, chat_info) await portal.handle_remote_message(self, puppet, evt) diff --git a/puppet/src/puppet.js b/puppet/src/puppet.js index 97adc1d..4524ef3 100644 --- a/puppet/src/puppet.js +++ b/puppet/src/puppet.js @@ -344,17 +344,17 @@ export default class MessagesPuppeteer { /** * Get info about a chat. * - * @param {number} id - The chat ID whose info to get. + * @param {string} chatID - The chat ID whose info to get. * @return {Promise} - Info about the chat. */ - async getChatInfo(id) { - return await this.taskQueue.push(() => this._getChatInfoUnsafe(id)) + async getChatInfo(chatID) { + return await this.taskQueue.push(() => this._getChatInfoUnsafe(chatID)) } /** * Send a message to a chat. * - * @param {number} chatID - The ID of the chat to send a message to. + * @param {string} chatID - The ID of the chat to send a message to. * @param {string} text - The text to send. * @return {Promise<{id: number}>} - The ID of the sent message. */ @@ -365,25 +365,31 @@ export default class MessagesPuppeteer { /** * Get messages in a chat. * - * @param {number} id The ID of the chat whose messages to get. + * @param {string} chatID The ID of the chat whose messages to get. * @return {Promise<[MessageData]>} - The messages visible in the chat. */ - async getMessages(id) { + async getMessages(chatID) { return this.taskQueue.push(async () => { - const messages = await this._getMessagesUnsafe(id) + const messages = await this._getMessagesUnsafe(chatID) if (messages.length > 0) { - this.mostRecentMessages.set(id, messages[messages.length - 1].id) + // TODO Commonize this + const newFirstID = messages[0].id + const newLastID = messages[messages.length - 1].id + this.mostRecentMessages.set(chatID, newLastID) + const range = newFirstID === newLastID ? newFirstID : `${newFirstID}-${newLastID}` + this.log(`Loaded ${messages.length} messages in ${chatID}: got ${range}`) } for (const message of messages) { - message.chat_id = id + message.chat_id = chatID } return messages }) } setLastMessageIDs(ids) { + this.mostRecentMessages.clear() for (const [chatID, messageID] of Object.entries(ids)) { - this.mostRecentMessages.set(+chatID, messageID) + this.mostRecentMessages.set(chatID, messageID) } this.log("Updated most recent message ID map:", this.mostRecentMessages) } @@ -449,10 +455,10 @@ export default class MessagesPuppeteer { return `#_chat_list_body div[data-chatid="${id}"]` } - async _switchChat(id) { + async _switchChat(chatID) { // TODO Allow passing in an element directly - this.log(`Switching to chat ${id}`) - const chatListItem = await this.page.$(this._listItemSelector(id)) + this.log(`Switching to chat ${chatID}`) + const chatListItem = await this.page.$(this._listItemSelector(chatID)) const chatName = await chatListItem.evaluate( element => window.__mautrixController.getChatListItemName(element)) @@ -499,14 +505,14 @@ export default class MessagesPuppeteer { //return participantList } - async _getChatInfoUnsafe(id) { - const chatListItem = await this.page.$(this._listItemSelector(id)) + async _getChatInfoUnsafe(chatID) { + const chatListItem = await this.page.$(this._listItemSelector(chatID)) const chatListInfo = await chatListItem.evaluate( - (element, id) => window.__mautrixController.parseChatListItem(element, id), - id) + (element, chatID) => window.__mautrixController.parseChatListItem(element, chatID), + chatID) let [isDirect, isGroup, isRoom] = [false,false,false] - switch (id.charAt(0)) { + switch (chatID.charAt(0)) { case "u": isDirect = true break @@ -522,18 +528,18 @@ export default class MessagesPuppeteer { if (!isDirect) { this.log("Found multi-user chat, so clicking chat header to get participants") // TODO This will mark the chat as "read"! - await this._switchChat(id) + await this._switchChat(chatID) const participantList = await this.getParticipantList() // TODO Is a group not actually created until a message is sent(?) // If so, maybe don't create a portal until there is a message. participants = await participantList.evaluate( element => window.__mautrixController.parseParticipantList(element)) } else { - this.log(`Found direct chat with ${id}`) + this.log(`Found direct chat with ${chatID}`) //const chatDetailArea = await this.page.waitForSelector("#_chat_detail_area > .mdRGT02Info") //await chatDetailArea.$(".MdTxtDesc02") || // 1:1 chat with custom title - get participant's real name participants = [{ - id: id, + id: chatID, avatar: chatListInfo.icon, name: chatListInfo.name, }] @@ -576,35 +582,37 @@ export default class MessagesPuppeteer { // TODO Inbound read receipts // Probably use a MutationObserver mapped to msgID - async _getMessagesUnsafe(id, minID = 0) { + async _getMessagesUnsafe(chatID) { // TODO Also handle "decrypting" state // TODO Handle unloaded messages. Maybe scroll up // TODO This will mark the chat as "read"! - await this._switchChat(id) - this.log("Waiting for messages to load") + await this._switchChat(chatID) + const minID = this.mostRecentMessages.get(chatID) || 0 + this.log(`Waiting for messages newer than ${minID}`) const messages = await this.page.evaluate( - id => window.__mautrixController.parseMessageList(id), id) - return messages.filter(msg => msg.id > minID && !this.sentMessageIDs.has(msg.id)) + chatID => window.__mautrixController.parseMessageList(chatID), chatID) + const filtered_messages = messages.filter(msg => msg.id > minID && !this.sentMessageIDs.has(msg.id)) + this.log(`Found messages: ${messages.length} total, ${filtered_messages.length} new`) + return filtered_messages } - async _processChatListChangeUnsafe(id) { - this.updatedChats.delete(id) - this.log("Processing change to", id) - const lastMsgID = this.mostRecentMessages.get(id) || 0 - const messages = await this._getMessagesUnsafe(id, lastMsgID) + async _processChatListChangeUnsafe(chatID) { + this.updatedChats.delete(chatID) + this.log("Processing change to", chatID) + const messages = await this._getMessagesUnsafe(chatID) if (messages.length === 0) { - this.log("No new messages found in", id) + this.log("No new messages found in", chatID) return } const newFirstID = messages[0].id const newLastID = messages[messages.length - 1].id - this.mostRecentMessages.set(id, newLastID) + this.mostRecentMessages.set(chatID, newLastID) const range = newFirstID === newLastID ? newFirstID : `${newFirstID}-${newLastID}` - this.log(`Loaded ${messages.length} messages in ${id} after ${lastMsgID}: got ${range}`) + this.log(`Loaded ${messages.length} messages in ${chatID}: got ${range}`) if (this.client) { for (const message of messages) { - message.chat_id = id + message.chat_id = chatID await this.client.sendMessage(message).catch(err => this.error("Failed to send message", message.id, "to client:", err)) }