matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/kt/client/client.py

317 lines
11 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Client functionality for the KakaoTalk API.
Currently a wrapper around a Node backend, but
the abstraction used here should be compatible
with any other potential backend.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, cast, Awaitable, Callable, Type, Optional, Union
import logging
import urllib.request
from aiohttp import ClientSession
from aiohttp.client import _RequestContextManager
from yarl import URL
from mautrix.util.logging import TraceLogger
from ...config import Config
from ...rpc import RPCClient
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
from ..types.bson import Long
from ..types.client.client_session import LoginResult
from ..types.channel.channel_type import ChannelType
from ..types.chat.chat import Chatlog
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.request import (
deserialize_result,
ResultType,
ResultListType,
RootCommandResult,
CommandResultDoneValue
)
from .types import PortalChannelInfo
from .errors import InvalidAccessToken
from .error_helper import raise_unsuccessful_response
try:
from aiohttp_socks import ProxyConnector
except ImportError:
ProxyConnector = None
if TYPE_CHECKING:
from mautrix.types import JSON
from ...user import User
from ...rpc.rpc import EventHandler
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
# TODO If no state is stored, consider using free functions instead of classmethods
class Client:
_rpc_client: RPCClient
@classmethod
async def init_cls(cls, config: Config) -> None:
"""Initialize RPC to the Node backend."""
cls._rpc_client = RPCClient(config)
await cls._rpc_client.connect()
@classmethod
async def stop_cls(cls) -> None:
"""Stop and disconnect from the Node backend."""
await cls._rpc_client.request("stop")
await cls._rpc_client.disconnect()
# region tokenless commands
@classmethod
async def generate_uuid(cls, used_uuids: set[str]) -> str:
"""Randomly generate a UUID for a (fake) device."""
tries_remaining = 10
while True:
uuid = await cls._rpc_client.request("generate_uuid")
if uuid not in used_uuids:
return uuid
tries_remaining -= 1
if tries_remaining == 0:
raise Exception(
"Unable to generate a UUID that hasn't been used before. "
"Either use a different RNG, or buy a lottery ticket"
)
@classmethod
async def register_device(cls, passcode: str, **req) -> None:
"""Register a (fake) device that will be associated with the provided login credentials."""
await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req)
@classmethod
async def login(cls, **req) -> OAuthCredential:
"""
Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials.
"""
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
# endregion
http: ClientSession
log: TraceLogger
def __init__(self, user: User, log: Optional[TraceLogger] = None):
"""Create a per-user client object for user-specific client functionality."""
self.user = user
# TODO Let the Node backend use a proxy too!
connector = None
try:
http_proxy = urllib.request.getproxies()["http"]
except KeyError:
pass
else:
if ProxyConnector:
connector = ProxyConnector.from_url(http_proxy)
else:
self.log.warning("http_proxy is set, but aiohttp-socks is not installed")
self.http = ClientSession(connector=connector)
self.log = log or logging.getLogger("mw.ktclient")
@property
def _oauth_credential(self) -> JSON:
return self.user.oauth_credential.serialize()
def _get_user_data(self) -> JSON:
return dict(
mxid=self.user.mxid,
oauth_credential=self._oauth_credential
)
# region HTTP
def get(
self,
url: Union[str, URL],
headers: Optional[dict[str, str]] = None,
**kwargs,
) -> _RequestContextManager:
# TODO Is auth ever needed?
headers = {
# TODO Are any default headers needed?
#**self._headers,
**(headers or {}),
}
url = URL(url)
return self.http.get(url, headers=headers, **kwargs)
# endregion
# region post-token commands
async def renew(self) -> OAuthInfo:
"""Get a new set of tokens from a refresh token."""
return await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens."""
oauth_info = await self.renew()
self.user.oauth_credential = oauth_info.credential
await self.user.save()
async def start(self) -> LoginResult:
"""
Start a new session by providing a token obtained from a prior login.
Receive a snapshot of account state in response.
"""
login_result = await self._api_user_request_result(LoginResult, "start")
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
return login_result
"""
async def is_connected(self) -> bool:
resp = await self._rpc_client.request("is_connected")
return resp["is_connected"]
"""
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile
async def get_profile(self, user_id: Long) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(
ProfileReqStruct,
"get_profile",
user_id=user_id.serialize()
)
return profile_req_struct.profile
async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo:
return await self._api_user_request_result(
PortalChannelInfo,
"get_portal_channel_info",
channel_id=channel_id.serialize()
)
async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
return (await self._api_user_request_result(
ResultListType(Chatlog),
"get_chats",
channel_id=channel_id.serialize(),
sync_from=sync_from.serialize() if sync_from else None
))[-limit if limit else 0:]
async def send_message(self, channel_id: Long, text: str) -> Chatlog:
return await self._api_user_request_result(
Chatlog,
"send_message",
channel_id=channel_id.serialize(),
text=text
)
async def start_listen(self) -> None:
# TODO Connect all listeners here?
await self._api_user_request_void("start_listen")
async def stop(self) -> None:
# TODO Stop all event handlers
await self._api_user_request_void("stop")
# TODO Combine these into one
async def _api_user_request_result(
self, result_type: Type[ResultType], command: str, **data: JSON
) -> ResultType:
renewed = False
while True:
try:
return await self._api_request_result(result_type, command, **self._get_user_data(), **data)
except InvalidAccessToken:
if renewed:
raise
await self.renew_and_save()
renewed = True
async def _api_user_request_void(self, command: str, **data: JSON) -> None:
renewed = False
while True:
try:
return await self._api_request_void(command, **self._get_user_data(), **data)
except InvalidAccessToken:
if renewed:
raise
await self.renew_and_save()
renewed = True
# endregion
# region listeners
async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None:
async def wrapper(data: dict[str, JSON]) -> None:
await func(
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
data["channelType"]
)
self._add_user_handler("message", wrapper)
def _add_user_handler(self, command: str, handler: EventHandler) -> str:
self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler)
# endregion
@classmethod
async def _api_request_result(
cls, result_type: Type[ResultType], command: str, **data: JSON
) -> ResultType:
"""
Call a command via RPC, and return its result object.
On failure, raise an appropriate exception.
"""
resp = deserialize_result(
result_type,
await cls._rpc_client.request(command, **data)
)
if not resp.success:
raise_unsuccessful_response(resp)
# NOTE Not asserting against CommandResultDoneValue because it's generic!
# TODO Check if there really is no way to do it.
assert type(resp) is not RootCommandResult, "Result object missing from successful response"
return cast(CommandResultDoneValue[ResultType], resp).result
@classmethod
async def _api_request_void(cls, command: str, **data: JSON) -> None:
resp = RootCommandResult.deserialize(
await cls._rpc_client.request(command, **data)
)
if not resp.success:
raise_unsuccessful_response(resp)