# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from __future__ import annotations from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast import asyncio import time from mautrix.bridge import BaseUser, async_getter_lock from mautrix.types import ( EventID, MessageType, RoomID, TextMessageEventContent, UserID, ) from mautrix.util.bridge_state import BridgeState, BridgeStateEvent from mautrix.util.opt_prometheus import Gauge, Summary, async_time from mautrix.util.simple_lock import SimpleLock from . import portal as po, puppet as pu from .config import Config from .db import User as DBUser from .kt.client import Client from .kt.client.errors import AuthenticationRequired, ResponseError from .kt.types.api.struct.profile import ProfileStruct from .kt.types.bson import Long from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData from .kt.types.channel.channel_type import ChannelType from .kt.types.chat.chat import Chatlog from .kt.types.client.client_session import LoginDataItem, LoginResult from .kt.types.oauth import OAuthCredential from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo from .kt.types.packet.chat.kickout import KnownKickoutType, KickoutRes METRIC_CONNECT_AND_SYNC = Summary("bridge_connect_and_sync", "calls to connect_and_sync") METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message") METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge") METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk") if TYPE_CHECKING: from .__main__ import KakaoTalkBridge BridgeState.human_readable_errors.update( { "kt-reconnection-error": "Failed to reconnect to KakaoTalk", # TODO Use for errors in Node backend that cause session to be lost #"kt-connection-error": "KakaoTalk disconnected unexpectedly", "kt-auth-error": "Authentication error from KakaoTalk: {message}", "logged-out": "You're not logged into KakaoTalk", } ) class User(DBUser, BaseUser): temp_disconnect_notices: bool = True shutdown: bool = False config: Config by_mxid: dict[UserID, User] = {} by_ktid: dict[int, User] = {} client: Client | None _notice_room_lock: asyncio.Lock _notice_send_lock: asyncio.Lock is_admin: bool permission_level: str _is_logged_in: bool | None _is_connected: bool | None _connection_time: float _db_instance: DBUser | None _sync_lock: SimpleLock _logged_in_info: ProfileStruct | None _logged_in_info_time: float def __init__( self, mxid: UserID, ktid: Long | None = None, uuid: str | None = None, access_token: str | None = None, refresh_token: str | None = None, notice_room: RoomID | None = None, ) -> None: super().__init__( mxid=mxid, ktid=ktid, uuid=uuid, access_token=access_token, refresh_token=refresh_token, notice_room=notice_room ) BaseUser.__init__(self) self.notice_room = notice_room self._notice_room_lock = asyncio.Lock() self._notice_send_lock = asyncio.Lock() ( self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level, ) = self.config.get_permissions(mxid) self._is_logged_in = None self._is_connected = None self._connection_time = time.monotonic() self._sync_lock = SimpleLock( "Waiting for thread sync to finish before handling %s", log=self.log ) self._logged_in_info = None self._logged_in_info_time = 0 self.client = None @classmethod def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]: cls.bridge = bridge cls.config = bridge.config cls.az = bridge.az cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"] return (user.reload_session(is_startup=True) async for user in cls.all_logged_in()) @property def is_connected(self) -> bool | None: return self._is_connected @is_connected.setter def is_connected(self, val: bool | None) -> None: if self._is_connected != val: self._is_connected = val self._connection_time = time.monotonic() @property def connection_time(self) -> float: return self._connection_time @property def has_state(self) -> bool: # TODO If more state is needed, consider returning a saved LoginResult return bool(self.access_token and self.refresh_token) # region Database getters def _add_to_cache(self) -> None: self.by_mxid[self.mxid] = self if self.ktid: self.by_ktid[self.ktid] = self @classmethod async def all_logged_in(cls) -> AsyncGenerator["User", None]: users = await super().all_logged_in() user: cls for user in users: try: yield cls.by_mxid[user.mxid] except KeyError: user._add_to_cache() yield user @classmethod @async_getter_lock async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> User | None: if pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid: return None try: return cls.by_mxid[mxid] except KeyError: pass user = cast(cls, await super().get_by_mxid(mxid)) if user is not None: user._add_to_cache() return user if create: cls.log.debug(f"Creating user instance for {mxid}") user = cls(mxid) await user.insert() user._add_to_cache() return user return None @classmethod @async_getter_lock async def get_by_ktid(cls, ktid: int) -> User | None: try: return cls.by_ktid[ktid] except KeyError: pass user = cast(cls, await super().get_by_ktid(ktid)) if user is not None: user._add_to_cache() return user return None async def get_uuid(self, force: bool = False) -> str: if self.uuid is None or force: self.uuid = await self._generate_uuid() # TODO Maybe don't save yet await self.save() return self.uuid @classmethod async def _generate_uuid(cls) -> str: return await Client.generate_uuid(await super().get_all_uuids()) # endregion @property def oauth_credential(self) -> OAuthCredential: return OAuthCredential( self.ktid, self.uuid, self.access_token, self.refresh_token, ) @oauth_credential.setter def oauth_credential(self, oauth_credential: OAuthCredential) -> None: self.ktid = oauth_credential.userId self.access_token = oauth_credential.accessToken self.refresh_token = oauth_credential.refreshToken if self.uuid != oauth_credential.deviceUUID: self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}") self.uuid = oauth_credential.deviceUUID async def get_own_info(self) -> ProfileStruct: if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic(): self._logged_in_info = await self.client.get_own_profile() self._logged_in_info_time = time.monotonic() return self._logged_in_info async def _load_session(self, is_startup: bool) -> bool: if self._is_logged_in and is_startup: return True elif not self.has_state: # If we have a user in the DB with no state, we can assume # KT logged us out and the bridge has restarted await self.push_bridge_state( BridgeStateEvent.BAD_CREDENTIALS, error="logged-out", ) return False client = Client(self, log=self.log.getChild("ktclient")) user_info = await client.start() # NOTE On failure, client.start throws instead of returning False self.log.info("Loaded session successfully") self.client = client self._logged_in_info = user_info self._logged_in_info_time = time.monotonic() self._track_metric(METRIC_LOGGED_IN, True) self._is_logged_in = True self.is_connected = None asyncio.create_task(self.post_login(is_startup=is_startup)) return True async def _send_reset_notice(self, e: AuthenticationRequired, edit: EventID | None = None) -> None: await self.send_bridge_notice( "Got authentication error from KakaoTalk:\n\n" f"> {e.message}\n\n" "If you changed your KakaoTalk password, this " "is normal and you just need to log in again.", edit=edit, important=True, state_event=BridgeStateEvent.BAD_CREDENTIALS, error_code="kt-auth-error", error_message=str(e), ) await self.logout(remove_ktid=False) async def is_logged_in(self, _override: bool = False) -> bool: if not self.has_state or not self.client: return False if self._is_logged_in is None or _override: try: self._is_logged_in = bool(await self.get_own_info()) except Exception: self.log.exception("Exception checking login status") self._is_logged_in = False return self._is_logged_in async def reload_session( self, event_id: EventID | None = None, retries: int = 3, is_startup: bool = False ) -> None: try: await self._load_session(is_startup=is_startup) except AuthenticationRequired as e: await self._send_reset_notice(e, edit=event_id) # TODO Throw a ResponseError on network failures except ResponseError as e: will_retry = retries > 0 retry = "Retrying in 1 minute" if will_retry else "Not retrying" notice = f"Failed to connect to KakaoTalk: unknown response error {e}. {retry}" if will_retry: await self.send_bridge_notice( notice, edit=event_id, state_event=BridgeStateEvent.TRANSIENT_DISCONNECT, ) await asyncio.sleep(60) await self.reload_session(event_id, retries - 1, is_startup) else: await self.send_bridge_notice( notice, edit=event_id, important=True, state_event=BridgeStateEvent.UNKNOWN_ERROR, error_code="kt-reconnection-error", ) except Exception: self.log.exception("Error connecting to KakaoTalk") await self.send_bridge_notice( "Failed to connect to KakaoTalk: unknown error (see logs for more details)", edit=event_id, state_event=BridgeStateEvent.UNKNOWN_ERROR, error_code="kt-reconnection-error", ) async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None: if self.client: # TODO Look for a logout API call await self.client.stop() if remove_ktid: await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT) self._track_metric(METRIC_LOGGED_IN, False) if reset_device: self.uuid = await self._generate_uuid() self.access_token = None self.refresh_token = None self._is_logged_in = False self.is_connected = None self.client = None if self.ktid and remove_ktid: #await UserPortal.delete_all(self.ktid) del self.by_ktid[self.ktid] self.ktid = None await self.save() async def post_login(self, is_startup: bool) -> None: self.log.info("Running post-login actions") self._add_to_cache() try: puppet = await pu.Puppet.get_by_ktid(self.ktid) if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid): self.log.info(f"Automatically enabling custom puppet") await puppet.switch_mxid(access_token="auto", mxid=self.mxid) except Exception: self.log.exception("Failed to automatically enable custom puppet") # TODO Check if things break when a live message comes in during syncing if self.config["bridge.sync_on_startup"] or not is_startup: sync_count = self.config["bridge.initial_chat_sync"] else: sync_count = None # TODO Don't auto-connect on startup if user's last state was disconnected await self.connect_and_sync(sync_count) async def get_direct_chats(self) -> dict[UserID, list[RoomID]]: return { pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid] async for portal in po.Portal.get_all_by_receiver(self.ktid) if portal.mxid } @async_time(METRIC_CONNECT_AND_SYNC) async def connect_and_sync(self, sync_count: int | None) -> bool: # TODO Look for a way to sync all channels without (re-)logging in try: login_result = await self.client.connect() await self.on_connect() await self._sync_channels(login_result, sync_count) return True except AuthenticationRequired as e: await self.send_bridge_notice( f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n", important=True, state_event=BridgeStateEvent.BAD_CREDENTIALS, error_code="kt-auth-error", error_message=str(e), ) await self.logout(remove_ktid=False) except Exception as e: self.log.exception("Failed to connect and sync") await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e)) return False async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None: if sync_count is None: sync_count = self.config["bridge.initial_chat_sync"] if not sync_count: self.log.debug("Skipping channel syncing") return if not login_result.channelList: self.log.debug("No channels to sync") return # TODO What about removed channels? Don't early-return then num_channels = len(login_result.channelList) sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels) await self.push_bridge_state(BridgeStateEvent.BACKFILLING) self.log.debug(f"Syncing {sync_count} of {num_channels} channels...") for login_data in login_result.channelList[:sync_count]: try: await self._sync_channel(login_data) except AuthenticationRequired: raise except Exception: self.log.exception(f"Failed to sync channel {login_data.channel.channelId}") await self.update_direct_chats() async def _sync_channel(self, login_data: LoginDataItem) -> None: channel_data = login_data.channel self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {login_data.lastUpdate})") channel_info = channel_data.info if isinstance(channel_data, NormalChannelData): channel_data: NormalChannelData channel_info: NormalChannelInfo self.log.debug(f"Join time: {channel_info.joinTime}") elif isinstance(channel_data, OpenChannelData): channel_data: OpenChannelData self.log.debug(f"channel_data link ID: {channel_data.linkId}") channel_info: OpenChannelInfo self.log.debug(f"channel_info link ID: {channel_info.linkId}") self.log.debug(f"openToken: {channel_info.openToken}") self.log.debug(f"Is direct channel: {channel_info.directChannel}") self.log.debug(f"Has OpenLink: {channel_info.openLink is not None}") else: self.log.error(f"Unexpected channel type: {type(channel_data)}") channel_info: ChannelInfo self.log.debug(f"channel_info channel ID: {channel_info.channelId}") self.log.debug(f"Channel data/info IDs match: {channel_data.channelId == channel_info.channelId}") self.log.debug(f"Channel type: {channel_info.type}") self.log.debug(f"Active user count: {channel_info.activeUserCount}") self.log.debug(f"New chat count: {channel_info.newChatCount}") self.log.debug(f"New chat count invalid: {channel_info.newChatCountInvalid}") self.log.debug(f"Last chat log ID: {channel_info.lastChatLogId}") self.log.debug(f"Last seen log ID: {channel_info.lastSeenLogId}") self.log.debug(f"Has last chat log: {channel_info.lastChatLog is not None}") self.log.debug(f"metaMap: {channel_info.metaMap}") self.log.debug(f"User count: {len(channel_info.displayUserList)}") self.log.debug(f"Has push alert: {channel_info.pushAlert}") for display_user_info in channel_info.displayUserList: self.log.debug(f"Member: {display_user_info.nickname} - {display_user_info.profileURL} - {display_user_info.userId}") portal = await po.Portal.get_by_ktid( channel_info.channelId, kt_receiver=self.ktid, kt_type=channel_info.type ) portal_info = await self.client.get_portal_channel_info(portal.channel_props) portal_info.channel_info = channel_info if not portal.mxid: await portal.create_matrix_room(self, portal_info) else: await portal.update_matrix_room(self, portal_info) await portal.backfill(self, is_initial=False, channel_info=channel_info) async def get_notice_room(self) -> RoomID: if not self.notice_room: async with self._notice_room_lock: # If someone already created the room while this call was waiting, # don't make a new room if self.notice_room: return self.notice_room creation_content = {} if not self.config["bridge.federate_rooms"]: creation_content["m.federate"] = False self.notice_room = await self.az.intent.create_room( is_direct=True, invitees=[self.mxid], topic="KakaoTalk bridge notices", creation_content=creation_content, ) await self.save() return self.notice_room async def send_bridge_notice( self, text: str, edit: EventID | None = None, state_event: BridgeStateEvent | None = None, important: bool = False, error_code: str | None = None, error_message: str | None = None, ) -> EventID | None: if state_event: await self.push_bridge_state( state_event, error=error_code, message=error_message if error_code else text, ) if self.config["bridge.disable_bridge_notices"]: return None event_id = None try: self.log.debug("Sending bridge notice: %s", text) content = TextMessageEventContent( body=text, msgtype=(MessageType.TEXT if important else MessageType.NOTICE), ) if edit: content.set_edit(edit) # This is locked to prevent notices going out in the wrong order async with self._notice_send_lock: event_id = await self.az.intent.send_message(await self.get_notice_room(), content) except Exception: self.log.warning("Failed to send bridge notice", exc_info=True) return edit or event_id async def fill_bridge_state(self, state: BridgeState) -> None: await super().fill_bridge_state(state) if self.ktid: state.remote_id = str(self.ktid) puppet = await pu.Puppet.get_by_ktid(self.ktid) state.remote_name = puppet.name async def get_bridge_states(self) -> list[BridgeState]: if not self.state: return [] state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR) if self.is_connected: state.state_event = BridgeStateEvent.CONNECTED # TODO #elif self._is_logged_in and self._is_reconnecting: # state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT return [state] async def get_puppet(self) -> pu.Puppet | None: if not self.ktid: return None return await pu.Puppet.get_by_ktid(self.ktid) async def get_portal_with(self, puppet: pu.Puppet, create: bool = True) -> po.Portal | None: # TODO return None """ if not self.ktid or not self.client: return None return await po.Portal.get_by_ktid( await self.client.get_dm_channel_id_for(puppet.ktid), kt_receiver=self.ktid, create=create, kt_type=KnownChannelType.DirectChat if puppet.ktid != self.ktid else KnownChannelType.MemoChat ) """ # region KakaoTalk event handling async def on_connect(self) -> None: now = time.monotonic() disconnected_at = self._connection_time max_delay = self.config["bridge.resync_max_disconnected_time"] first_connect = self.is_connected is None self.is_connected = True self._track_metric(METRIC_CONNECTED, True) if not first_connect and disconnected_at + max_delay < now: duration = int(now - disconnected_at) self.log.debug(f"Disconnection lasted {duration} seconds") elif self.temp_disconnect_notices: await self.send_bridge_notice("Connected to KakaoTalk chats") await self.push_bridge_state(BridgeStateEvent.CONNECTED) async def on_disconnect(self, res: KickoutRes | None) -> None: self.is_connected = False self._track_metric(METRIC_CONNECTED, False) if res: logout = False if res.reason == KnownKickoutType.LOGIN_ANOTHER: reason_str = "Logged in from another desktop client." elif res.reason == KnownKickoutType.CHANGE_SERVER: # TODO Reconnect automatically instead reason_str = "KakaoTalk backend server changed." elif res.reason == KnownKickoutType.ACCOUNT_DELETED: reason_str = "Your KakaoTalk account was deleted!" logout = True else: reason_str = f"Unknown reason ({res.reason})." if not logout: reason_suffix = " To reconnect, use the `sync` command." # TODO What bridge state to push? else: reason_suffix = " You are now logged out." await self.logout() await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}") async def on_logged_in(self, oauth_credential: OAuthCredential) -> None: self.log.debug(f"Successfully logged in as {oauth_credential.userId}") self.oauth_credential = oauth_credential await self.push_bridge_state(BridgeStateEvent.CONNECTING) self.client = Client(self, log=self.log.getChild("ktclient")) await self.save() self._is_logged_in = True # TODO Retry network connection failures here, or in the client (like token refreshes are)? # Should also catch unlikely authentication errors self._logged_in_info = await self.client.start() self._logged_in_info_time = time.monotonic() asyncio.create_task(self.post_login(is_startup=True)) @async_time(METRIC_MESSAGE) async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None: portal = await po.Portal.get_by_ktid( channel_id, kt_receiver=self.ktid, kt_type=channel_type ) puppet = await pu.Puppet.get_by_ktid(evt.sender.userId) await portal.backfill_lock.wait(evt.logId) if not puppet.name: portal.schedule_resync(self, puppet) await portal.handle_remote_message(self, puppet, evt) # TODO Many more handlers # endregion