# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . """ Client functionality for the KakaoTalk API. Currently a wrapper around a Node backend, but the abstraction used here should be compatible with any other potential backend. """ from __future__ import annotations from typing import TYPE_CHECKING, cast, Type, Optional, Union import logging import urllib.request from aiohttp import ClientSession from aiohttp.client import _RequestContextManager from yarl import URL from mautrix.util.logging import TraceLogger from ...config import Config from ...rpc import RPCClient from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct from ..types.bson import Long from ..types.client.client_session import LoginResult from ..types.chat.chat import Chatlog from ..types.oauth import OAuthCredential, OAuthInfo from ..types.request import ( deserialize_result, ResultType, ResultListType, RootCommandResult, CommandResultDoneValue ) from .types import PortalChannelInfo, UserInfoUnion from .errors import InvalidAccessToken from .error_helper import raise_unsuccessful_response try: from aiohttp_socks import ProxyConnector except ImportError: ProxyConnector = None if TYPE_CHECKING: from mautrix.types import JSON from ...user import User # TODO Consider defining an interface for this, with node/native backend as swappable implementations # TODO If no state is stored, consider using free functions instead of classmethods class Client: _rpc_client: RPCClient @classmethod async def init_cls(cls, config: Config) -> None: """Initialize RPC to the Node backend.""" cls._rpc_client = RPCClient(config) await cls._rpc_client.connect() @classmethod async def stop_cls(cls) -> None: """Stop and disconnect from the Node backend.""" await cls._rpc_client.disconnect() # region tokenless commands @classmethod async def generate_uuid(cls, used_uuids: set[str]) -> str: """Randomly generate a UUID for a (fake) device.""" tries_remaining = 10 while True: uuid = await cls._rpc_client.request("generate_uuid") if uuid not in used_uuids: return uuid tries_remaining -= 1 if tries_remaining == 0: raise Exception( "Unable to generate a UUID that hasn't been used before. " "Either use a different RNG, or buy a lottery ticket" ) @classmethod async def register_device(cls, passcode: str, **req: JSON) -> None: """Register a (fake) device that will be associated with the provided login credentials.""" await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req) @classmethod async def login(cls, **req: JSON) -> OAuthCredential: """ Obtain a session token by logging in with user-provided credentials. Must have first called register_device with these credentials. """ # NOTE Actually returns a LoginData object, but this only needs an OAuthCredential return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req) # endregion http: ClientSession log: TraceLogger def __init__(self, user: User, log: Optional[TraceLogger] = None): """Create a per-user client object for user-specific client functionality.""" self.user = user # TODO Let the Node backend use a proxy too! connector = None try: http_proxy = urllib.request.getproxies()["http"] except KeyError: pass else: if ProxyConnector: connector = ProxyConnector.from_url(http_proxy) else: self.log.warning("http_proxy is set, but aiohttp-socks is not installed") self.http = ClientSession(connector=connector) self.log = log or logging.getLogger("mw.ktclient") @property def _oauth_credential(self) -> JSON: return self.user.oauth_credential.serialize() def _get_user_data(self) -> JSON: return dict( mxid=self.user.mxid, oauth_credential=self._oauth_credential ) # region HTTP def get( self, url: Union[str, URL], headers: Optional[dict[str, str]] = None, **kwargs, ) -> _RequestContextManager: # TODO Is auth ever needed? headers = { # TODO Are any default headers needed? #**self._headers, **(headers or {}), } url = URL(url) return self.http.get(url, headers=headers, **kwargs) # endregion # region post-token commands async def renew(self) -> OAuthInfo: """Get a new set of tokens from a refresh token.""" return await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential) async def renew_and_save(self) -> None: """Renew and save the user's session tokens.""" oauth_info = await self.renew() self.user.oauth_credential = oauth_info.credential await self.user.save() async def connect(self) -> LoginResult: """ Start a new session by providing a token obtained from a prior login. Receive a snapshot of account state in response. """ login_result = await self._api_user_request_result(LoginResult, "connect") assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" # TODO Skip if handlers are already listening. But this is idempotent and thus probably safe self._start_listen() return login_result async def disconnect(self) -> bool: connection_existed = await self._rpc_client.request("disconnect", mxid=self.user.mxid) self._stop_listen() return connection_existed async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct: profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") return profile_req_struct.profile async def get_profile(self, user_id: Long) -> ProfileStruct: profile_req_struct = await self._api_user_request_result( ProfileReqStruct, "get_profile", user_id=user_id.serialize() ) return profile_req_struct.profile async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo: return await self._api_user_request_result( PortalChannelInfo, "get_portal_channel_info", channel_id=channel_id.serialize() ) async def get_participants(self, channel_id: Long) -> list[UserInfoUnion]: return await self._api_user_request_result( ResultListType(UserInfoUnion), "get_participants", channel_id=channel_id.serialize() ) async def get_chats(self, channel_id: Long, sync_from: Long | None, limit: int | None) -> list[Chatlog]: return await self._api_user_request_result( ResultListType(Chatlog), "get_chats", channel_id=channel_id.serialize(), sync_from=sync_from.serialize() if sync_from else None, limit=limit ) async def send_message(self, channel_id: Long, text: str) -> Chatlog: return await self._api_user_request_result( Chatlog, "send_message", channel_id=channel_id.serialize(), text=text ) # TODO Combine these into one async def _api_user_request_result( self, result_type: Type[ResultType], command: str, **data: JSON ) -> ResultType: renewed = False while True: try: return await self._api_request_result(result_type, command, **self._get_user_data(), **data) except InvalidAccessToken: if renewed: raise await self.renew_and_save() renewed = True async def _api_user_request_void(self, command: str, **data: JSON) -> None: renewed = False while True: try: return await self._api_request_void(command, **self._get_user_data(), **data) except InvalidAccessToken: if renewed: raise await self.renew_and_save() renewed = True # endregion # region listeners async def _on_message(self, data: dict[str, JSON]) -> None: await self.user.on_message( Chatlog.deserialize(data["chatlog"]), Long.deserialize(data["channelId"]), data["channelType"] ) """ TODO async def _on_receipt(self, data: Dict[str, JSON]) -> None: await self.user.on_receipt(Receipt.deserialize(data["receipt"])) """ def _start_listen(self) -> None: # TODO Automate this somehow, like with a fancy enum self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [self._on_message]) # TODO many more listeners def _stop_listen(self) -> None: # TODO Automate this somehow, like with a fancy enum self._rpc_client.set_event_handlers(self._get_user_cmd("message"), []) # TODO many more listeners def _get_user_cmd(self, command) -> str: return f"{command}:{self.user.mxid}" # endregion @classmethod async def _api_request_result( cls, result_type: Type[ResultType], command: str, **data: JSON ) -> ResultType: """ Call a command via RPC, and return its result object. On failure, raise an appropriate exception. """ resp = deserialize_result( result_type, await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp) # NOTE Not asserting against CommandResultDoneValue because it's generic! # TODO Check if there really is no way to do it. assert type(resp) is not RootCommandResult, "Result object missing from successful response" return cast(CommandResultDoneValue[ResultType], resp).result @classmethod async def _api_request_void(cls, command: str, **data: JSON) -> None: resp = RootCommandResult.deserialize( await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp)