matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/kt/client/client.py

553 lines
19 KiB
Python
Raw Normal View History

2022-02-25 02:22:50 -05:00
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Client functionality for the KakaoTalk API.
Currently a wrapper around a Node backend, but
the abstraction used here should be compatible
with any other potential backend.
"""
from __future__ import annotations
2022-04-10 02:18:53 -04:00
from typing import TYPE_CHECKING, cast, Awaitable, Type, Optional, Union
import asyncio
from contextlib import asynccontextmanager
2022-02-25 02:22:50 -05:00
import logging
import urllib.request
from aiohttp import ClientSession
from aiohttp.client import _RequestContextManager
from yarl import URL
from mautrix.types import SerializerError
2022-02-25 02:22:50 -05:00
from mautrix.util.logging import TraceLogger
from ...config import Config
from ...rpc import EventHandler, RPCClient
2022-02-25 02:22:50 -05:00
2022-03-23 03:09:30 -04:00
from ..types.api.struct import FriendListStruct
2022-04-10 04:30:26 -04:00
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
2022-02-25 02:22:50 -05:00
from ..types.bson import Long
2022-04-10 04:30:26 -04:00
from ..types.channel.channel_info import ChannelInfo
2022-03-26 03:37:53 -04:00
from ..types.chat import Chatlog, KnownChatType
2022-04-06 12:49:23 -04:00
from ..types.chat.attachment import MentionStruct, ReplyAttachment
2022-04-10 04:30:26 -04:00
from ..types.client.client_session import LoginResult
2022-02-25 02:22:50 -05:00
from ..types.oauth import OAuthCredential, OAuthInfo
2022-04-10 02:23:50 -04:00
from ..types.openlink.open_link_user_info import OpenLinkChannelUserInfo
from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes
2022-02-25 02:22:50 -05:00
from ..types.request import (
deserialize_result,
2022-03-09 02:25:28 -05:00
ResultType,
ResultListType,
RootCommandResult,
CommandResultDoneValue
)
2022-02-25 02:22:50 -05:00
from .types import (
ChannelProps,
PortalChannelInfo,
SettingsStruct,
UserInfoUnion,
)
2022-02-25 02:22:50 -05:00
from .errors import InvalidAccessToken, CommandException
2022-02-25 02:22:50 -05:00
from .error_helper import raise_unsuccessful_response
try:
from aiohttp_socks import ProxyConnector
except ImportError:
ProxyConnector = None
if TYPE_CHECKING:
from mautrix.types import JSON
from ... import user as u
2022-02-25 02:22:50 -05:00
@asynccontextmanager
async def sandboxed_get(url: URL) -> _RequestContextManager:
async with ClientSession() as sess, sess.get(url) as resp:
yield resp
2022-02-25 02:22:50 -05:00
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
# TODO If no state is stored, consider using free functions instead of classmethods
class Client:
_rpc_client: RPCClient
@classmethod
def init_cls(cls, config: Config) -> None:
2022-02-25 02:22:50 -05:00
"""Initialize RPC to the Node backend."""
cls._rpc_client = RPCClient(config)
# NOTE No need to store this, as cancelling the RPCClient will cancel this too
asyncio.create_task(cls._keep_connected())
2022-02-25 02:22:50 -05:00
@classmethod
async def _keep_connected(cls) -> None:
while True:
await cls._rpc_client.connect()
await cls._rpc_client.wait_for_disconnection()
@classmethod
def stop_cls(cls) -> None:
2022-02-25 02:22:50 -05:00
"""Stop and disconnect from the Node backend."""
cls._rpc_client.cancel()
2022-02-25 02:22:50 -05:00
# region tokenless commands
@classmethod
async def generate_uuid(cls, used_uuids: set[str]) -> str:
"""Randomly generate a UUID for a (fake) device."""
tries_remaining = 10
while True:
uuid = await cls._rpc_client.request("generate_uuid")
if uuid not in used_uuids:
return uuid
tries_remaining -= 1
if tries_remaining == 0:
raise Exception(
"Unable to generate a UUID that hasn't been used before. "
"Either use a different RNG, or buy a lottery ticket"
)
@classmethod
2022-03-18 03:52:55 -04:00
async def register_device(cls, passcode: str, **req: JSON) -> None:
2022-02-25 02:22:50 -05:00
"""Register a (fake) device that will be associated with the provided login credentials."""
2022-03-11 20:38:55 -05:00
await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req)
2022-02-25 02:22:50 -05:00
@classmethod
2022-03-18 03:52:55 -04:00
async def login(cls, **req: JSON) -> OAuthCredential:
2022-02-25 02:22:50 -05:00
"""
Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials.
"""
# NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential
2022-02-25 02:22:50 -05:00
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
# endregion
user: u.User
_rpc_disconnection_task: asyncio.Task | None
2022-02-25 02:22:50 -05:00
http: ClientSession
log: TraceLogger
_handler_methods: list[str]
2022-02-25 02:22:50 -05:00
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
2022-02-25 02:22:50 -05:00
"""Create a per-user client object for user-specific client functionality."""
self.user = user
self._rpc_disconnection_task = None
self._handler_methods = []
2022-02-25 02:22:50 -05:00
# TODO Let the Node backend use a proxy too!
connector = None
try:
http_proxy = urllib.request.getproxies()["http"]
except KeyError:
pass
else:
if ProxyConnector:
connector = ProxyConnector.from_url(http_proxy)
else:
self.log.warning("http_proxy is set, but aiohttp-socks is not installed")
self.http = ClientSession(connector=connector)
self.log = log or logging.getLogger("mw.ktclient")
@property
def _oauth_credential(self) -> JSON:
return self.user.oauth_credential.serialize()
@property
def _user_data(self) -> JSON:
return {
"mxid": self.user.mxid,
"oauth_credential": self._oauth_credential,
}
2022-02-25 02:22:50 -05:00
# region HTTP
def get(
self,
url: Union[str, URL],
headers: Optional[dict[str, str]] = None,
sandbox: bool = False,
2022-02-25 02:22:50 -05:00
**kwargs,
) -> _RequestContextManager:
# TODO Is auth ever needed?
headers = {
2022-03-09 02:25:28 -05:00
# TODO Are any default headers needed?
#**self._headers,
2022-02-25 02:22:50 -05:00
**(headers or {}),
}
url = URL(url)
if sandbox:
return sandboxed_get(url)
2022-02-25 02:22:50 -05:00
return self.http.get(url, headers=headers, **kwargs)
# endregion
# region post-token commands
async def start(self) -> SettingsStruct | None:
"""
Initialize user-specific bridging & state by providing a token obtained from a prior login.
Receive the user's profile info in response.
"""
try:
settings_struct = await self._api_user_request_result(SettingsStruct, "start")
except SerializerError:
self.log.exception("Unable to deserialize settings struct, but starting client anyways")
settings_struct = None
if not self._rpc_disconnection_task:
self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler())
else:
self.log.warning("Called \"start\" on an already-started client")
return settings_struct
async def stop(self) -> None:
"""Immediately stop bridging this user."""
self._stop_listen()
if self._rpc_disconnection_task:
self._rpc_disconnection_task.cancel()
else:
self.log.warning("Called \"stop\" on an already-stopped client")
await self._rpc_client.request("stop", mxid=self.user.mxid)
2022-02-25 02:22:50 -05:00
async def _rpc_disconnection_handler(self) -> None:
await self._rpc_client.wait_for_disconnection()
self._rpc_disconnection_task = None
self._stop_listen()
asyncio.create_task(self.user.on_client_disconnect())
2022-02-25 02:22:50 -05:00
async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens."""
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
2022-02-25 02:22:50 -05:00
self.user.oauth_credential = oauth_info.credential
await self.user.save()
2022-03-18 03:52:55 -04:00
async def connect(self) -> LoginResult:
2022-02-25 02:22:50 -05:00
"""
Start a new talk session by providing a token obtained from a prior login.
2022-02-25 02:22:50 -05:00
Receive a snapshot of account state in response.
"""
2022-03-18 03:52:55 -04:00
login_result = await self._api_user_request_result(LoginResult, "connect")
2022-02-25 02:22:50 -05:00
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
2022-03-18 03:52:55 -04:00
# TODO Skip if handlers are already listening. But this is idempotent and thus probably safe
self._start_listen()
2022-02-25 02:22:50 -05:00
return login_result
async def disconnect(self) -> None:
"""Disconnect from the talk session, but remain logged in."""
await self._rpc_client.request("disconnect", mxid=self.user.mxid)
await self._on_disconnect(None)
2022-02-25 02:22:50 -05:00
async def get_settings(self) -> SettingsStruct:
return await self._api_user_request_result(SettingsStruct, "get_settings")
async def get_own_profile(self) -> ProfileStruct:
2022-03-09 02:25:28 -05:00
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile
async def get_profile(self, user_id: Long) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(
ProfileReqStruct,
"get_profile",
2022-03-23 03:15:02 -04:00
user_id=user_id.serialize(),
2022-03-09 02:25:28 -05:00
)
return profile_req_struct.profile
async def get_portal_channel_info(self, channel_props: ChannelProps) -> PortalChannelInfo:
2022-03-10 02:46:24 -05:00
return await self._api_user_request_result(
2022-02-25 02:22:50 -05:00
PortalChannelInfo,
"get_portal_channel_info",
channel_props=channel_props.serialize(),
2022-02-25 02:22:50 -05:00
)
async def get_participants(self, channel_props: ChannelProps) -> list[UserInfoUnion]:
return await self._api_user_request_result(
ResultListType(UserInfoUnion),
"get_participants",
2022-03-23 03:15:02 -04:00
channel_props=channel_props.serialize(),
)
async def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> list[Chatlog]:
return await self._api_user_request_result(
2022-03-09 02:25:28 -05:00
ResultListType(Chatlog),
"get_chats",
channel_props=channel_props.serialize(),
sync_from=sync_from.serialize() if sync_from else None,
2022-03-23 03:15:02 -04:00
limit=limit,
)
2022-02-25 02:22:50 -05:00
2022-03-23 03:09:30 -04:00
async def list_friends(self) -> FriendListStruct:
return await self._api_user_request_result(
FriendListStruct,
"list_friends",
)
async def get_friend_dm_id(self, friend_id: Long) -> Long | None:
try:
return await self._api_user_request_result(
Long,
"get_friend_dm_id",
friend_id=friend_id.serialize(),
)
except CommandException:
self.log.exception(f"Could not find friend with ID {friend_id}")
return None
async def get_memo_ids(self) -> list[Long]:
return ResultListType(Long).deserialize(
await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid)
)
async def send_chat(
2022-04-05 15:44:02 -04:00
self,
channel_props: ChannelProps,
text: str,
reply_to: ReplyAttachment | None,
2022-04-06 12:49:23 -04:00
mentions: list[MentionStruct] | None,
2022-04-05 15:44:02 -04:00
) -> Chatlog:
2022-03-09 20:26:39 -05:00
return await self._api_user_request_result(
Chatlog,
2022-04-05 00:59:22 -04:00
"send_chat",
channel_props=channel_props.serialize(),
2022-03-23 03:15:02 -04:00
text=text,
2022-04-05 15:44:02 -04:00
reply_to=reply_to.serialize() if reply_to is not None else None,
2022-04-06 12:49:23 -04:00
mentions=[m.serialize() for m in mentions] if mentions is not None else None,
2022-03-09 20:26:39 -05:00
)
2022-03-26 03:37:53 -04:00
async def send_media(
self,
channel_props: ChannelProps,
media_type: KnownChatType,
data: bytes,
filename: str,
*,
width: int | None = None,
height: int | None = None,
ext: str | None = None,
) -> Chatlog:
return await self._api_user_request_result(
Chatlog,
"send_media",
channel_props=channel_props.serialize(),
type=media_type,
data=list(data),
name=filename,
width=width,
height=height,
ext=ext,
# Don't log the bytes
# TODO Disable logging per-argument, not per-command
is_secret=True
)
2022-04-09 04:02:51 -04:00
async def delete_chat(
self,
channel_props: ChannelProps,
chat_id: Long,
) -> None:
return await self._api_user_request_void(
"delete_chat",
channel_props=channel_props.serialize(),
chat_id=chat_id.serialize(),
)
2022-04-10 04:26:09 -04:00
async def mark_read(
self,
channel_props: ChannelProps,
read_until_chat_id: Long,
) -> None:
return await self._api_user_request_void(
"mark_read",
channel_props=channel_props.serialize(),
read_until_chat_id=read_until_chat_id.serialize(),
)
2022-02-25 02:22:50 -05:00
# TODO Combine these into one
async def _api_user_request_result(
self, result_type: Type[ResultType], command: str, **data: JSON
) -> ResultType:
renewed = False
while True:
try:
return await self._api_request_result(result_type, command, **self._user_data, **data)
2022-02-25 02:22:50 -05:00
except InvalidAccessToken:
if renewed:
raise
await self.renew_and_save()
renewed = True
async def _api_user_request_void(self, command: str, **data: JSON) -> None:
renewed = False
while True:
try:
return await self._api_request_void(command, **self._user_data, **data)
2022-02-25 02:22:50 -05:00
except InvalidAccessToken:
if renewed:
raise
await self.renew_and_save()
renewed = True
# endregion
# region listeners
async def _on_chat(self, data: dict[str, JSON]) -> None:
await self.user.on_chat(
2022-03-18 03:52:55 -04:00
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_chat_deleted(self, data: dict[str, JSON]) -> None:
await self.user.on_chat_deleted(
Long.deserialize(data["chatId"]),
Long.deserialize(data["senderId"]),
int(data["timestamp"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
2022-03-18 03:52:55 -04:00
)
async def _on_chat_read(self, data: dict[str, JSON]) -> None:
await self.user.on_chat_read(
Long.deserialize(data["chatId"]),
Long.deserialize(data["senderId"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
2022-04-10 02:23:50 -04:00
async def _on_profile_changed(self, data: dict[str, JSON]) -> None:
await self.user.on_profile_changed(
OpenLinkChannelUserInfo.deserialize(data["info"]),
)
2022-04-10 04:30:26 -04:00
async def _on_channel_join(self, data: dict[str, JSON]) -> None:
await self.user.on_channel_join(
ChannelInfo.deserialize(data["channelInfo"]),
)
async def _on_channel_left(self, data: dict[str, JSON]) -> None:
await self.user.on_channel_left(
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_channel_kicked(self, data: dict[str, JSON]) -> None:
await self.user.on_channel_kicked(
Long.deserialize(data["userId"]),
Long.deserialize(data["senderId"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_user_join(self, data: dict[str, JSON]) -> None:
await self.user.on_user_join(
Long.deserialize(data["userId"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_user_left(self, data: dict[str, JSON]) -> None:
await self.user.on_user_left(
Long.deserialize(data["userId"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
async def _on_listen_disconnect(self, data: dict[str, JSON]) -> None:
try:
res = KickoutRes.deserialize(data)
except Exception:
self.log.exception("Invalid kickout reason, defaulting to None")
res = None
await self._on_disconnect(res)
async def _on_switch_server(self) -> None:
# TODO Reconnect automatically instead
await self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
async def _on_disconnect(self, res: KickoutRes | None) -> None:
self._stop_listen()
await self.user.on_disconnect(res)
2022-04-10 02:18:53 -04:00
def _on_error(self, data: dict[str, JSON]) -> Awaitable[None]:
return self.user.on_error(data)
2022-02-25 02:22:50 -05:00
2022-03-18 03:52:55 -04:00
def _start_listen(self) -> None:
self._add_event_handler("chat", self._on_chat)
self._add_event_handler("chat_deleted", self._on_chat_deleted)
self._add_event_handler("chat_read", self._on_chat_read)
2022-04-10 02:23:50 -04:00
self._add_event_handler("profile_changed", self._on_profile_changed)
2022-04-10 04:30:26 -04:00
self._add_event_handler("channel_join", self._on_channel_join)
self._add_event_handler("channel_left", self._on_channel_left)
self._add_event_handler("channel_kicked", self._on_channel_kicked)
self._add_event_handler("user_join", self._on_user_join)
self._add_event_handler("user_left", self._on_user_left)
self._add_event_handler("disconnected", self._on_listen_disconnect)
self._add_event_handler("switch_server", self._on_switch_server)
2022-04-10 02:18:53 -04:00
self._add_event_handler("error", self._on_error)
2022-02-25 02:22:50 -05:00
2022-03-18 03:52:55 -04:00
def _stop_listen(self) -> None:
for method in self._handler_methods:
self._rpc_client.set_event_handlers(self._get_user_cmd(method), [])
def _add_event_handler(self, method: str, handler: EventHandler):
self._rpc_client.set_event_handlers(self._get_user_cmd(method), [handler])
self._handler_methods.append(method)
2022-02-25 02:22:50 -05:00
2022-03-18 03:52:55 -04:00
def _get_user_cmd(self, command) -> str:
return f"{command}:{self.user.mxid}"
2022-02-25 02:22:50 -05:00
# endregion
@classmethod
async def _api_request_result(
cls, result_type: Type[ResultType], command: str, **data: JSON
) -> ResultType:
"""
Call a command via RPC, and return its result object.
On failure, raise an appropriate exception.
"""
resp = deserialize_result(
result_type,
await cls._rpc_client.request(command, **data)
)
if not resp.success:
raise_unsuccessful_response(resp)
# NOTE Not asserting against CommandResultDoneValue because it's generic!
# TODO Check if there really is no way to do it.
assert type(resp) is not RootCommandResult, "Result object missing from successful response"
return cast(CommandResultDoneValue[ResultType], resp).result
@classmethod
async def _api_request_void(cls, command: str, **data: JSON) -> None:
resp = RootCommandResult.deserialize(
await cls._rpc_client.request(command, **data)
)
if not resp.success:
raise_unsuccessful_response(resp)