# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . """ Client functionality for the KakaoTalk API. Currently a wrapper around a Node backend, but the abstraction used here should be compatible with any other potential backend. """ from __future__ import annotations from typing import TYPE_CHECKING, cast, Type, Optional, Union from contextlib import asynccontextmanager import logging import urllib.request from aiohttp import ClientSession from aiohttp.client import _RequestContextManager from yarl import URL from mautrix.util.logging import TraceLogger from ...config import Config from ...rpc import RPCClient from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct from ..types.api.struct import FriendListStruct from ..types.bson import Long from ..types.client.client_session import LoginResult from ..types.chat import Chatlog, KnownChatType from ..types.oauth import OAuthCredential, OAuthInfo from ..types.request import ( deserialize_result, ResultType, ResultListType, RootCommandResult, CommandResultDoneValue ) from .types import PortalChannelInfo, UserInfoUnion, ChannelProps from .errors import InvalidAccessToken from .error_helper import raise_unsuccessful_response try: from aiohttp_socks import ProxyConnector except ImportError: ProxyConnector = None if TYPE_CHECKING: from mautrix.types import JSON from ...user import User @asynccontextmanager async def sandboxed_get(url: URL) -> _RequestContextManager: async with ClientSession() as sess, sess.get(url) as resp: yield resp # TODO Consider defining an interface for this, with node/native backend as swappable implementations # TODO If no state is stored, consider using free functions instead of classmethods class Client: _rpc_client: RPCClient @classmethod async def init_cls(cls, config: Config) -> None: """Initialize RPC to the Node backend.""" cls._rpc_client = RPCClient(config) await cls._rpc_client.connect() @classmethod async def stop_cls(cls) -> None: """Stop and disconnect from the Node backend.""" await cls._rpc_client.disconnect() # region tokenless commands @classmethod async def generate_uuid(cls, used_uuids: set[str]) -> str: """Randomly generate a UUID for a (fake) device.""" tries_remaining = 10 while True: uuid = await cls._rpc_client.request("generate_uuid") if uuid not in used_uuids: return uuid tries_remaining -= 1 if tries_remaining == 0: raise Exception( "Unable to generate a UUID that hasn't been used before. " "Either use a different RNG, or buy a lottery ticket" ) @classmethod async def register_device(cls, passcode: str, **req: JSON) -> None: """Register a (fake) device that will be associated with the provided login credentials.""" await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req) @classmethod async def login(cls, **req: JSON) -> OAuthCredential: """ Obtain a session token by logging in with user-provided credentials. Must have first called register_device with these credentials. """ # NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req) # endregion http: ClientSession log: TraceLogger def __init__(self, user: User, log: Optional[TraceLogger] = None): """Create a per-user client object for user-specific client functionality.""" self.user = user # TODO Let the Node backend use a proxy too! connector = None try: http_proxy = urllib.request.getproxies()["http"] except KeyError: pass else: if ProxyConnector: connector = ProxyConnector.from_url(http_proxy) else: self.log.warning("http_proxy is set, but aiohttp-socks is not installed") self.http = ClientSession(connector=connector) self.log = log or logging.getLogger("mw.ktclient") @property def _oauth_credential(self) -> JSON: return self.user.oauth_credential.serialize() def _get_user_data(self) -> JSON: return dict( mxid=self.user.mxid, oauth_credential=self._oauth_credential ) # region HTTP def get( self, url: Union[str, URL], headers: Optional[dict[str, str]] = None, sandbox: bool = False, **kwargs, ) -> _RequestContextManager: # TODO Is auth ever needed? headers = { # TODO Are any default headers needed? #**self._headers, **(headers or {}), } url = URL(url) if sandbox: return sandboxed_get(url) return self.http.get(url, headers=headers, **kwargs) # endregion # region post-token commands async def renew(self) -> OAuthInfo: """Get a new set of tokens from a refresh token.""" return await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential) async def renew_and_save(self) -> None: """Renew and save the user's session tokens.""" oauth_info = await self.renew() self.user.oauth_credential = oauth_info.credential await self.user.save() async def connect(self) -> LoginResult: """ Start a new session by providing a token obtained from a prior login. Receive a snapshot of account state in response. """ login_result = await self._api_user_request_result(LoginResult, "connect") assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" # TODO Skip if handlers are already listening. But this is idempotent and thus probably safe self._start_listen() return login_result async def disconnect(self) -> bool: connection_existed = await self._rpc_client.request("disconnect", mxid=self.user.mxid) self._stop_listen() return connection_existed async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct: profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") return profile_req_struct.profile async def get_profile(self, user_id: Long) -> ProfileStruct: profile_req_struct = await self._api_user_request_result( ProfileReqStruct, "get_profile", user_id=user_id.serialize(), ) return profile_req_struct.profile async def get_portal_channel_info(self, channel_props: ChannelProps) -> PortalChannelInfo: return await self._api_user_request_result( PortalChannelInfo, "get_portal_channel_info", channel_props=channel_props.serialize(), ) async def get_participants(self, channel_props: ChannelProps) -> list[UserInfoUnion]: return await self._api_user_request_result( ResultListType(UserInfoUnion), "get_participants", channel_props=channel_props.serialize(), ) async def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> list[Chatlog]: return await self._api_user_request_result( ResultListType(Chatlog), "get_chats", channel_props=channel_props.serialize(), sync_from=sync_from.serialize() if sync_from else None, limit=limit, ) async def list_friends(self) -> FriendListStruct: return await self._api_user_request_result( FriendListStruct, "list_friends", ) async def send_message(self, channel_props: ChannelProps, text: str) -> Chatlog: return await self._api_user_request_result( Chatlog, "send_message", channel_props=channel_props.serialize(), text=text, ) async def send_media( self, channel_props: ChannelProps, media_type: KnownChatType, data: bytes, filename: str, *, width: int | None = None, height: int | None = None, ext: str | None = None, ) -> Chatlog: return await self._api_user_request_result( Chatlog, "send_media", channel_props=channel_props.serialize(), type=media_type, data=list(data), name=filename, width=width, height=height, ext=ext, # Don't log the bytes # TODO Disable logging per-argument, not per-command is_secret=True ) # TODO Combine these into one async def _api_user_request_result( self, result_type: Type[ResultType], command: str, **data: JSON ) -> ResultType: renewed = False while True: try: return await self._api_request_result(result_type, command, **self._get_user_data(), **data) except InvalidAccessToken: if renewed: raise await self.renew_and_save() renewed = True async def _api_user_request_void(self, command: str, **data: JSON) -> None: renewed = False while True: try: return await self._api_request_void(command, **self._get_user_data(), **data) except InvalidAccessToken: if renewed: raise await self.renew_and_save() renewed = True # endregion # region listeners async def _on_message(self, data: dict[str, JSON]) -> None: await self.user.on_message( Chatlog.deserialize(data["chatlog"]), Long.deserialize(data["channelId"]), data["channelType"] ) """ TODO async def _on_receipt(self, data: Dict[str, JSON]) -> None: await self.user.on_receipt(Receipt.deserialize(data["receipt"])) """ def _start_listen(self) -> None: # TODO Automate this somehow, like with a fancy enum self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [self._on_message]) # TODO many more listeners def _stop_listen(self) -> None: # TODO Automate this somehow, like with a fancy enum self._rpc_client.set_event_handlers(self._get_user_cmd("message"), []) # TODO many more listeners def _get_user_cmd(self, command) -> str: return f"{command}:{self.user.mxid}" # endregion @classmethod async def _api_request_result( cls, result_type: Type[ResultType], command: str, **data: JSON ) -> ResultType: """ Call a command via RPC, and return its result object. On failure, raise an appropriate exception. """ resp = deserialize_result( result_type, await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp) # NOTE Not asserting against CommandResultDoneValue because it's generic! # TODO Check if there really is no way to do it. assert type(resp) is not RootCommandResult, "Result object missing from successful response" return cast(CommandResultDoneValue[ResultType], resp).result @classmethod async def _api_request_void(cls, command: str, **data: JSON) -> None: resp = RootCommandResult.deserialize( await cls._rpc_client.request(command, **data) ) if not resp.success: raise_unsuccessful_response(resp)