# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. """ Client functionality for the KakaoTalk API. Currently a wrapper around a Node backend, but the abstraction used here should be compatible with any other potential backend. """ from __future__ import annotations from typing import TYPE_CHECKING, cast, Awaitable, Type, Optional, Union import asyncio from contextlib import asynccontextmanager import logging import urllib.request from aiohttp import ClientSession from aiohttp.client import _RequestContextManager from yarl import URL from mautrix.types import SerializerError from mautrix.util.logging import TraceLogger from ...config import Config from ...rpc import EventHandler, RPCClient from ..types.api.struct import ( FriendListStruct, FriendReqStruct, FriendStruct, ProfileReqStruct, ProfileStruct, ) from ..types.bson import Long from ..types.channel.channel_info import ChannelInfo from ..types.chat import Chatlog, KnownChatType from ..types.chat.attachment import MentionStruct, ReplyAttachment from ..types.client.client_session import LoginResult from ..types.oauth import OAuthCredential, OAuthInfo from ..types.openlink.open_link_type import OpenChannelUserPerm from ..types.openlink.open_link_user_info import OpenLinkChannelUserInfo from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes from ..types.request import ( deserialize_result, ResultType, ResultListType, ResultRawType, RootCommandResult, CommandResultDoneValue ) from .types import ( ChannelProps, PortalChannelInfo, PortalChannelParticipantInfo, Receipt, SettingsStruct, UserInfoUnion, ) from .errors import InvalidAccessToken, CommandException from .error_helper import raise_unsuccessful_response try: from aiohttp_socks import ProxyConnector except ImportError: ProxyConnector = None if TYPE_CHECKING: from mautrix.types import JSON from ... import user as u @asynccontextmanager async def sandboxed_get(url: URL) -> _RequestContextManager: async with ClientSession() as sess, sess.get(url) as resp: yield resp # TODO Consider defining an interface for this, with node/native backend as swappable implementations # TODO If no state is stored, consider using free functions instead of classmethods class Client: _rpc_client: RPCClient @classmethod def init_cls(cls, config: Config) -> None: """Initialize RPC to the Node backend.""" cls._rpc_client = RPCClient(config, "kakaotalk") # NOTE No need to store this, as cancelling the RPCClient will cancel this too asyncio.create_task(cls._keep_connected()) @classmethod async def _keep_connected(cls) -> None: while True: await cls._rpc_client.connect() await cls._rpc_client.wait_for_disconnection() @classmethod def wait_for_connection(cls) -> Awaitable[None]: return cls._rpc_client.wait_for_connection() @classmethod def stop_cls(cls) -> None: """Stop and disconnect from the Node backend.""" cls._rpc_client.cancel() # region tokenless commands @classmethod async def generate_uuid(cls, used_uuids: set[str]) -> str: """Randomly generate a UUID for a (fake) device.""" tries_remaining = 10 while True: uuid = await cls._rpc_client.request("generate_uuid") if uuid not in used_uuids: return uuid tries_remaining -= 1 if tries_remaining == 0: raise Exception( "Unable to generate a UUID that hasn't been used before. " "Either use a different RNG, or buy a lottery ticket" ) @classmethod def register_device(cls, passcode: str, **req: JSON) -> Awaitable[None]: """Register a (fake) device that will be associated with the provided login credentials.""" return cls._api_request_void("register_device", passcode=passcode, **req) @classmethod def login(cls, uuid: str, form: JSON, forced: bool) -> Awaitable[OAuthCredential]: """ Obtain a session token by logging in with user-provided credentials. Must have first called register_device with these credentials. """ # NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential return cls._api_request_result( OAuthCredential, "login", uuid=uuid, form=form, forced=forced, ) # endregion user: u.User _rpc_disconnection_task: asyncio.Task | None http: ClientSession log: TraceLogger _handler_methods: list[str] def __init__(self, user: u.User, log: Optional[TraceLogger] = None): """Create a per-user client object for user-specific client functionality.""" self.user = user self._rpc_disconnection_task = None self._handler_methods = [] # TODO Let the Node backend use a proxy too! connector = None try: http_proxy = urllib.request.getproxies()["http"] except KeyError: pass else: if ProxyConnector: connector = ProxyConnector.from_url(http_proxy) else: self.log.warning("http_proxy is set, but aiohttp-socks is not installed") self.http = ClientSession(connector=connector) self.log = log or logging.getLogger("mw.ktclient") @property def _oauth_credential(self) -> JSON: return self.user.oauth_credential.serialize() # region HTTP def get( self, url: Union[str, URL], headers: Optional[dict[str, str]] = None, sandbox: bool = False, **kwargs, ) -> _RequestContextManager: # TODO Is auth ever needed? headers = { # TODO Are any default headers needed? #**self._headers, **(headers or {}), } url = URL(url) if sandbox: return sandboxed_get(url) return self.http.get(url, headers=headers, **kwargs) # endregion # region post-token commands async def start(self) -> SettingsStruct | None: """ Initialize user-specific bridging & state by providing a token obtained from a prior login. Receive the user's profile info in response. """ try: settings_struct = await self._api_user_cred_request_result(SettingsStruct, "start") except SerializerError: self.log.exception("Unable to deserialize settings struct, but starting client anyways") settings_struct = None if not self._rpc_disconnection_task: self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler()) else: self.log.warning("Called \"start\" on an already-started client") return settings_struct async def stop(self) -> None: """Immediately stop bridging this user.""" self._stop_listen() if self._rpc_disconnection_task: self._rpc_disconnection_task.cancel() else: self.log.warning("Called \"stop\" on an already-stopped client") await self._rpc_client.request("stop", mxid=self.user.mxid) async def _rpc_disconnection_handler(self) -> None: await self._rpc_client.wait_for_disconnection() self._rpc_disconnection_task = None self._stop_listen() asyncio.create_task(self.user.on_client_disconnect()) async def renew_and_save(self) -> None: """Renew and save the user's session tokens.""" oauth_info = await self._api_user_cred_request_result(OAuthInfo, "renew", renew=False) self.user.oauth_credential = oauth_info.credential await self.user.save() async def connect(self) -> LoginResult | None: """ Start a new talk session by providing a token obtained from a prior login. Receive a snapshot of account state in response. """ try: login_result = await self._api_user_cred_request_result(LoginResult, "connect") assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" except SerializerError: self.log.exception("Unable to deserialize login result, but connecting anyways") login_result = None # TODO Skip if handlers are already listening. But this is idempotent and thus probably safe self._start_listen() return login_result async def disconnect(self) -> None: """Disconnect from the talk session, but remain logged in.""" await self._rpc_client.request("disconnect", mxid=self.user.mxid) await self._on_disconnect(None) def is_connected(self) -> Awaitable[bool]: return self._rpc_client.request("is_connected", mxid=self.user.mxid) def get_settings(self) -> Awaitable[SettingsStruct]: return self._api_user_request_result(SettingsStruct, "get_settings") async def get_own_profile(self) -> ProfileStruct: profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") return profile_req_struct.profile async def get_profile(self, user_id: Long) -> ProfileStruct: profile_req_struct = await self._api_user_request_result( ProfileReqStruct, "get_profile", user_id=user_id.serialize(), ) return profile_req_struct.profile def get_portal_channel_info(self, channel_props: ChannelProps) -> Awaitable[PortalChannelInfo]: return self._api_user_request_result( PortalChannelInfo, "get_portal_channel_info", channel_props=channel_props.serialize(), ) def get_portal_channel_participant_info(self, channel_props: ChannelProps) -> Awaitable[PortalChannelParticipantInfo]: return self._api_user_request_result( PortalChannelParticipantInfo, "get_portal_channel_participant_info", channel_props=channel_props.serialize(), ) def get_participants(self, channel_props: ChannelProps) -> Awaitable[list[UserInfoUnion]]: return self._api_user_request_result( ResultListType(UserInfoUnion), "get_participants", channel_props=channel_props.serialize(), ) def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> Awaitable[list[Chatlog]]: return self._api_user_request_result( ResultListType(Chatlog), "get_chats", channel_props=channel_props.serialize(), sync_from=sync_from.serialize() if sync_from else None, limit=limit, ) def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]: return self._api_user_request_result( ResultListType(Receipt), "get_read_receipts", channel_props=channel_props.serialize(), unread_chat_ids=[c.serialize() for c in unread_chat_ids], ) async def can_change_uuid(self, uuid: str) -> bool: try: await self._api_user_request_void("can_change_uuid", uuid=uuid) except CommandException: return False else: return True def change_uuid(self, uuid: str) -> Awaitable[None]: return self._api_user_request_void("change_uuid", uuid=uuid) def set_uuid_searchable(self, searchable: bool) -> Awaitable[None]: return self._api_user_request_void("set_uuid_searchable", searchable=searchable) def list_friends(self) -> Awaitable[FriendListStruct]: return self._api_user_request_result( FriendListStruct, "list_friends", ) async def edit_friend(self, ktid: Long, add: bool) -> FriendStruct | None: try: friend_req_struct = await self._api_user_request_result( FriendReqStruct, "edit_friend", user_id=ktid.serialize(), add=add, ) return friend_req_struct.friend except SerializerError: self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless") return None async def edit_friend_by_uuid(self, uuid: str, add: bool) -> FriendStruct | None: try: friend_req_struct = await self._api_user_request_result( FriendReqStruct, "edit_friend_by_uuid", uuid=uuid, add=add, ) return friend_req_struct.friend except SerializerError: self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless") return None async def get_friend_dm_id(self, friend_id: Long) -> Long | None: try: return await self._api_user_request_result( Long, "get_friend_dm_id", friend_id=friend_id.serialize(), ) except CommandException: self.log.exception(f"Could not find friend with ID {friend_id}") return None async def get_memo_ids(self) -> list[Long]: return ResultListType(Long).deserialize( await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid) ) def download_file( self, channel_props: ChannelProps, key: str, ) -> Awaitable[bytes]: return self._api_user_request_result( ResultRawType(bytes), "download_file", channel_props=channel_props.serialize(), key=key, ) def send_chat( self, channel_props: ChannelProps, text: str, reply_to: ReplyAttachment | None, mentions: list[MentionStruct] | None, ) -> Awaitable[Long]: return self._api_user_request_result( Long, "send_chat", channel_props=channel_props.serialize(), text=text, reply_to=reply_to.serialize() if reply_to is not None else None, mentions=[m.serialize() for m in mentions] if mentions is not None else None, ) def send_media( self, channel_props: ChannelProps, media_type: KnownChatType, data: bytes, filename: str, *, width: int | None = None, height: int | None = None, ext: str | None = None, ) -> Awaitable[Long]: return self._api_user_request_result( Long, "send_media", channel_props=channel_props.serialize(), type=media_type, data=list(data), name=filename, width=width, height=height, ext=ext, ) def delete_chat( self, channel_props: ChannelProps, chat_id: Long, ) -> Awaitable[None]: return self._api_user_request_void( "delete_chat", channel_props=channel_props.serialize(), chat_id=chat_id.serialize(), ) def mark_read( self, channel_props: ChannelProps, read_until_chat_id: Long, ) -> Awaitable[None]: return self._api_user_request_void( "mark_read", channel_props=channel_props.serialize(), read_until_chat_id=read_until_chat_id.serialize(), ) def send_perm( self, channel_props: ChannelProps, user_id: Long, perm: OpenChannelUserPerm, ) -> Awaitable[None]: return self._api_user_request_void( "send_perm", channel_props=channel_props.serialize(), user_id=user_id.serialize(), perm=perm, ) def set_channel_name( self, channel_props: ChannelProps, name: str, ) -> Awaitable[None]: return self._api_user_request_void( "set_channel_name", channel_props=channel_props.serialize(), name=name, ) def set_channel_description( self, channel_props: ChannelProps, description: str, ) -> Awaitable[None]: return self._api_user_request_void( "set_channel_description", channel_props=channel_props.serialize(), description=description, ) """ TODO def set_channel_photo( self, channel_props: ChannelProps, photo_url: str, ) -> Awaitable[None]: return self._api_user_request_void( "set_channel_photo", channel_props=channel_props.serialize(), photo_url=photo_url, ) """ def create_direct_chat(self, ktid: Long) -> Awaitable[Long]: return self._api_user_request_result( Long, "create_direct_chat", user_id=ktid.serialize(), ) def leave_channel( self, channel_props: ChannelProps, ) -> Awaitable[None]: return self._api_user_request_void( "leave_channel", channel_props=channel_props.serialize(), ) # TODO Combine each of these pairs into one async def _api_user_request_result( self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON ) -> ResultType: while True: try: return await self._api_request_result(result_type, command, mxid=self.user.mxid, **data) except InvalidAccessToken: if not renew: raise await self.renew_and_save() renew = False async def _api_user_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None: while True: try: return await self._api_request_void(command, mxid=self.user.mxid, **data) except InvalidAccessToken: if not renew: raise await self.renew_and_save() renew = False async def _api_user_cred_request_result( self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON ) -> ResultType: while True: try: return await self._api_user_request_result( result_type, command, oauth_credential=self._oauth_credential, renew=False, **data ) except InvalidAccessToken: if not renew: raise await self.renew_and_save() renew = False async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None: while True: try: await self._api_user_request_void( command, oauth_credential=self._oauth_credential, renew=False, **data ) except InvalidAccessToken: if not renew: raise await self.renew_and_save() renew = False # endregion # region listeners def _on_chat(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_chat( Chatlog.deserialize(data["chatlog"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_chat_deleted(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_chat_deleted( Long.deserialize(data["chatId"]), Long.deserialize(data["senderId"]), int(data["timestamp"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_chat_read(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_chat_read( Long.deserialize(data["chatId"]), Long.deserialize(data["senderId"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_profile_changed(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_profile_changed( OpenLinkChannelUserInfo.deserialize(data["info"]), ) def _on_perm_changed(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_perm_changed( Long.deserialize(data["userId"]), OpenChannelUserPerm(data["perm"]), Long.deserialize(data["senderId"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_channel_added(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_channel_added( ChannelInfo.deserialize(data["channelInfo"]), ) def _on_channel_join(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_channel_join( ChannelInfo.deserialize(data["channelInfo"]), ) def _on_channel_left(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_channel_left( Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_channel_kicked(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_channel_kicked( Long.deserialize(data["userId"]), Long.deserialize(data["senderId"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_user_join(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_user_join( Long.deserialize(data["userId"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_user_left(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_user_left( Long.deserialize(data["userId"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_channel_meta_change(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_channel_meta_change( PortalChannelInfo.deserialize(data["info"]), Long.deserialize(data["channelId"]), str(data["channelType"]), ) def _on_listen_disconnect(self, data: dict[str, JSON]) -> Awaitable[None]: try: res = KickoutRes.deserialize(data) except Exception: self.log.exception("Invalid kickout reason, defaulting to None") res = None return self._on_disconnect(res) def _on_switch_server(self, _: dict[str, JSON]) -> Awaitable[None]: # TODO Reconnect automatically instead return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER)) def _on_disconnect(self, res: KickoutRes | None) -> Awaitable[None]: self._stop_listen() return self.user.on_disconnect(res) def _on_error(self, data: dict[str, JSON]) -> Awaitable[None]: return self.user.on_error(data) def _start_listen(self) -> None: self._add_event_handler("chat", self._on_chat) self._add_event_handler("chat_deleted", self._on_chat_deleted) self._add_event_handler("chat_read", self._on_chat_read) self._add_event_handler("profile_changed", self._on_profile_changed) self._add_event_handler("perm_changed", self._on_perm_changed) self._add_event_handler("channel_added", self._on_channel_added) self._add_event_handler("channel_join", self._on_channel_join) self._add_event_handler("channel_left", self._on_channel_left) self._add_event_handler("channel_kicked", self._on_channel_kicked) self._add_event_handler("user_join", self._on_user_join) self._add_event_handler("user_left", self._on_user_left) self._add_event_handler("channel_meta_change", self._on_channel_meta_change) self._add_event_handler("disconnected", self._on_listen_disconnect) self._add_event_handler("switch_server", self._on_switch_server) self._add_event_handler("error", self._on_error) def _stop_listen(self) -> None: for method in self._handler_methods: self._rpc_client.set_event_handlers(self._get_user_cmd(method), []) def _add_event_handler(self, method: str, handler: EventHandler): self._rpc_client.set_event_handlers(self._get_user_cmd(method), [handler]) self._handler_methods.append(method) def _get_user_cmd(self, command) -> str: return f"{command}:{self.user.mxid}" # endregion @classmethod async def _api_request_result( cls, result_type: Type[ResultType], command: str, **data: JSON ) -> ResultType: """ Call a command via RPC, and return its result object. On failure, raise an appropriate exception. """ resp = deserialize_result( result_type, await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp) # NOTE Not asserting against CommandResultDoneValue because it's generic! # TODO Check if there really is no way to do it. assert type(resp) is not RootCommandResult, "Result object missing from successful response" return cast(CommandResultDoneValue[ResultType], resp).result @classmethod async def _api_request_void(cls, command: str, **data: JSON) -> None: resp = RootCommandResult.deserialize( await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp)