diff --git a/matrix_appservice_kakaotalk/db/message.py b/matrix_appservice_kakaotalk/db/message.py index 79c69b2..5c8d878 100644 --- a/matrix_appservice_kakaotalk/db/message.py +++ b/matrix_appservice_kakaotalk/db/message.py @@ -42,7 +42,7 @@ class Message: timestamp: int @classmethod - def _from_row(cls, row: Record | None) -> Message | None: + def _from_row(cls, row: Record) -> Message | None: data = {**row} ktid = data.pop("ktid") kt_chat = data.pop("kt_chat") @@ -64,13 +64,13 @@ class Message: async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]: q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2" rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver)) - return [cls._from_row(row) for row in rows] + return [cls._from_row(row) for row in rows if row] @classmethod async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None: q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index) - return cls._from_row(row) + return cls._from_optional_row(row) @classmethod async def delete_all_by_room(cls, room_id: RoomID) -> None: @@ -80,7 +80,7 @@ class Message: async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None: q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2" row = await cls.db.fetchrow(q, mxid, mx_room) - return cls._from_row(row) + return cls._from_optional_row(row) @classmethod async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None: @@ -90,7 +90,7 @@ class Message: "ORDER BY timestamp DESC LIMIT 1" ) row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver)) - return cls._from_row(row) + return cls._from_optional_row(row) @classmethod async def get_closest_before( @@ -103,7 +103,7 @@ class Message: "ORDER BY timestamp DESC LIMIT 1" ) row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp) - return cls._from_row(row) + return cls._from_optional_row(row) _insert_query = ( 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, ' diff --git a/matrix_appservice_kakaotalk/kt/client/client.py b/matrix_appservice_kakaotalk/kt/client/client.py index d1729ed..678e226 100644 --- a/matrix_appservice_kakaotalk/kt/client/client.py +++ b/matrix_appservice_kakaotalk/kt/client/client.py @@ -39,6 +39,7 @@ from ...rpc import RPCClient from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct from ..types.bson import Long from ..types.client.client_session import LoginResult +from ..types.channel.channel_type import ChannelType from ..types.chat.chat import Chatlog from ..types.oauth import OAuthCredential, OAuthInfo from ..types.request import ( @@ -208,14 +209,12 @@ class Client: ) return profile_req_struct.profile - async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo: - req = await self._api_user_request_result( + async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo: + return await self._api_user_request_result( PortalChannelInfo, "get_portal_channel_info", - channel_id=channel_info.channelId.serialize() + channel_id=channel_id.serialize() ) - req.channel_info = channel_info - return req async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]: return (await self._api_user_request_result( @@ -233,6 +232,10 @@ class Client: text=text ) + async def start_listen(self) -> None: + # TODO Connect all listeners here? + await self._api_user_request_void("start_listen") + async def stop(self) -> None: # TODO Stop all event handlers await self._api_user_request_void("stop") @@ -269,15 +272,19 @@ class Client: # region listeners - async def on_message(self, func: Callable[[Chatlog, Long], Awaitable[None]]) -> None: + async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None: async def wrapper(data: dict[str, JSON]) -> None: - await func(Chatlog.deserialize(data["chatlog"]), data["channelId"]) + await func( + Chatlog.deserialize(data["chatlog"]), + Long.deserialize(data["channelId"]), + data["channelType"] + ) - self._add_user_handler("chat", wrapper) + self._add_user_handler("message", wrapper) def _add_user_handler(self, command: str, handler: EventHandler) -> str: - self._rpc_client.add_event_handler(f"{command}:{self.mxid}", handler) + self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler) # endregion diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py index 8dc8913..1187461 100644 --- a/matrix_appservice_kakaotalk/portal.py +++ b/matrix_appservice_kakaotalk/portal.py @@ -93,13 +93,14 @@ class Portal(DBPortal, BasePortal): _main_intent: IntentAPI | None _create_room_lock: asyncio.Lock - _dedup: deque[str] - _oti_dedup: dict[int, DBMessage] _send_locks: dict[int, asyncio.Lock] _noop_lock: FakeLock = FakeLock() _typing: set[UserID] backfill_lock: SimpleLock _backfill_leave: set[IntentAPI] | None + _sleeping_to_resync: bool + _scheduled_resync: asyncio.Task | None + _resync_targets: dict[int, p.Puppet] def __init__( self, @@ -132,10 +133,11 @@ class Portal(DBPortal, BasePortal): self._main_intent = None self._create_room_lock = asyncio.Lock() - self._dedup = deque(maxlen=100) - self._oti_dedup = {} self._send_locks = {} self._typing = set() + self._sleeping_to_resync = False + self._scheduled_resync = None + self._resync_targets = {} self.backfill_lock = SimpleLock( "Waiting for backfilling to finish before handling %s", log=self.log @@ -190,12 +192,45 @@ class Portal(DBPortal, BasePortal): # endregion # region Chat info updating + def schedule_resync(self, source: u.User, target: p.Puppet) -> None: + self._resync_targets[target.ktid] = target + if ( + self._sleeping_to_resync + and self._scheduled_resync + and not self._scheduled_resync.done() + ): + return + self._sleeping_to_resync = True + self.log.debug(f"Scheduling resync through {source.mxid}/{source.ktid}") + self._scheduled_resync = asyncio.create_task(self._sleep_and_resync(source, 10)) + + async def _sleep_and_resync(self, source: u.User, sleep: int) -> None: + await asyncio.sleep(sleep) + targets = self._resync_targets + self._sleeping_to_resync = False + self._resync_targets = {} + for puppet in targets.values(): + if not puppet.name or not puppet.name_set: + break + else: + self.log.debug( + f"Cancelled resync through {source.mxid}/{source.ktid}, all puppets have names" + ) + return + self.log.debug(f"Resyncing chat through {source.mxid}/{source.ktid} after sleeping") + await self.update_info(source) + self._scheduled_resync = None + self.log.debug(f"Completed scheduled resync through {source.mxid}/{source.ktid}") + async def update_info( self, source: u.User, - info: PortalChannelInfo, + info: PortalChannelInfo | None = None, force_save: bool = False, - ) -> None: + ) -> PortalChannelInfo | None: + if not info: + self.log.debug("Called update_info with no info, fetching channel info...") + info = await source.client.get_portal_channel_info(self.ktid) changed = False if not self.is_direct: changed = any( @@ -209,6 +244,7 @@ class Portal(DBPortal, BasePortal): if changed or force_save: await self.update_bridge_info() await self.save() + return info """ @classmethod @@ -365,7 +401,7 @@ class Portal(DBPortal, BasePortal): # endregion # region Matrix room creation - async def update_matrix_room(self, source: u.User, info: PortalChannelInfo) -> None: + async def update_matrix_room(self, source: u.User, info: PortalChannelInfo | None = None) -> None: try: await self._update_matrix_room(source, info) except Exception: @@ -380,7 +416,7 @@ class Portal(DBPortal, BasePortal): return invite_content async def _update_matrix_room( - self, source: u.User, info: PortalChannelInfo + self, source: u.User, info: PortalChannelInfo | None = None ) -> None: puppet = await p.Puppet.get_by_custom_mxid(source.mxid) await self.main_intent.invite_user( @@ -394,7 +430,10 @@ class Portal(DBPortal, BasePortal): if did_join and self.is_direct: await source.update_direct_chats({self.main_intent.mxid: [self.mxid]}) - await self.update_info(source, info) + info = await self.update_info(source, info) + if not info: + self.log.warning("Canceling _update_matrix_room as update_info didn't return info") + return # TODO #await self._sync_read_receipts(info.read_receipts.nodes) @@ -421,7 +460,7 @@ class Portal(DBPortal, BasePortal): """ async def create_matrix_room( - self, source: u.User, info: PortalChannelInfo + self, source: u.User, info: PortalChannelInfo | None = None ) -> RoomID | None: if self.mxid: try: @@ -474,7 +513,7 @@ class Portal(DBPortal, BasePortal): self.log.warning("Failed to update bridge info", exc_info=True) async def _create_matrix_room( - self, source: u.User, info: PortalChannelInfo + self, source: u.User, info: PortalChannelInfo | None = None ) -> RoomID | None: if self.mxid: await self._update_matrix_room(source, info) @@ -507,7 +546,10 @@ class Portal(DBPortal, BasePortal): if self.is_direct: invites.append(self.az.bot_mxid) - await self.update_info(source=source, info=info) + info = await self.update_info(source=source, info=info) + if not info: + self.log.debug("update_info() didn't return info, cancelling room creation") + return None if self.encrypted or not self.is_direct: name = self.name diff --git a/matrix_appservice_kakaotalk/user.py b/matrix_appservice_kakaotalk/user.py index ba75fcd..40a8e61 100644 --- a/matrix_appservice_kakaotalk/user.py +++ b/matrix_appservice_kakaotalk/user.py @@ -39,14 +39,12 @@ from .kt.client import Client from .kt.client.errors import AuthenticationRequired, ResponseError from .kt.types.api.struct.profile import ProfileStruct from .kt.types.bson import Long -from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo -from .kt.types.channel.channel_info import NormalChannelData +from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData +from .kt.types.channel.channel_type import ChannelType from .kt.types.chat.chat import Chatlog -from .kt.types.client.client_session import ChannelLoginDataItem -from .kt.types.client.client_session import LoginResult +from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult from .kt.types.oauth import OAuthCredential -from .kt.types.openlink.open_channel_info import OpenChannelData -from .kt.types.openlink.open_channel_info import OpenChannelInfo +from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels") METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync") @@ -418,7 +416,7 @@ class User(DBUser, BaseUser): # TODO if not is_startup, close existing listeners login_result = await self.client.start() await self._sync_channels(login_result, is_startup) - # TODO connect listeners, even if channel sync fails (except if it's an auth failure) + self.start_listen() except AuthenticationRequired as e: await self.send_bridge_notice( f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n", @@ -498,7 +496,8 @@ class User(DBUser, BaseUser): kt_receiver=self.ktid, kt_type=channel_info.type ) - portal_info = await self.client.get_portal_channel_info(channel_info) + portal_info = await self.client.get_portal_channel_info(channel_info.channelId) + portal_info.channel_info = channel_info if not portal.mxid: await portal.create_matrix_room(self, portal_info) else: @@ -585,6 +584,30 @@ class User(DBUser, BaseUser): # region KakaoTalk event handling + def start_listen(self) -> None: + self.listen_task = asyncio.create_task(self._try_listen()) + + def _disconnect_listener_after_error(self) -> None: + self.log.info("TODO: _disconnect_listener_after_error") + + async def _try_listen(self) -> None: + try: + # TODO Pass all listeners to start_listen instead of registering them one-by-one? + await self.client.start_listen() + await self.client.on_message(self.on_message) + # TODO Handle auth errors specially? + #except AuthenticationRequired as e: + except Exception: + #self.is_connected = False + self.log.exception("Fatal error in listener") + await self.send_bridge_notice( + "Fatal error in listener (see logs for more info)", + state_event=BridgeStateEvent.UNKNOWN_ERROR, + important=True, + error_code="kt-connection-error", + ) + self._disconnect_listener_after_error() + def stop_listen(self) -> None: self.log.info("TODO: stop_listen") @@ -603,8 +626,18 @@ class User(DBUser, BaseUser): asyncio.create_task(self.post_login(is_startup=True)) @async_time(METRIC_MESSAGE) - async def on_message(self, evt: Chatlog, channel_id: Long) -> None: - self.log.info("TODO: on_message") + async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None: + portal = await po.Portal.get_by_ktid( + channel_id, + kt_receiver=self.ktid, + kt_type=channel_type + ) + puppet = await pu.Puppet.get_by_ktid(evt.sender.userId) + await portal.backfill_lock.wait(evt.logId) + if not puppet.name: + portal.schedule_resync(self, puppet) + # TODO reply_to + await portal.handle_remote_message(self, puppet, evt) # TODO Many more handlers diff --git a/node/src/client.js b/node/src/client.js index 474af59..e20c40f 100644 --- a/node/src/client.js +++ b/node/src/client.js @@ -266,28 +266,33 @@ export default class PeerClient { const res = await userClient.talkClient.login(req.oauth_credential) if (!res.success) return res - // Attach listeners in something like start_listen - /* + this.userClients.set(req.mxid, userClient) + return res + } + + startListen = async (req) => { + const userClient = this.#getUser(req.mxid) + userClient.talkClient.on("chat", (data, channel) => { - this.log(`Found message in channel ${channel.channelId}`) + this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`) return this.#write({ id: --this.notificationID, - command: userClient.getCmd("chat"), - //is_sequential: true, // TODO make sequential per user! - chatlog: data.chat(), + command: userClient.getCmd("message"), + //is_sequential: true, // TODO Make sequential per user & channel (if it isn't already) + chatlog: data.chat, channelId: channel.channelId, + channelType: channel.info.type, }) }) - /* + /* TODO Many more listeners userClient.talkClient.on("chat_read", (chat, channel, reader) => { this.log(`chat_read in channel ${channel.channelId}`) //chat.logId }) */ - this.userClients.set(req.mxid, userClient) - return res + return this.#voidCommandResult } /** @@ -448,6 +453,7 @@ export default class PeerClient { renew: this.handleRenew, generate_uuid: util.randomAndroidSubDeviceUUID, register_device: this.registerDevice, + start_listen: this.startListen, get_own_profile: this.getOwnProfile, get_portal_channel_info: this.getPortalChannelInfo, get_chats: this.getChats,