matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/kt/client/client.py
2022-04-10 04:38:25 -04:00

489 lines
17 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Client functionality for the KakaoTalk API.
Currently a wrapper around a Node backend, but
the abstraction used here should be compatible
with any other potential backend.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, cast, Awaitable, Type, Optional, Union
import asyncio
from contextlib import asynccontextmanager
import logging
import urllib.request
from aiohttp import ClientSession
from aiohttp.client import _RequestContextManager
from yarl import URL
from mautrix.util.logging import TraceLogger
from ...config import Config
from ...rpc import EventHandler, RPCClient
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
from ..types.api.struct import FriendListStruct
from ..types.bson import Long
from ..types.client.client_session import LoginResult
from ..types.chat import Chatlog, KnownChatType
from ..types.chat.attachment import MentionStruct, ReplyAttachment
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.openlink.open_link_user_info import OpenLinkChannelUserInfo
from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes
from ..types.request import (
deserialize_result,
ResultType,
ResultListType,
RootCommandResult,
CommandResultDoneValue
)
from .types import PortalChannelInfo, UserInfoUnion, ChannelProps
from .errors import InvalidAccessToken, CommandException
from .error_helper import raise_unsuccessful_response
try:
from aiohttp_socks import ProxyConnector
except ImportError:
ProxyConnector = None
if TYPE_CHECKING:
from mautrix.types import JSON
from ... import user as u
@asynccontextmanager
async def sandboxed_get(url: URL) -> _RequestContextManager:
async with ClientSession() as sess, sess.get(url) as resp:
yield resp
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
# TODO If no state is stored, consider using free functions instead of classmethods
class Client:
_rpc_client: RPCClient
@classmethod
def init_cls(cls, config: Config) -> None:
"""Initialize RPC to the Node backend."""
cls._rpc_client = RPCClient(config)
# NOTE No need to store this, as cancelling the RPCClient will cancel this too
asyncio.create_task(cls._keep_connected())
@classmethod
async def _keep_connected(cls) -> None:
while True:
await cls._rpc_client.connect()
await cls._rpc_client.wait_for_disconnection()
@classmethod
def stop_cls(cls) -> None:
"""Stop and disconnect from the Node backend."""
cls._rpc_client.cancel()
# region tokenless commands
@classmethod
async def generate_uuid(cls, used_uuids: set[str]) -> str:
"""Randomly generate a UUID for a (fake) device."""
tries_remaining = 10
while True:
uuid = await cls._rpc_client.request("generate_uuid")
if uuid not in used_uuids:
return uuid
tries_remaining -= 1
if tries_remaining == 0:
raise Exception(
"Unable to generate a UUID that hasn't been used before. "
"Either use a different RNG, or buy a lottery ticket"
)
@classmethod
async def register_device(cls, passcode: str, **req: JSON) -> None:
"""Register a (fake) device that will be associated with the provided login credentials."""
await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req)
@classmethod
async def login(cls, **req: JSON) -> OAuthCredential:
"""
Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials.
"""
# NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
# endregion
user: u.User
_rpc_disconnection_task: asyncio.Task | None
http: ClientSession
log: TraceLogger
_handler_methods: list[str]
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
"""Create a per-user client object for user-specific client functionality."""
self.user = user
self._rpc_disconnection_task = None
self._handler_methods = []
# TODO Let the Node backend use a proxy too!
connector = None
try:
http_proxy = urllib.request.getproxies()["http"]
except KeyError:
pass
else:
if ProxyConnector:
connector = ProxyConnector.from_url(http_proxy)
else:
self.log.warning("http_proxy is set, but aiohttp-socks is not installed")
self.http = ClientSession(connector=connector)
self.log = log or logging.getLogger("mw.ktclient")
@property
def _oauth_credential(self) -> JSON:
return self.user.oauth_credential.serialize()
@property
def _user_data(self) -> JSON:
return {
"mxid": self.user.mxid,
"oauth_credential": self._oauth_credential,
}
# region HTTP
def get(
self,
url: Union[str, URL],
headers: Optional[dict[str, str]] = None,
sandbox: bool = False,
**kwargs,
) -> _RequestContextManager:
# TODO Is auth ever needed?
headers = {
# TODO Are any default headers needed?
#**self._headers,
**(headers or {}),
}
url = URL(url)
if sandbox:
return sandboxed_get(url)
return self.http.get(url, headers=headers, **kwargs)
# endregion
# region post-token commands
async def start(self) -> ProfileStruct:
"""
Initialize user-specific bridging & state by providing a token obtained from a prior login.
Receive the user's profile info in response.
"""
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "start")
if not self._rpc_disconnection_task:
self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler())
else:
self.log.warning("Called \"start\" on an already-started client")
return profile_req_struct.profile
async def stop(self) -> None:
"""Immediately stop bridging this user."""
self._stop_listen()
if self._rpc_disconnection_task:
self._rpc_disconnection_task.cancel()
else:
self.log.warning("Called \"stop\" on an already-stopped client")
await self._rpc_client.request("stop", mxid=self.user.mxid)
async def _rpc_disconnection_handler(self) -> None:
await self._rpc_client.wait_for_disconnection()
self._rpc_disconnection_task = None
self._stop_listen()
asyncio.create_task(self.user.on_client_disconnect())
async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens."""
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
self.user.oauth_credential = oauth_info.credential
await self.user.save()
async def connect(self) -> LoginResult:
"""
Start a new talk session by providing a token obtained from a prior login.
Receive a snapshot of account state in response.
"""
login_result = await self._api_user_request_result(LoginResult, "connect")
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
# TODO Skip if handlers are already listening. But this is idempotent and thus probably safe
self._start_listen()
return login_result
async def disconnect(self) -> None:
"""Disconnect from the talk session, but remain logged in."""
await self._rpc_client.request("disconnect", mxid=self.user.mxid)
await self._on_disconnect(None)
async def get_own_profile(self) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile
async def get_profile(self, user_id: Long) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(
ProfileReqStruct,
"get_profile",
user_id=user_id.serialize(),
)
return profile_req_struct.profile
async def get_portal_channel_info(self, channel_props: ChannelProps) -> PortalChannelInfo:
return await self._api_user_request_result(
PortalChannelInfo,
"get_portal_channel_info",
channel_props=channel_props.serialize(),
)
async def get_participants(self, channel_props: ChannelProps) -> list[UserInfoUnion]:
return await self._api_user_request_result(
ResultListType(UserInfoUnion),
"get_participants",
channel_props=channel_props.serialize(),
)
async def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> list[Chatlog]:
return await self._api_user_request_result(
ResultListType(Chatlog),
"get_chats",
channel_props=channel_props.serialize(),
sync_from=sync_from.serialize() if sync_from else None,
limit=limit,
)
async def list_friends(self) -> FriendListStruct:
return await self._api_user_request_result(
FriendListStruct,
"list_friends",
)
async def get_friend_dm_id(self, friend_id: Long) -> Long | None:
try:
return await self._api_user_request_result(
Long,
"get_friend_dm_id",
friend_id=friend_id.serialize(),
)
except CommandException:
self.log.exception(f"Could not find friend with ID {friend_id}")
return None
async def get_memo_ids(self) -> list[Long]:
return ResultListType(Long).deserialize(
await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid)
)
async def send_chat(
self,
channel_props: ChannelProps,
text: str,
reply_to: ReplyAttachment | None,
mentions: list[MentionStruct] | None,
) -> Chatlog:
return await self._api_user_request_result(
Chatlog,
"send_chat",
channel_props=channel_props.serialize(),
text=text,
reply_to=reply_to.serialize() if reply_to is not None else None,
mentions=[m.serialize() for m in mentions] if mentions is not None else None,
)
async def send_media(
self,
channel_props: ChannelProps,
media_type: KnownChatType,
data: bytes,
filename: str,
*,
width: int | None = None,
height: int | None = None,
ext: str | None = None,
) -> Chatlog:
return await self._api_user_request_result(
Chatlog,
"send_media",
channel_props=channel_props.serialize(),
type=media_type,
data=list(data),
name=filename,
width=width,
height=height,
ext=ext,
# Don't log the bytes
# TODO Disable logging per-argument, not per-command
is_secret=True
)
async def delete_chat(
self,
channel_props: ChannelProps,
chat_id: Long,
) -> None:
return await self._api_user_request_void(
"delete_chat",
channel_props=channel_props.serialize(),
chat_id=chat_id.serialize(),
)
# TODO Combine these into one
async def _api_user_request_result(
self, result_type: Type[ResultType], command: str, **data: JSON
) -> ResultType:
renewed = False
while True:
try:
return await self._api_request_result(result_type, command, **self._user_data, **data)
except InvalidAccessToken:
if renewed:
raise
await self.renew_and_save()
renewed = True
async def _api_user_request_void(self, command: str, **data: JSON) -> None:
renewed = False
while True:
try:
return await self._api_request_void(command, **self._user_data, **data)
except InvalidAccessToken:
if renewed:
raise
await self.renew_and_save()
renewed = True
# endregion
# region listeners
async def _on_chat(self, data: dict[str, JSON]) -> None:
await self.user.on_chat(
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_chat_deleted(self, data: dict[str, JSON]) -> None:
await self.user.on_chat_deleted(
Long.deserialize(data["chatId"]),
Long.deserialize(data["senderId"]),
int(data["timestamp"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_chat_read(self, data: dict[str, JSON]) -> None:
await self.user.on_chat_read(
Long.deserialize(data["chatId"]),
Long.deserialize(data["senderId"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_profile_changed(self, data: dict[str, JSON]) -> None:
await self.user.on_profile_changed(
OpenLinkChannelUserInfo.deserialize(data["info"]),
)
async def _on_listen_disconnect(self, data: dict[str, JSON]) -> None:
try:
res = KickoutRes.deserialize(data)
except Exception:
self.log.exception("Invalid kickout reason, defaulting to None")
res = None
await self._on_disconnect(res)
async def _on_switch_server(self) -> None:
# TODO Reconnect automatically instead
await self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
async def _on_disconnect(self, res: KickoutRes | None) -> None:
self._stop_listen()
await self.user.on_disconnect(res)
def _on_error(self, data: dict[str, JSON]) -> Awaitable[None]:
return self.user.on_error(data)
def _start_listen(self) -> None:
self._add_event_handler("chat", self._on_chat)
self._add_event_handler("chat_deleted", self._on_chat_deleted)
self._add_event_handler("chat_read", self._on_chat_read)
self._add_event_handler("profile_changed", self._on_profile_changed)
self._add_event_handler("disconnected", self._on_listen_disconnect)
self._add_event_handler("switch_server", self._on_switch_server)
self._add_event_handler("error", self._on_error)
def _stop_listen(self) -> None:
for method in self._handler_methods:
self._rpc_client.set_event_handlers(self._get_user_cmd(method), [])
def _add_event_handler(self, method: str, handler: EventHandler):
self._rpc_client.set_event_handlers(self._get_user_cmd(method), [handler])
self._handler_methods.append(method)
def _get_user_cmd(self, command) -> str:
return f"{command}:{self.user.mxid}"
# endregion
@classmethod
async def _api_request_result(
cls, result_type: Type[ResultType], command: str, **data: JSON
) -> ResultType:
"""
Call a command via RPC, and return its result object.
On failure, raise an appropriate exception.
"""
resp = deserialize_result(
result_type,
await cls._rpc_client.request(command, **data)
)
if not resp.success:
raise_unsuccessful_response(resp)
# NOTE Not asserting against CommandResultDoneValue because it's generic!
# TODO Check if there really is no way to do it.
assert type(resp) is not RootCommandResult, "Result object missing from successful response"
return cast(CommandResultDoneValue[ResultType], resp).result
@classmethod
async def _api_request_void(cls, command: str, **data: JSON) -> None:
resp = RootCommandResult.deserialize(
await cls._rpc_client.request(command, **data)
)
if not resp.success:
raise_unsuccessful_response(resp)