# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . """ Client functionality for the KakaoTalk API. Currently a wrapper around a Node backend, but the abstraction used here should be compatible with any other potential backend. """ from __future__ import annotations from typing import TYPE_CHECKING, cast, Type, Optional, Union import asyncio from contextlib import asynccontextmanager import logging import urllib.request from aiohttp import ClientSession from aiohttp.client import _RequestContextManager from yarl import URL from mautrix.util.logging import TraceLogger from ...config import Config from ...rpc import EventHandler, RPCClient from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct from ..types.api.struct import FriendListStruct from ..types.bson import Long from ..types.client.client_session import LoginResult from ..types.chat import Chatlog, KnownChatType from ..types.chat.attachment import MentionStruct, ReplyAttachment from ..types.oauth import OAuthCredential, OAuthInfo from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes from ..types.request import ( deserialize_result, ResultType, ResultListType, RootCommandResult, CommandResultDoneValue ) from .types import PortalChannelInfo, UserInfoUnion, ChannelProps from .errors import InvalidAccessToken, CommandException from .error_helper import raise_unsuccessful_response try: from aiohttp_socks import ProxyConnector except ImportError: ProxyConnector = None if TYPE_CHECKING: from mautrix.types import JSON from ... import user as u @asynccontextmanager async def sandboxed_get(url: URL) -> _RequestContextManager: async with ClientSession() as sess, sess.get(url) as resp: yield resp # TODO Consider defining an interface for this, with node/native backend as swappable implementations # TODO If no state is stored, consider using free functions instead of classmethods class Client: _rpc_client: RPCClient @classmethod def init_cls(cls, config: Config) -> None: """Initialize RPC to the Node backend.""" cls._rpc_client = RPCClient(config) # NOTE No need to store this, as cancelling the RPCClient will cancel this too asyncio.create_task(cls._keep_connected()) @classmethod async def _keep_connected(cls) -> None: while True: await cls._rpc_client.connect() await cls._rpc_client.wait_for_disconnection() @classmethod def stop_cls(cls) -> None: """Stop and disconnect from the Node backend.""" cls._rpc_client.cancel() # region tokenless commands @classmethod async def generate_uuid(cls, used_uuids: set[str]) -> str: """Randomly generate a UUID for a (fake) device.""" tries_remaining = 10 while True: uuid = await cls._rpc_client.request("generate_uuid") if uuid not in used_uuids: return uuid tries_remaining -= 1 if tries_remaining == 0: raise Exception( "Unable to generate a UUID that hasn't been used before. " "Either use a different RNG, or buy a lottery ticket" ) @classmethod async def register_device(cls, passcode: str, **req: JSON) -> None: """Register a (fake) device that will be associated with the provided login credentials.""" await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req) @classmethod async def login(cls, **req: JSON) -> OAuthCredential: """ Obtain a session token by logging in with user-provided credentials. Must have first called register_device with these credentials. """ # NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req) # endregion user: u.User _rpc_disconnection_task: asyncio.Task | None http: ClientSession log: TraceLogger _handler_methods: list[str] def __init__(self, user: u.User, log: Optional[TraceLogger] = None): """Create a per-user client object for user-specific client functionality.""" self.user = user self._rpc_disconnection_task = None self._handler_methods = [] # TODO Let the Node backend use a proxy too! connector = None try: http_proxy = urllib.request.getproxies()["http"] except KeyError: pass else: if ProxyConnector: connector = ProxyConnector.from_url(http_proxy) else: self.log.warning("http_proxy is set, but aiohttp-socks is not installed") self.http = ClientSession(connector=connector) self.log = log or logging.getLogger("mw.ktclient") @property def _oauth_credential(self) -> JSON: return self.user.oauth_credential.serialize() @property def _user_data(self) -> JSON: return { "mxid": self.user.mxid, "oauth_credential": self._oauth_credential, } # region HTTP def get( self, url: Union[str, URL], headers: Optional[dict[str, str]] = None, sandbox: bool = False, **kwargs, ) -> _RequestContextManager: # TODO Is auth ever needed? headers = { # TODO Are any default headers needed? #**self._headers, **(headers or {}), } url = URL(url) if sandbox: return sandboxed_get(url) return self.http.get(url, headers=headers, **kwargs) # endregion # region post-token commands async def start(self) -> ProfileStruct: """ Initialize user-specific bridging & state by providing a token obtained from a prior login. Receive the user's profile info in response. """ profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "start") if not self._rpc_disconnection_task: self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler()) else: self.log.warning("Called \"start\" on an already-started client") return profile_req_struct.profile async def stop(self) -> None: """Immediately stop bridging this user.""" self._stop_listen() if self._rpc_disconnection_task: self._rpc_disconnection_task.cancel() else: self.log.warning("Called \"stop\" on an already-stopped client") await self._rpc_client.request("stop", mxid=self.user.mxid) async def _rpc_disconnection_handler(self) -> None: await self._rpc_client.wait_for_disconnection() self._rpc_disconnection_task = None self._stop_listen() asyncio.create_task(self.user.on_client_disconnect()) async def renew_and_save(self) -> None: """Renew and save the user's session tokens.""" oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential) self.user.oauth_credential = oauth_info.credential await self.user.save() async def connect(self) -> LoginResult: """ Start a new talk session by providing a token obtained from a prior login. Receive a snapshot of account state in response. """ login_result = await self._api_user_request_result(LoginResult, "connect") assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" # TODO Skip if handlers are already listening. But this is idempotent and thus probably safe self._start_listen() return login_result async def disconnect(self) -> None: """Disconnect from the talk session, but remain logged in.""" await self._rpc_client.request("disconnect", mxid=self.user.mxid) await self._on_disconnect(None) async def get_own_profile(self) -> ProfileStruct: profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") return profile_req_struct.profile async def get_profile(self, user_id: Long) -> ProfileStruct: profile_req_struct = await self._api_user_request_result( ProfileReqStruct, "get_profile", user_id=user_id.serialize(), ) return profile_req_struct.profile async def get_portal_channel_info(self, channel_props: ChannelProps) -> PortalChannelInfo: return await self._api_user_request_result( PortalChannelInfo, "get_portal_channel_info", channel_props=channel_props.serialize(), ) async def get_participants(self, channel_props: ChannelProps) -> list[UserInfoUnion]: return await self._api_user_request_result( ResultListType(UserInfoUnion), "get_participants", channel_props=channel_props.serialize(), ) async def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> list[Chatlog]: return await self._api_user_request_result( ResultListType(Chatlog), "get_chats", channel_props=channel_props.serialize(), sync_from=sync_from.serialize() if sync_from else None, limit=limit, ) async def list_friends(self) -> FriendListStruct: return await self._api_user_request_result( FriendListStruct, "list_friends", ) async def get_friend_dm_id(self, friend_id: Long) -> Long | None: try: return await self._api_user_request_result( Long, "get_friend_dm_id", friend_id=friend_id.serialize(), ) except CommandException: self.log.exception(f"Could not find friend with ID {friend_id}") return None async def get_memo_ids(self) -> list[Long]: return ResultListType(Long).deserialize( await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid) ) async def send_message( self, channel_props: ChannelProps, text: str, reply_to: ReplyAttachment | None, mentions: list[MentionStruct] | None, ) -> Chatlog: return await self._api_user_request_result( Chatlog, "send_chat", channel_props=channel_props.serialize(), text=text, reply_to=reply_to.serialize() if reply_to is not None else None, mentions=[m.serialize() for m in mentions] if mentions is not None else None, ) async def send_media( self, channel_props: ChannelProps, media_type: KnownChatType, data: bytes, filename: str, *, width: int | None = None, height: int | None = None, ext: str | None = None, ) -> Chatlog: return await self._api_user_request_result( Chatlog, "send_media", channel_props=channel_props.serialize(), type=media_type, data=list(data), name=filename, width=width, height=height, ext=ext, # Don't log the bytes # TODO Disable logging per-argument, not per-command is_secret=True ) # TODO Combine these into one async def _api_user_request_result( self, result_type: Type[ResultType], command: str, **data: JSON ) -> ResultType: renewed = False while True: try: return await self._api_request_result(result_type, command, **self._user_data, **data) except InvalidAccessToken: if renewed: raise await self.renew_and_save() renewed = True async def _api_user_request_void(self, command: str, **data: JSON) -> None: renewed = False while True: try: return await self._api_request_void(command, **self._user_data, **data) except InvalidAccessToken: if renewed: raise await self.renew_and_save() renewed = True # endregion # region listeners async def _on_message(self, data: dict[str, JSON]) -> None: await self.user.on_message( Chatlog.deserialize(data["chatlog"]), Long.deserialize(data["channelId"]), data["channelType"], ) """ TODO async def _on_receipt(self, data: Dict[str, JSON]) -> None: await self.user.on_receipt(Receipt.deserialize(data["receipt"])) """ async def _on_listen_disconnect(self, data: dict[str, JSON]) -> None: try: res = KickoutRes.deserialize(data) except Exception: self.log.exception("Invalid kickout reason, defaulting to None") res = None await self._on_disconnect(res) async def _on_switch_server(self) -> None: # TODO Reconnect automatically instead await self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER)) async def _on_disconnect(self, res: KickoutRes | None) -> None: self._stop_listen() await self.user.on_disconnect(res) def _start_listen(self) -> None: self._add_event_handler("chat", self._on_message) # TODO many more listeners self._add_event_handler("disconnected", self._on_listen_disconnect) self._add_event_handler("switch_server", self._on_switch_server) def _stop_listen(self) -> None: for method in self._handler_methods: self._rpc_client.set_event_handlers(self._get_user_cmd(method), []) def _add_event_handler(self, method: str, handler: EventHandler): self._rpc_client.set_event_handlers(self._get_user_cmd(method), [handler]) self._handler_methods.append(method) def _get_user_cmd(self, command) -> str: return f"{command}:{self.user.mxid}" # endregion @classmethod async def _api_request_result( cls, result_type: Type[ResultType], command: str, **data: JSON ) -> ResultType: """ Call a command via RPC, and return its result object. On failure, raise an appropriate exception. """ resp = deserialize_result( result_type, await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp) # NOTE Not asserting against CommandResultDoneValue because it's generic! # TODO Check if there really is no way to do it. assert type(resp) is not RootCommandResult, "Result object missing from successful response" return cast(CommandResultDoneValue[ResultType], resp).result @classmethod async def _api_request_void(cls, command: str, **data: JSON) -> None: resp = RootCommandResult.deserialize( await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp)