# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from __future__ import annotations from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast import asyncio import time from mautrix.bridge import BaseUser, async_getter_lock from mautrix.types import ( EventID, MessageType, RoomID, TextMessageEventContent, UserID, ) from mautrix.util.bridge_state import BridgeState, BridgeStateEvent from mautrix.util.opt_prometheus import Gauge, Summary, async_time from mautrix.util.simple_lock import SimpleLock from . import portal as po, puppet as pu from .config import Config from .db import User as DBUser from .kt.client import Client from .kt.client.errors import AuthenticationRequired, ResponseError from .kt.types.api.struct.profile import ProfileStruct from .kt.types.bson import Long from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData from .kt.types.channel.channel_type import ChannelType from .kt.types.chat.chat import Chatlog from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult from .kt.types.oauth import OAuthCredential from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels") METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync") METRIC_UNKNOWN_EVENT = Summary("bridge_on_unknown_event", "calls to on_unknown_event") METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added") METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed") METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing") METRIC_PRESENCE = Summary("bridge_on_presence", "calls to on_presence") METRIC_REACTION = Summary("bridge_on_reaction", "calls to on_reaction") METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent") METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen") METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change") METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change") METRIC_THREAD_CHANGE = Summary("bridge_on_thread_change", "calls to on_thread_change") METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message") METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge") METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk") if TYPE_CHECKING: from .__main__ import KakaoTalkBridge BridgeState.human_readable_errors.update( { "kt-reconnection-error": "Failed to reconnect to KakaoTalk", "kt-connection-error": "KakaoTalk disconnected unexpectedly", "kt-auth-error": "Authentication error from KakaoTalk: {message}", "kt-start-error": "Startup error from KakaoTalk: {message}", "kt-disconnected": None, "logged-out": "You're not logged into KakaoTalk", } ) class User(DBUser, BaseUser): #temp_disconnect_notices: bool = True shutdown: bool = False config: Config by_mxid: dict[UserID, User] = {} by_ktid: dict[int, User] = {} client: Client | None _notice_room_lock: asyncio.Lock _notice_send_lock: asyncio.Lock command_status: dict | None is_admin: bool permission_level: str _is_logged_in: bool | None #_is_connected: bool | None #_connection_time: float _prev_reconnect_fail_refresh: float _db_instance: DBUser | None _sync_lock: SimpleLock _is_refreshing: bool _logged_in_info: ProfileStruct | None _logged_in_info_time: float def __init__( self, mxid: UserID, ktid: Long | None = None, uuid: str | None = None, access_token: str | None = None, refresh_token: str | None = None, notice_room: RoomID | None = None, ) -> None: super().__init__( mxid=mxid, ktid=ktid, uuid=uuid, access_token=access_token, refresh_token=refresh_token, notice_room=notice_room ) BaseUser.__init__(self) self.notice_room = notice_room self._notice_room_lock = asyncio.Lock() self._notice_send_lock = asyncio.Lock() self.command_status = None ( self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level, ) = self.config.get_permissions(mxid) self._is_logged_in = None #self._is_connected = None #self._connection_time = time.monotonic() self._prev_reconnect_fail_refresh = time.monotonic() self._sync_lock = SimpleLock( "Waiting for thread sync to finish before handling %s", log=self.log ) self._is_refreshing = False self._logged_in_info = None self._logged_in_info_time = 0 self.client = None @classmethod def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]: cls.bridge = bridge cls.config = bridge.config cls.az = bridge.az cls.loop = bridge.loop #cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"] return (user.reload_session(is_startup=True) async for user in cls.all_logged_in()) """ @property def is_connected(self) -> bool | None: return self._is_connected @is_connected.setter def is_connected(self, val: bool | None) -> None: if self._is_connected != val: self._is_connected = val self._connection_time = time.monotonic() @property def connection_time(self) -> float: return self._connection_time """ @property def has_state(self) -> bool: return bool(self.uuid and self.ktid and self.access_token and self.refresh_token) # region Database getters def _add_to_cache(self) -> None: self.by_mxid[self.mxid] = self if self.ktid: self.by_ktid[self.ktid] = self @classmethod async def all_logged_in(cls) -> AsyncGenerator["User", None]: users = await super().all_logged_in() user: cls for user in users: try: yield cls.by_mxid[user.mxid] except KeyError: user._add_to_cache() yield user @classmethod @async_getter_lock async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> User | None: if pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid: return None try: return cls.by_mxid[mxid] except KeyError: pass user = cast(cls, await super().get_by_mxid(mxid)) if user is not None: user._add_to_cache() return user if create: cls.log.debug(f"Creating user instance for {mxid}") user = cls(mxid) await user.insert() user._add_to_cache() return user return None @classmethod @async_getter_lock async def get_by_ktid(cls, ktid: int) -> User | None: try: return cls.by_ktid[ktid] except KeyError: pass user = cast(cls, await super().get_by_ktid(ktid)) if user is not None: user._add_to_cache() return user return None async def get_uuid(self, force: bool = False) -> str: if self.uuid is None or force: self.uuid = await self._generate_uuid() await self.save() return self.uuid async def _generate_uuid(self) -> str: return await Client.generate_uuid(await self.get_all_uuids()) # endregion @property def oauth_credential(self) -> OAuthCredential: return OAuthCredential( self.ktid, self.uuid, self.access_token, self.refresh_token, ) @oauth_credential.setter def oauth_credential(self, oauth_credential: OAuthCredential) -> None: self.ktid = oauth_credential.userId self.access_token = oauth_credential.accessToken self.refresh_token = oauth_credential.refreshToken if self.uuid != oauth_credential.deviceUUID: self.log.warn(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}") self.uuid = oauth_credential.deviceUUID async def get_own_info(self) -> ProfileStruct: if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic(): self._logged_in_info = await self.client.fetch_logged_in_user() self._logged_in_info_time = time.monotonic() return self._logged_in_info async def _load_session(self, is_startup: bool) -> bool: if self._is_logged_in and not is_startup: return True elif not self.has_state: # If we have a user in the DB with no state, we can assume # KT logged us out and the bridge has restarted await self.push_bridge_state( BridgeStateEvent.BAD_CREDENTIALS, error="logged-out", ) return False client = Client(self, log=self.log.getChild("ktclient")) user_info = await self.fetch_logged_in_user(client) if user_info: self.log.info("Loaded session successfully") self.client = client self._logged_in_info = user_info self._logged_in_info_time = time.monotonic() self._track_metric(METRIC_LOGGED_IN, True) self._is_logged_in = True #self.is_connected = None self.stop_listen() asyncio.create_task(self.post_login(is_startup=is_startup)) return True return False async def _send_reset_notice(self, e: AuthenticationRequired, edit: EventID | None = None) -> None: await self.send_bridge_notice( "Got authentication error from KakaoTalk:\n\n" f"> {e.message}\n\n" "If you changed your KakaoTalk password, this " "is normal and you just need to log in again.", edit=edit, important=True, state_event=BridgeStateEvent.BAD_CREDENTIALS, error_code="kt-auth-error", error_message=str(e), ) await self.logout(remove_ktid=False) async def fetch_logged_in_user( self, client: Client | None = None, action: str = "restore session" ) -> ProfileStruct: if not client: client = self.client # TODO Retry network connection failures here, or in the client? try: return await client.fetch_logged_in_user() # NOTE Not catching InvalidAccessToken here, as client handles it & tries to refresh the token except AuthenticationRequired as e: if action != "restore session": await self._send_reset_notice(e) raise except Exception: self.log.exception(f"Failed to {action}") raise async def is_logged_in(self, _override: bool = False) -> bool: if not self.has_state or not self.client: return False if self._is_logged_in is None or _override: try: self._is_logged_in = bool(await self.get_own_info()) except Exception: self.log.exception("Exception checking login status") self._is_logged_in = False return self._is_logged_in async def reload_session( self, event_id: EventID | None = None, retries: int = 3, is_startup: bool = False ) -> None: try: await self._load_session(is_startup=is_startup) except AuthenticationRequired as e: await self._send_reset_notice(e, edit=event_id) # TODO Throw a ResponseError on network failures except ResponseError as e: will_retry = retries > 0 retry = "Retrying in 1 minute" if will_retry else "Not retrying" notice = f"Failed to connect to KakaoTalk: unknown response error {e}. {retry}" if will_retry: await self.send_bridge_notice( notice, edit=event_id, state_event=BridgeStateEvent.TRANSIENT_DISCONNECT, ) await asyncio.sleep(60) await self.reload_session(event_id, retries - 1) else: await self.send_bridge_notice( notice, edit=event_id, important=True, state_event=BridgeStateEvent.UNKNOWN_ERROR, error_code="kt-reconnection-error", ) except Exception: await self.send_bridge_notice( "Failed to connect to KakaoTalk: unknown error (see logs for more details)", edit=event_id, state_event=BridgeStateEvent.UNKNOWN_ERROR, error_code="kt-reconnection-error", ) finally: self._is_refreshing = False async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> bool: ok = True self.stop_listen() if self.has_state: # TODO Log out of KakaoTalk if an API exists for it pass if remove_ktid: await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT) self._track_metric(METRIC_LOGGED_IN, False) self._is_logged_in = False #self.is_connected = None if self.client: await self.client.stop() self.client = None if self.ktid and remove_ktid: #await UserPortal.delete_all(self.ktid) del self.by_ktid[self.ktid] self.ktid = None if reset_device: self.uuid = await self._generate_uuid() self.access_token = None self.refresh_token = None await self.save() return ok async def post_login(self, is_startup: bool) -> bool: self.log.info("Running post-login actions") self._add_to_cache() try: puppet = await pu.Puppet.get_by_ktid(self.ktid) if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid): self.log.info(f"Automatically enabling custom puppet") await puppet.switch_mxid(access_token="auto", mxid=self.mxid) except Exception: self.log.exception("Failed to automatically enable custom puppet") assert self.client try: # TODO if not is_startup, close existing listeners login_result = await self.client.start() await self._sync_channels(login_result, is_startup) self.start_listen() return True except AuthenticationRequired as e: await self.send_bridge_notice( f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n", important=True, state_event=BridgeStateEvent.BAD_CREDENTIALS, error_code="kt-auth-error", error_message=str(e), ) await self.logout(remove_ktid=False) except Exception as e: self.log.exception("Failed to start client") await self.send_bridge_notice( f"Got error from KakaoTalk:\n\n> {e!s}\n\n", important=True, state_event=BridgeStateEvent.UNKNOWN_ERROR, error_code="kt-start-error", error_message=str(e), ) return False async def get_direct_chats(self) -> dict[UserID, list[RoomID]]: return { pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid] async for portal in po.Portal.get_all_by_receiver(self.ktid) if portal.mxid } @async_time(METRIC_SYNC_CHANNELS) async def _sync_channels(self, login_result: LoginResult, is_startup: bool) -> None: # TODO Look for a way to sync all channels without (re-)logging in sync_count = self.config["bridge.initial_chat_sync"] if sync_count <= 0 or not self.config["bridge.sync_on_startup"] and is_startup: self.log.debug(f"Skipping channel syncing{' on startup' if sync_count > 0 else ''}") return if not login_result.channelList: self.log.debug("No channels to sync") return # TODO What about removed channels? Don't early-return then sync_count = min(sync_count, len(login_result.channelList)) await self.push_bridge_state(BridgeStateEvent.BACKFILLING) self.log.debug(f"Syncing {sync_count} of {login_result.channelList} channels...") for channel_item in login_result.channelList[:sync_count]: # TODO try-except here, above, below? await self._sync_channel(channel_item) async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None: channel_data = channel_item.channel self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {channel_item.lastUpdate})") channel_info = channel_data.info if isinstance(channel_data, NormalChannelData): channel_data: NormalChannelData channel_info: NormalChannelInfo self.log.debug(f"Join time: {channel_info.joinTime}") elif isinstance(channel_data, OpenChannelData): channel_data: OpenChannelData self.log.debug(f"channel_data link ID: {channel_data.linkId}") channel_info: OpenChannelInfo self.log.debug(f"channel_info link ID: {channel_info.linkId}") self.log.debug(f"openToken: {channel_info.openToken}") self.log.debug(f"Is direct channel: {channel_info.directChannel}") self.log.debug(f"Has OpenLink: {channel_info.openLink is not None}") else: self.log.error(f"Unexpected channel type: {type(channel_data)}") channel_info: ChannelInfo self.log.debug(f"channel_info channel ID: {channel_info.channelId}") self.log.debug(f"Channel data/info IDs match: {channel_data.channelId == channel_info.channelId}") self.log.debug(f"Channel type: {channel_info.type}") self.log.debug(f"Active user count: {channel_info.activeUserCount}") self.log.debug(f"New chat count: {channel_info.newChatCount}") self.log.debug(f"New chat count invalid: {channel_info.newChatCountInvalid}") self.log.debug(f"Last chat log ID: {channel_info.lastChatLogId}") self.log.debug(f"Last seen log ID: {channel_info.lastSeenLogId}") self.log.debug(f"Has last chat log: {channel_info.lastChatLog is not None}") self.log.debug(f"metaMap: {channel_info.metaMap}") self.log.debug(f"User count: {len(channel_info.displayUserList)}") self.log.debug(f"Has push alert: {channel_info.pushAlert}") for display_user_info in channel_info.displayUserList: self.log.debug(f"Member: {display_user_info.nickname} - {display_user_info.profileURL} - {display_user_info.userId}") portal = await po.Portal.get_by_ktid( channel_info.channelId, kt_receiver=self.ktid, kt_type=channel_info.type ) portal_info = await self.client.get_portal_channel_info(channel_info.channelId) portal_info.channel_info = channel_info if not portal.mxid: await portal.create_matrix_room(self, portal_info) else: await portal.update_matrix_room(self, portal_info) await portal.backfill(self, is_initial=False, channel_info=channel_info) async def get_notice_room(self) -> RoomID: if not self.notice_room: async with self._notice_room_lock: # If someone already created the room while this call was waiting, # don't make a new room if self.notice_room: return self.notice_room creation_content = {} if not self.config["bridge.federate_rooms"]: creation_content["m.federate"] = False self.notice_room = await self.az.intent.create_room( is_direct=True, invitees=[self.mxid], topic="KakaoTalk bridge notices", creation_content=creation_content, ) await self.save() return self.notice_room async def send_bridge_notice( self, text: str, edit: EventID | None = None, state_event: BridgeStateEvent | None = None, important: bool = False, error_code: str | None = None, error_message: str | None = None, ) -> EventID | None: if state_event: await self.push_bridge_state( state_event, error=error_code, message=error_message if error_code else text, ) if self.config["bridge.disable_bridge_notices"]: return None event_id = None try: self.log.debug("Sending bridge notice: %s", text) content = TextMessageEventContent( body=text, msgtype=(MessageType.TEXT if important else MessageType.NOTICE), ) if edit: content.set_edit(edit) # This is locked to prevent notices going out in the wrong order async with self._notice_send_lock: event_id = await self.az.intent.send_message(await self.get_notice_room(), content) except Exception: self.log.warning("Failed to send bridge notice", exc_info=True) return edit or event_id async def fill_bridge_state(self, state: BridgeState) -> None: await super().fill_bridge_state(state) if self.ktid: state.remote_id = str(self.ktid) puppet = await pu.Puppet.get_by_ktid(self.ktid) state.remote_name = puppet.name async def get_bridge_states(self) -> list[BridgeState]: self.log.info("TODO: get_bridge_states") return [] """ if not self.state: return [] state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR) if self.is_connected: state.state_event = BridgeStateEvent.CONNECTED elif self._is_refreshing or self.mqtt: state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT return [state] """ async def get_puppet(self) -> pu.Puppet | None: if not self.ktid: return None return await pu.Puppet.get_by_ktid(self.ktid) # region KakaoTalk event handling def start_listen(self) -> None: self.listen_task = asyncio.create_task(self._try_listen()) def _disconnect_listener_after_error(self) -> None: self.log.info("TODO: _disconnect_listener_after_error") async def _try_listen(self) -> None: try: # TODO Pass all listeners to start_listen instead of registering them one-by-one? await self.client.start_listen() await self.client.on_message(self.on_message) # TODO Handle auth errors specially? #except AuthenticationRequired as e: except Exception: #self.is_connected = False self.log.exception("Fatal error in listener") await self.send_bridge_notice( "Fatal error in listener (see logs for more info)", state_event=BridgeStateEvent.UNKNOWN_ERROR, important=True, error_code="kt-connection-error", ) self._disconnect_listener_after_error() def stop_listen(self) -> None: self.log.info("TODO: stop_listen") async def on_logged_in(self, oauth_credential: OAuthCredential) -> None: self.log.debug(f"Successfully logged in as {oauth_credential.userId}") self.oauth_credential = oauth_credential self.client = Client(self, log=self.log.getChild("ktclient")) await self.save() self._is_logged_in = True try: self._logged_in_info = await self.client.fetch_logged_in_user(post_login=True) self._logged_in_info_time = time.monotonic() except Exception: self.log.exception("Failed to fetch post-login info") self.stop_listen() asyncio.create_task(self.post_login(is_startup=True)) @async_time(METRIC_MESSAGE) async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None: portal = await po.Portal.get_by_ktid( channel_id, kt_receiver=self.ktid, kt_type=channel_type ) puppet = await pu.Puppet.get_by_ktid(evt.sender.userId) await portal.backfill_lock.wait(evt.logId) if not puppet.name: portal.schedule_resync(self, puppet) # TODO reply_to await portal.handle_remote_message(self, puppet, evt) # TODO Many more handlers # endregion