747 lines
26 KiB
Python
747 lines
26 KiB
Python
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
|
|
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
Client functionality for the KakaoTalk API.
|
|
Currently a wrapper around a Node backend, but
|
|
the abstraction used here should be compatible
|
|
with any other potential backend.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, cast, Awaitable, Type, Optional, Union
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import logging
|
|
import urllib.request
|
|
|
|
from aiohttp import ClientSession
|
|
from aiohttp.client import _RequestContextManager
|
|
from yarl import URL
|
|
|
|
from mautrix.types import SerializerError
|
|
from mautrix.util.logging import TraceLogger
|
|
|
|
from ...config import Config
|
|
from ...rpc import EventHandler, RPCClient
|
|
|
|
from ..types.api.struct import (
|
|
FriendListStruct,
|
|
FriendReqStruct,
|
|
FriendStruct,
|
|
ProfileReqStruct,
|
|
ProfileStruct,
|
|
)
|
|
from ..types.bson import Long
|
|
from ..types.channel.channel_info import ChannelInfo
|
|
from ..types.chat import Chatlog, KnownChatType
|
|
from ..types.chat.attachment import MentionStruct, ReplyAttachment
|
|
from ..types.client.client_session import LoginResult
|
|
from ..types.oauth import OAuthCredential, OAuthInfo
|
|
from ..types.openlink.open_link_type import OpenChannelUserPerm
|
|
from ..types.openlink.open_link_user_info import OpenLinkChannelUserInfo
|
|
from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes
|
|
from ..types.request import (
|
|
deserialize_result,
|
|
ResultType,
|
|
ResultListType,
|
|
ResultRawType,
|
|
RootCommandResult,
|
|
CommandResultDoneValue
|
|
)
|
|
|
|
from .types import (
|
|
ChannelProps,
|
|
PortalChannelInfo,
|
|
PortalChannelParticipantInfo,
|
|
Receipt,
|
|
SettingsStruct,
|
|
UserInfoUnion,
|
|
)
|
|
|
|
from .errors import InvalidAccessToken, CommandException
|
|
from .error_helper import raise_unsuccessful_response
|
|
|
|
try:
|
|
from aiohttp_socks import ProxyConnector
|
|
except ImportError:
|
|
ProxyConnector = None
|
|
|
|
if TYPE_CHECKING:
|
|
from mautrix.types import JSON
|
|
from ... import user as u
|
|
|
|
|
|
@asynccontextmanager
|
|
async def sandboxed_get(url: URL) -> _RequestContextManager:
|
|
async with ClientSession() as sess, sess.get(url) as resp:
|
|
yield resp
|
|
|
|
|
|
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
|
|
# TODO If no state is stored, consider using free functions instead of classmethods
|
|
class Client:
|
|
_rpc_client: RPCClient
|
|
|
|
@classmethod
|
|
def init_cls(cls, config: Config) -> None:
|
|
"""Initialize RPC to the Node backend."""
|
|
cls._rpc_client = RPCClient(config, "kakaotalk")
|
|
# NOTE No need to store this, as cancelling the RPCClient will cancel this too
|
|
asyncio.create_task(cls._keep_connected())
|
|
|
|
@classmethod
|
|
async def _keep_connected(cls) -> None:
|
|
while True:
|
|
await cls._rpc_client.connect()
|
|
await cls._rpc_client.wait_for_disconnection()
|
|
|
|
@classmethod
|
|
def wait_for_connection(cls) -> Awaitable[None]:
|
|
return cls._rpc_client.wait_for_connection()
|
|
|
|
@classmethod
|
|
def stop_cls(cls) -> None:
|
|
"""Stop and disconnect from the Node backend."""
|
|
cls._rpc_client.cancel()
|
|
|
|
|
|
# region tokenless commands
|
|
|
|
@classmethod
|
|
async def generate_uuid(cls, used_uuids: set[str]) -> str:
|
|
"""Randomly generate a UUID for a (fake) device."""
|
|
tries_remaining = 10
|
|
while True:
|
|
uuid = await cls._rpc_client.request("generate_uuid")
|
|
if uuid not in used_uuids:
|
|
return uuid
|
|
tries_remaining -= 1
|
|
if tries_remaining == 0:
|
|
raise Exception(
|
|
"Unable to generate a UUID that hasn't been used before. "
|
|
"Either use a different RNG, or buy a lottery ticket"
|
|
)
|
|
|
|
@classmethod
|
|
def register_device(cls, passcode: str, **req: JSON) -> Awaitable[None]:
|
|
"""Register a (fake) device that will be associated with the provided login credentials."""
|
|
return cls._api_request_void("register_device", passcode=passcode, **req)
|
|
|
|
@classmethod
|
|
def login(cls, uuid: str, form: JSON, forced: bool) -> Awaitable[OAuthCredential]:
|
|
"""
|
|
Obtain a session token by logging in with user-provided credentials.
|
|
Must have first called register_device with these credentials.
|
|
"""
|
|
# NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential
|
|
return cls._api_request_result(
|
|
OAuthCredential,
|
|
"login",
|
|
uuid=uuid,
|
|
form=form,
|
|
forced=forced,
|
|
)
|
|
|
|
# endregion
|
|
|
|
|
|
user: u.User
|
|
_rpc_disconnection_task: asyncio.Task | None
|
|
http: ClientSession
|
|
log: TraceLogger
|
|
_handler_methods: list[str]
|
|
|
|
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
|
|
"""Create a per-user client object for user-specific client functionality."""
|
|
self.user = user
|
|
self._rpc_disconnection_task = None
|
|
self._handler_methods = []
|
|
|
|
# TODO Let the Node backend use a proxy too!
|
|
connector = None
|
|
try:
|
|
http_proxy = urllib.request.getproxies()["http"]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
if ProxyConnector:
|
|
connector = ProxyConnector.from_url(http_proxy)
|
|
else:
|
|
self.log.warning("http_proxy is set, but aiohttp-socks is not installed")
|
|
self.http = ClientSession(connector=connector)
|
|
|
|
self.log = log or logging.getLogger("mw.ktclient")
|
|
|
|
@property
|
|
def _oauth_credential(self) -> JSON:
|
|
return self.user.oauth_credential.serialize()
|
|
|
|
# region HTTP
|
|
|
|
def get(
|
|
self,
|
|
url: Union[str, URL],
|
|
headers: Optional[dict[str, str]] = None,
|
|
sandbox: bool = False,
|
|
**kwargs,
|
|
) -> _RequestContextManager:
|
|
# TODO Is auth ever needed?
|
|
headers = {
|
|
# TODO Are any default headers needed?
|
|
#**self._headers,
|
|
**(headers or {}),
|
|
}
|
|
url = URL(url)
|
|
if sandbox:
|
|
return sandboxed_get(url)
|
|
return self.http.get(url, headers=headers, **kwargs)
|
|
|
|
# endregion
|
|
|
|
|
|
# region post-token commands
|
|
|
|
async def start(self) -> SettingsStruct | None:
|
|
"""
|
|
Initialize user-specific bridging & state by providing a token obtained from a prior login.
|
|
Receive the user's profile info in response.
|
|
"""
|
|
try:
|
|
settings_struct = await self._api_user_cred_request_result(SettingsStruct, "start")
|
|
except SerializerError:
|
|
self.log.exception("Unable to deserialize settings struct, but starting client anyways")
|
|
settings_struct = None
|
|
if not self._rpc_disconnection_task:
|
|
self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler())
|
|
else:
|
|
self.log.warning("Called \"start\" on an already-started client")
|
|
return settings_struct
|
|
|
|
async def stop(self) -> None:
|
|
"""Immediately stop bridging this user."""
|
|
self._stop_listen()
|
|
if self._rpc_disconnection_task:
|
|
self._rpc_disconnection_task.cancel()
|
|
else:
|
|
self.log.warning("Called \"stop\" on an already-stopped client")
|
|
await self._rpc_client.request("stop", mxid=self.user.mxid)
|
|
|
|
async def _rpc_disconnection_handler(self) -> None:
|
|
await self._rpc_client.wait_for_disconnection()
|
|
self._rpc_disconnection_task = None
|
|
self._stop_listen()
|
|
asyncio.create_task(self.user.on_client_disconnect())
|
|
|
|
async def renew_and_save(self) -> None:
|
|
"""Renew and save the user's session tokens."""
|
|
oauth_info = await self._api_user_cred_request_result(OAuthInfo, "renew", renew=False)
|
|
self.user.oauth_credential = oauth_info.credential
|
|
await self.user.save()
|
|
|
|
async def connect(self) -> LoginResult | None:
|
|
"""
|
|
Start a new talk session by providing a token obtained from a prior login.
|
|
Receive a snapshot of account state in response.
|
|
"""
|
|
try:
|
|
login_result = await self._api_user_cred_request_result(LoginResult, "connect")
|
|
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
|
except SerializerError:
|
|
self.log.exception("Unable to deserialize login result, but connecting anyways")
|
|
login_result = None
|
|
# TODO Skip if handlers are already listening. But this is idempotent and thus probably safe
|
|
self._start_listen()
|
|
return login_result
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the talk session, but remain logged in."""
|
|
await self._rpc_client.request("disconnect", mxid=self.user.mxid)
|
|
await self._on_disconnect(None)
|
|
|
|
def is_connected(self) -> Awaitable[bool]:
|
|
return self._rpc_client.request("is_connected", mxid=self.user.mxid)
|
|
|
|
def get_settings(self) -> Awaitable[SettingsStruct]:
|
|
return self._api_user_request_result(SettingsStruct, "get_settings")
|
|
|
|
async def get_own_profile(self) -> ProfileStruct:
|
|
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
|
|
return profile_req_struct.profile
|
|
|
|
async def get_profile(self, user_id: Long) -> ProfileStruct:
|
|
profile_req_struct = await self._api_user_request_result(
|
|
ProfileReqStruct,
|
|
"get_profile",
|
|
user_id=user_id.serialize(),
|
|
)
|
|
return profile_req_struct.profile
|
|
|
|
def get_portal_channel_info(self, channel_props: ChannelProps) -> Awaitable[PortalChannelInfo]:
|
|
return self._api_user_request_result(
|
|
PortalChannelInfo,
|
|
"get_portal_channel_info",
|
|
channel_props=channel_props.serialize(),
|
|
)
|
|
|
|
def get_portal_channel_participant_info(self, channel_props: ChannelProps) -> Awaitable[PortalChannelParticipantInfo]:
|
|
return self._api_user_request_result(
|
|
PortalChannelParticipantInfo,
|
|
"get_portal_channel_participant_info",
|
|
channel_props=channel_props.serialize(),
|
|
)
|
|
|
|
def get_participants(self, channel_props: ChannelProps) -> Awaitable[list[UserInfoUnion]]:
|
|
return self._api_user_request_result(
|
|
ResultListType(UserInfoUnion),
|
|
"get_participants",
|
|
channel_props=channel_props.serialize(),
|
|
)
|
|
|
|
def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> Awaitable[list[Chatlog]]:
|
|
return self._api_user_request_result(
|
|
ResultListType(Chatlog),
|
|
"get_chats",
|
|
channel_props=channel_props.serialize(),
|
|
sync_from=sync_from.serialize() if sync_from else None,
|
|
limit=limit,
|
|
)
|
|
|
|
def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]:
|
|
return self._api_user_request_result(
|
|
ResultListType(Receipt),
|
|
"get_read_receipts",
|
|
channel_props=channel_props.serialize(),
|
|
unread_chat_ids=[c.serialize() for c in unread_chat_ids],
|
|
)
|
|
|
|
async def can_change_uuid(self, uuid: str) -> bool:
|
|
try:
|
|
await self._api_user_request_void("can_change_uuid", uuid=uuid)
|
|
except CommandException:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def change_uuid(self, uuid: str) -> Awaitable[None]:
|
|
return self._api_user_request_void("change_uuid", uuid=uuid)
|
|
|
|
def set_uuid_searchable(self, searchable: bool) -> Awaitable[None]:
|
|
return self._api_user_request_void("set_uuid_searchable", searchable=searchable)
|
|
|
|
def list_friends(self) -> Awaitable[FriendListStruct]:
|
|
return self._api_user_request_result(
|
|
FriendListStruct,
|
|
"list_friends",
|
|
)
|
|
|
|
async def edit_friend(self, ktid: Long, add: bool) -> FriendStruct | None:
|
|
try:
|
|
friend_req_struct = await self._api_user_request_result(
|
|
FriendReqStruct,
|
|
"edit_friend",
|
|
user_id=ktid.serialize(),
|
|
add=add,
|
|
)
|
|
return friend_req_struct.friend
|
|
except SerializerError:
|
|
self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
|
|
return None
|
|
|
|
async def edit_friend_by_uuid(self, uuid: str, add: bool) -> FriendStruct | None:
|
|
try:
|
|
friend_req_struct = await self._api_user_request_result(
|
|
FriendReqStruct,
|
|
"edit_friend_by_uuid",
|
|
uuid=uuid,
|
|
add=add,
|
|
)
|
|
return friend_req_struct.friend
|
|
except SerializerError:
|
|
self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
|
|
return None
|
|
|
|
async def get_friend_dm_id(self, friend_id: Long) -> Long | None:
|
|
try:
|
|
return await self._api_user_request_result(
|
|
Long,
|
|
"get_friend_dm_id",
|
|
friend_id=friend_id.serialize(),
|
|
)
|
|
except CommandException:
|
|
self.log.exception(f"Could not find friend with ID {friend_id}")
|
|
return None
|
|
|
|
async def get_memo_ids(self) -> list[Long]:
|
|
return ResultListType(Long).deserialize(
|
|
await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid)
|
|
)
|
|
|
|
def download_file(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
key: str,
|
|
) -> Awaitable[bytes]:
|
|
return self._api_user_request_result(
|
|
ResultRawType(bytes),
|
|
"download_file",
|
|
channel_props=channel_props.serialize(),
|
|
key=key,
|
|
)
|
|
|
|
def send_chat(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
text: str,
|
|
reply_to: ReplyAttachment | None,
|
|
mentions: list[MentionStruct] | None,
|
|
) -> Awaitable[Long]:
|
|
return self._api_user_request_result(
|
|
Long,
|
|
"send_chat",
|
|
channel_props=channel_props.serialize(),
|
|
text=text,
|
|
reply_to=reply_to.serialize() if reply_to is not None else None,
|
|
mentions=[m.serialize() for m in mentions] if mentions is not None else None,
|
|
)
|
|
|
|
def send_media(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
media_type: KnownChatType,
|
|
data: bytes,
|
|
filename: str,
|
|
*,
|
|
width: int | None = None,
|
|
height: int | None = None,
|
|
ext: str | None = None,
|
|
) -> Awaitable[Long]:
|
|
return self._api_user_request_result(
|
|
Long,
|
|
"send_media",
|
|
channel_props=channel_props.serialize(),
|
|
type=media_type,
|
|
data=list(data),
|
|
name=filename,
|
|
width=width,
|
|
height=height,
|
|
ext=ext,
|
|
)
|
|
|
|
def delete_chat(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
chat_id: Long,
|
|
) -> Awaitable[None]:
|
|
return self._api_user_request_void(
|
|
"delete_chat",
|
|
channel_props=channel_props.serialize(),
|
|
chat_id=chat_id.serialize(),
|
|
)
|
|
|
|
def mark_read(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
read_until_chat_id: Long,
|
|
) -> Awaitable[None]:
|
|
return self._api_user_request_void(
|
|
"mark_read",
|
|
channel_props=channel_props.serialize(),
|
|
read_until_chat_id=read_until_chat_id.serialize(),
|
|
)
|
|
|
|
def send_perm(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
user_id: Long,
|
|
perm: OpenChannelUserPerm,
|
|
) -> Awaitable[None]:
|
|
return self._api_user_request_void(
|
|
"send_perm",
|
|
channel_props=channel_props.serialize(),
|
|
user_id=user_id.serialize(),
|
|
perm=perm,
|
|
)
|
|
|
|
def set_channel_name(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
name: str,
|
|
) -> Awaitable[None]:
|
|
return self._api_user_request_void(
|
|
"set_channel_name",
|
|
channel_props=channel_props.serialize(),
|
|
name=name,
|
|
)
|
|
|
|
def set_channel_description(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
description: str,
|
|
) -> Awaitable[None]:
|
|
return self._api_user_request_void(
|
|
"set_channel_description",
|
|
channel_props=channel_props.serialize(),
|
|
description=description,
|
|
)
|
|
|
|
def set_channel_photo(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
photo_url: str,
|
|
) -> Awaitable[None]:
|
|
return self._api_user_request_void(
|
|
"set_channel_photo",
|
|
channel_props=channel_props.serialize(),
|
|
photo_url=photo_url,
|
|
)
|
|
|
|
def create_direct_chat(self, ktid: Long) -> Awaitable[Long]:
|
|
return self._api_user_request_result(
|
|
Long,
|
|
"create_direct_chat",
|
|
user_id=ktid.serialize(),
|
|
)
|
|
|
|
def leave_channel(
|
|
self,
|
|
channel_props: ChannelProps,
|
|
) -> Awaitable[None]:
|
|
return self._api_user_request_void(
|
|
"leave_channel",
|
|
channel_props=channel_props.serialize(),
|
|
)
|
|
|
|
|
|
# TODO Combine each of these pairs into one
|
|
|
|
async def _api_user_request_result(
|
|
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
|
|
) -> ResultType:
|
|
while True:
|
|
try:
|
|
return await self._api_request_result(result_type, command, mxid=self.user.mxid, **data)
|
|
except InvalidAccessToken:
|
|
if not renew:
|
|
raise
|
|
await self.renew_and_save()
|
|
renew = False
|
|
|
|
async def _api_user_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
|
while True:
|
|
try:
|
|
return await self._api_request_void(command, mxid=self.user.mxid, **data)
|
|
except InvalidAccessToken:
|
|
if not renew:
|
|
raise
|
|
await self.renew_and_save()
|
|
renew = False
|
|
|
|
|
|
async def _api_user_cred_request_result(
|
|
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
|
|
) -> ResultType:
|
|
while True:
|
|
try:
|
|
return await self._api_user_request_result(
|
|
result_type, command, oauth_credential=self._oauth_credential, renew=False, **data
|
|
)
|
|
except InvalidAccessToken:
|
|
if not renew:
|
|
raise
|
|
await self.renew_and_save()
|
|
renew = False
|
|
|
|
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
|
while True:
|
|
try:
|
|
await self._api_user_request_void(
|
|
command, oauth_credential=self._oauth_credential, renew=False, **data
|
|
)
|
|
except InvalidAccessToken:
|
|
if not renew:
|
|
raise
|
|
await self.renew_and_save()
|
|
renew = False
|
|
|
|
# endregion
|
|
|
|
|
|
# region listeners
|
|
|
|
def _on_chat(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_chat(
|
|
Chatlog.deserialize(data["chatlog"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_chat_deleted(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_chat_deleted(
|
|
Long.deserialize(data["chatId"]),
|
|
Long.deserialize(data["senderId"]),
|
|
int(data["timestamp"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_chat_read(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_chat_read(
|
|
Long.deserialize(data["chatId"]),
|
|
Long.deserialize(data["senderId"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_profile_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_profile_changed(
|
|
OpenLinkChannelUserInfo.deserialize(data["info"]),
|
|
)
|
|
|
|
def _on_perm_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_perm_changed(
|
|
Long.deserialize(data["userId"]),
|
|
OpenChannelUserPerm(data["perm"]),
|
|
Long.deserialize(data["senderId"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_channel_added(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_channel_added(
|
|
ChannelInfo.deserialize(data["channelInfo"]),
|
|
)
|
|
|
|
def _on_channel_join(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_channel_join(
|
|
ChannelInfo.deserialize(data["channelInfo"]),
|
|
)
|
|
|
|
def _on_channel_left(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_channel_left(
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_channel_kicked(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_channel_kicked(
|
|
Long.deserialize(data["userId"]),
|
|
Long.deserialize(data["senderId"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_user_join(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_user_join(
|
|
Long.deserialize(data["userId"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_user_left(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_user_left(
|
|
Long.deserialize(data["userId"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
def _on_channel_meta_change(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_channel_meta_change(
|
|
PortalChannelInfo.deserialize(data["info"]),
|
|
Long.deserialize(data["channelId"]),
|
|
str(data["channelType"]),
|
|
)
|
|
|
|
|
|
def _on_listen_disconnect(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
try:
|
|
res = KickoutRes.deserialize(data)
|
|
except Exception:
|
|
self.log.exception("Invalid kickout reason, defaulting to None")
|
|
res = None
|
|
return self._on_disconnect(res)
|
|
|
|
def _on_switch_server(self, _: dict[str, JSON]) -> Awaitable[None]:
|
|
# TODO Reconnect automatically instead
|
|
return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
|
|
|
|
def _on_disconnect(self, res: KickoutRes | None) -> Awaitable[None]:
|
|
self._stop_listen()
|
|
return self.user.on_disconnect(res)
|
|
|
|
def _on_error(self, data: dict[str, JSON]) -> Awaitable[None]:
|
|
return self.user.on_error(data)
|
|
|
|
|
|
def _start_listen(self) -> None:
|
|
self._add_event_handler("chat", self._on_chat)
|
|
self._add_event_handler("chat_deleted", self._on_chat_deleted)
|
|
self._add_event_handler("chat_read", self._on_chat_read)
|
|
self._add_event_handler("profile_changed", self._on_profile_changed)
|
|
self._add_event_handler("perm_changed", self._on_perm_changed)
|
|
self._add_event_handler("channel_added", self._on_channel_added)
|
|
self._add_event_handler("channel_join", self._on_channel_join)
|
|
self._add_event_handler("channel_left", self._on_channel_left)
|
|
self._add_event_handler("channel_kicked", self._on_channel_kicked)
|
|
self._add_event_handler("user_join", self._on_user_join)
|
|
self._add_event_handler("user_left", self._on_user_left)
|
|
self._add_event_handler("channel_meta_change", self._on_channel_meta_change)
|
|
self._add_event_handler("disconnected", self._on_listen_disconnect)
|
|
self._add_event_handler("switch_server", self._on_switch_server)
|
|
self._add_event_handler("error", self._on_error)
|
|
|
|
def _stop_listen(self) -> None:
|
|
for method in self._handler_methods:
|
|
self._rpc_client.set_event_handlers(self._get_user_cmd(method), [])
|
|
|
|
|
|
def _add_event_handler(self, method: str, handler: EventHandler):
|
|
self._rpc_client.set_event_handlers(self._get_user_cmd(method), [handler])
|
|
self._handler_methods.append(method)
|
|
|
|
def _get_user_cmd(self, command) -> str:
|
|
return f"{command}:{self.user.mxid}"
|
|
|
|
# endregion
|
|
|
|
|
|
@classmethod
|
|
async def _api_request_result(
|
|
cls, result_type: Type[ResultType], command: str, **data: JSON
|
|
) -> ResultType:
|
|
"""
|
|
Call a command via RPC, and return its result object.
|
|
On failure, raise an appropriate exception.
|
|
"""
|
|
resp = deserialize_result(
|
|
result_type,
|
|
await cls._rpc_client.request(command, **data)
|
|
)
|
|
if not resp.success:
|
|
raise_unsuccessful_response(resp)
|
|
# NOTE Not asserting against CommandResultDoneValue because it's generic!
|
|
# TODO Check if there really is no way to do it.
|
|
assert type(resp) is not RootCommandResult, "Result object missing from successful response"
|
|
return cast(CommandResultDoneValue[ResultType], resp).result
|
|
|
|
@classmethod
|
|
async def _api_request_void(cls, command: str, **data: JSON) -> None:
|
|
resp = RootCommandResult.deserialize(
|
|
await cls._rpc_client.request(command, **data)
|
|
)
|
|
if not resp.success:
|
|
raise_unsuccessful_response(resp)
|