diff --git a/matrix_appservice_kakaotalk/__main__.py b/matrix_appservice_kakaotalk/__main__.py index 3a05681..8a03485 100644 --- a/matrix_appservice_kakaotalk/__main__.py +++ b/matrix_appservice_kakaotalk/__main__.py @@ -72,8 +72,6 @@ class KakaoTalkBridge(Bridge): puppet.stop() self.log.debug("Stopping kakaotalk listeners") User.shutdown = True - for user in User.by_ktid.values(): - user.stop_listen() self.add_shutdown_actions(user.save() for user in User.by_mxid.values()) self.add_shutdown_actions(KakaoTalkClient.stop_cls()) diff --git a/matrix_appservice_kakaotalk/commands/conn.py b/matrix_appservice_kakaotalk/commands/conn.py index 6a1336e..3a171d6 100644 --- a/matrix_appservice_kakaotalk/commands/conn.py +++ b/matrix_appservice_kakaotalk/commands/conn.py @@ -17,6 +17,8 @@ from mautrix.bridge.commands import HelpSection, command_handler from .typehint import CommandEvent +from ..kt.client.errors import CommandException + SECTION_CONNECTION = HelpSection("Connection management", 15, "") @@ -32,35 +34,6 @@ async def set_notice_room(evt: CommandEvent) -> None: await evt.reply("This room has been marked as your bridge notice room") -""" -@command_handler( - needs_auth=True, - management_only=True, - help_section=SECTION_CONNECTION, - help_text="Disconnect from KakaoTalk", -) -async def disconnect(evt: CommandEvent) -> None: - if not evt.sender.mqtt: - await evt.reply("You don't have a KakaoTalk MQTT connection") - return - evt.sender.mqtt.disconnect() - - -@command_handler( - needs_auth=True, - management_only=True, - help_section=SECTION_CONNECTION, - help_text="Connect to KakaoTalk", - aliases=["reconnect"], -) -async def connect(evt: CommandEvent) -> None: - if evt.sender.listen_task and not evt.sender.listen_task.done(): - await evt.reply("You already have a KakaoTalk MQTT connection") - return - evt.sender.start_listen() -""" - - @command_handler( needs_auth=True, management_only=True, @@ -72,34 +45,11 @@ async def ping(evt: CommandEvent) -> None: await evt.reply("You're not logged into KakaoTalk") return await evt.mark_read() - # try: - own_info = await evt.sender.get_own_info() - # TODO catch errors - # except fbchat.PleaseRefresh as e: - # await evt.reply(f"{e}\n\nUse `$cmdprefix+sp refresh` refresh the session.") - # return - await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})") - - """ - if not evt.sender.listen_task or evt.sender.listen_task.done(): - await evt.reply("You don't have a KakaoTalk MQTT connection. Use `connect` to connect.") - elif not evt.sender.is_connected: - await evt.reply("The KakaoTalk MQTT listener is **disconnected**.") - else: - await evt.reply("The KakaoTalk MQTT listener is connected.") - """ - - -""" -@command_handler( - needs_auth=True, - management_only=True, - help_section=SECTION_CONNECTION, - help_text="Resync chats and reconnect to MQTT", -) -async def refresh(evt: CommandEvent) -> None: - await evt.sender.refresh(force_notice=True) -""" + try: + own_info = await evt.sender.get_own_info() + await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})") + except CommandException as e: + await evt.reply(f"Error from KakaoTalk: {e}") @command_handler( @@ -107,10 +57,19 @@ async def refresh(evt: CommandEvent) -> None: management_only=True, help_section=SECTION_CONNECTION, help_text="Resync chats", + help_args="[count]", ) async def sync(evt: CommandEvent) -> None: + try: + sync_count = int(evt.args[0]) + except IndexError: + sync_count = None + except ValueError: + await evt.reply("**Usage:** `$cmdprefix+sp logout [--reset-device]`") + return + await evt.mark_read() - if await evt.sender.post_login(is_startup=False): + if await evt.sender.connect_and_sync(sync_count): await evt.reply("Sync complete") else: await evt.reply("Sync failed") diff --git a/matrix_appservice_kakaotalk/db/user.py b/matrix_appservice_kakaotalk/db/user.py index 22f8dbe..476e186 100644 --- a/matrix_appservice_kakaotalk/db/user.py +++ b/matrix_appservice_kakaotalk/db/user.py @@ -47,24 +47,23 @@ class User: def _from_optional_row(cls, row: Record | None) -> User | None: return cls._from_row(row) if row is not None else None + _columns = "mxid, ktid, uuid, access_token, refresh_token, notice_room" + @classmethod async def all_logged_in(cls) -> List[User]: - q = """ - SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" - WHERE ktid<>0 - """ + q = f'SELECT {cls._columns} FROM "user" WHERE ktid<>0' rows = await cls.db.fetch(q) return [cls._from_row(row) for row in rows if row] @classmethod async def get_by_ktid(cls, ktid: int) -> User | None: - q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1' + q = f'SELECT {cls._columns} FROM "user" WHERE ktid=$1' row = await cls.db.fetchrow(q, ktid) return cls._from_optional_row(row) @classmethod async def get_by_mxid(cls, mxid: UserID) -> User | None: - q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE mxid=$1' + q = f'SELECT {cls._columns} FROM "user" WHERE mxid=$1' row = await cls.db.fetchrow(q, mxid) return cls._from_optional_row(row) @@ -73,23 +72,30 @@ class User: q = 'SELECT uuid FROM "user" WHERE uuid IS NOT NULL' return {tuple(record)[0] for record in await cls.db.fetch(q)} + @property + def _values(self): + return ( + self.mxid, + self.ktid, + self.uuid, + self.access_token, + self.refresh_token, + self.notice_room, + ) + async def insert(self) -> None: q = """ INSERT INTO "user" (mxid, ktid, uuid, access_token, refresh_token, notice_room) VALUES ($1, $2, $3, $4, $5, $6) """ - await self.db.execute( - q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room - ) + await self.db.execute(q, *self._values) async def delete(self) -> None: await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid) async def save(self) -> None: q = """ - UPDATE "user" SET ktid=$1, uuid=$2, access_token=$3, refresh_token=$4, notice_room=$5 - WHERE mxid=$6 + UPDATE "user" SET ktid=$2, uuid=$3, access_token=$4, refresh_token=$5, notice_room=$6 + WHERE mxid=$1 """ - await self.db.execute( - q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid - ) + await self.db.execute(q, *self._values) diff --git a/matrix_appservice_kakaotalk/example-config.yaml b/matrix_appservice_kakaotalk/example-config.yaml index 1f6334a..1425d6a 100644 --- a/matrix_appservice_kakaotalk/example-config.yaml +++ b/matrix_appservice_kakaotalk/example-config.yaml @@ -122,7 +122,7 @@ bridge: command_prefix: "!kt" # Number of chats to sync (and create portals for) on startup/login. - # Set 0 to disable automatic syncing. + # Set to 0 to disable automatic syncing, or -1 to sync as much as possible. initial_chat_sync: 20 # Whether or not the KakaoTalk users of logged in Matrix users should be # invited to private chats when the user sends a message from another client. @@ -188,11 +188,11 @@ bridge: # usually needed to prevent rate limits and to allow timestamp massaging. invite_own_puppet: true # Maximum number of messages to backfill initially. - # Set to 0 to disable backfilling when creating portal. + # Set to 0 to disable backfilling when creating portal, or -1 to backfill as much as possible. initial_limit: 0 # Maximum number of messages to backfill if messages were missed while # the bridge was disconnected. - # Set to 0 to disable backfilling missed messages. + # Set to 0 to disable backfilling missed messages, or -1 to backfill as much as possible. missed_limit: 1000 # If using double puppeting, should notifications be disabled # while the initial backfill is in progress? @@ -213,7 +213,7 @@ bridge: # The number of seconds that a disconnection can last without triggering an automatic re-sync # and missed message backfilling when reconnecting. # Set to 0 to always re-sync, or -1 to never re-sync automatically. - resync_max_disconnected_time: 5 + #resync_max_disconnected_time: 5 # Should the bridge do a resync on startup? sync_on_startup: true # Whether or not temporary disconnections should send notices to the notice room. diff --git a/matrix_appservice_kakaotalk/kt/client/client.py b/matrix_appservice_kakaotalk/kt/client/client.py index 7ee83cb..b231993 100644 --- a/matrix_appservice_kakaotalk/kt/client/client.py +++ b/matrix_appservice_kakaotalk/kt/client/client.py @@ -22,7 +22,7 @@ with any other potential backend. from __future__ import annotations -from typing import TYPE_CHECKING, cast, Awaitable, Callable, Type, Optional, Union +from typing import TYPE_CHECKING, cast, Type, Optional, Union import logging import urllib.request @@ -39,7 +39,6 @@ from ...rpc import RPCClient from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct from ..types.bson import Long from ..types.client.client_session import LoginResult -from ..types.channel.channel_type import ChannelType from ..types.chat.chat import Chatlog from ..types.oauth import OAuthCredential, OAuthInfo from ..types.request import ( @@ -63,7 +62,6 @@ except ImportError: if TYPE_CHECKING: from mautrix.types import JSON from ...user import User - from ...rpc.rpc import EventHandler # TODO Consider defining an interface for this, with node/native backend as swappable implementations @@ -80,7 +78,6 @@ class Client: @classmethod async def stop_cls(cls) -> None: """Stop and disconnect from the Node backend.""" - await cls._rpc_client.request("stop") await cls._rpc_client.disconnect() @@ -102,16 +99,17 @@ class Client: ) @classmethod - async def register_device(cls, passcode: str, **req) -> None: + async def register_device(cls, passcode: str, **req: JSON) -> None: """Register a (fake) device that will be associated with the provided login credentials.""" await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req) @classmethod - async def login(cls, **req) -> OAuthCredential: + async def login(cls, **req: JSON) -> OAuthCredential: """ Obtain a session token by logging in with user-provided credentials. Must have first called register_device with these credentials. """ + # NOTE Actually returns a LoginData object, but this only needs an OAuthCredential return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req) # endregion @@ -181,20 +179,21 @@ class Client: self.user.oauth_credential = oauth_info.credential await self.user.save() - async def start(self) -> LoginResult: + async def connect(self) -> LoginResult: """ Start a new session by providing a token obtained from a prior login. Receive a snapshot of account state in response. """ - login_result = await self._api_user_request_result(LoginResult, "start") + login_result = await self._api_user_request_result(LoginResult, "connect") assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" + # TODO Skip if handlers are already listening. But this is idempotent and thus probably safe + self._start_listen() return login_result - """ - async def is_connected(self) -> bool: - resp = await self._rpc_client.request("is_connected") - return resp["is_connected"] - """ + async def disconnect(self) -> bool: + connection_existed = await self._rpc_client.request("disconnect", mxid=self.user.mxid) + self._stop_listen() + return connection_existed async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct: profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") @@ -232,14 +231,6 @@ class Client: text=text ) - async def start_listen(self) -> None: - # TODO Connect all listeners here? - await self._api_user_request_void("start_listen") - - async def stop(self) -> None: - # TODO Stop all event handlers - await self._api_user_request_void("stop") - # TODO Combine these into one @@ -272,19 +263,31 @@ class Client: # region listeners - async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None: - async def wrapper(data: dict[str, JSON]) -> None: - await func( - Chatlog.deserialize(data["chatlog"]), - Long.deserialize(data["channelId"]), - data["channelType"] - ) + async def _on_message(self, data: dict[str, JSON]) -> None: + await self.user.on_message( + Chatlog.deserialize(data["chatlog"]), + Long.deserialize(data["channelId"]), + data["channelType"] + ) - self._add_user_handler("message", wrapper) + """ TODO + async def _on_receipt(self, data: Dict[str, JSON]) -> None: + await self.user.on_receipt(Receipt.deserialize(data["receipt"])) + """ - def _add_user_handler(self, command: str, handler: EventHandler) -> str: - self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler) + def _start_listen(self) -> None: + # TODO Automate this somehow, like with a fancy enum + self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [self._on_message]) + # TODO many more listeners + + def _stop_listen(self) -> None: + # TODO Automate this somehow, like with a fancy enum + self._rpc_client.set_event_handlers(self._get_user_cmd("message"), []) + # TODO many more listeners + + def _get_user_cmd(self, command) -> str: + return f"{command}:{self.user.mxid}" # endregion diff --git a/matrix_appservice_kakaotalk/kt/types/request.py b/matrix_appservice_kakaotalk/kt/types/request.py index eb862a0..ac92693 100644 --- a/matrix_appservice_kakaotalk/kt/types/request.py +++ b/matrix_appservice_kakaotalk/kt/types/request.py @@ -79,6 +79,16 @@ class RootCommandResult(ResponseState): """For brevity, this also encompasses CommandResultFailed and CommandResultDoneVoid""" success: bool + @classmethod + def deserialize(cls, data: JSON) -> "RootCommandResult": + if not data or "success" not in data or "status" not in data: + return RootCommandResult( + success=True, + status=KnownDataStatusCode.SUCCESS + ) + else: + return super().deserialize(data) + ResultType = TypeVar("ResultType", bound=Serializable) diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py index 5b39445..6e7fbf9 100644 --- a/matrix_appservice_kakaotalk/portal.py +++ b/matrix_appservice_kakaotalk/portal.py @@ -16,7 +16,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, AsyncGenerator, Pattern, cast -from collections import deque import asyncio import re import time @@ -845,11 +844,6 @@ class Portal(DBPortal, BasePortal): return False return True - async def _add_kakaotalk_reply( - self, content: MessageEventContent, reply_to: None - ) -> None: - self.log.info("TODO") - async def handle_remote_message( self, source: u.User, @@ -969,7 +963,7 @@ class Portal(DBPortal, BasePortal): if not messages: self.log.debug("Didn't get any messages from server") return - self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server") + self.log.debug(f"Got {len(messages)} message{'s' if len(messages) > 1 else ''} from server") self._backfill_leave = set() async with NotificationDisabler(self.mxid, source): for message in messages: diff --git a/matrix_appservice_kakaotalk/rpc/__init__.py b/matrix_appservice_kakaotalk/rpc/__init__.py index 33792bd..f5125b1 100644 --- a/matrix_appservice_kakaotalk/rpc/__init__.py +++ b/matrix_appservice_kakaotalk/rpc/__init__.py @@ -1 +1,2 @@ from .rpc import RPCClient +from .types import RPCError diff --git a/matrix_appservice_kakaotalk/rpc/rpc.py b/matrix_appservice_kakaotalk/rpc/rpc.py index 3820317..59d8d84 100644 --- a/matrix_appservice_kakaotalk/rpc/rpc.py +++ b/matrix_appservice_kakaotalk/rpc/rpc.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from __future__ import annotations -from typing import Any, Callable, Awaitable, List +from typing import Any, Callable, Awaitable import asyncio import json @@ -39,7 +39,7 @@ class RPCClient: _req_id: int _min_broadcast_id: int _response_waiters: dict[int, asyncio.Future[JSON]] - _event_handlers: dict[str, List[EventHandler]] + _event_handlers: dict[str, list[EventHandler]] _command_queue: asyncio.Queue def __init__(self, config: Config) -> None: @@ -98,7 +98,13 @@ class RPCClient: self._event_handlers.setdefault(method, []).append(handler) def remove_event_handler(self, method: str, handler: EventHandler) -> None: - self._event_handlers.setdefault(method, []).remove(handler) + try: + self._event_handlers.setdefault(method, []).remove(handler) + except ValueError: + pass + + def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None: + self._event_handlers[method] = handlers async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None: if req_id > self._min_broadcast_id: diff --git a/matrix_appservice_kakaotalk/user.py b/matrix_appservice_kakaotalk/user.py index 136d647..b8b7c89 100644 --- a/matrix_appservice_kakaotalk/user.py +++ b/matrix_appservice_kakaotalk/user.py @@ -46,9 +46,7 @@ from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult from .kt.types.oauth import OAuthCredential from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo -METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels") -METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync") -METRIC_UNKNOWN_EVENT = Summary("bridge_on_unknown_event", "calls to on_unknown_event") +METRIC_CONNECT_AND_SYNC = Summary("bridge_sync_channels", "calls to connect_and_sync") METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added") METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed") METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing") @@ -58,7 +56,6 @@ METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent") METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen") METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change") METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change") -METRIC_THREAD_CHANGE = Summary("bridge_on_thread_change", "calls to on_thread_change") METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message") METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge") METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk") @@ -71,7 +68,6 @@ BridgeState.human_readable_errors.update( "kt-reconnection-error": "Failed to reconnect to KakaoTalk", "kt-connection-error": "KakaoTalk disconnected unexpectedly", "kt-auth-error": "Authentication error from KakaoTalk: {message}", - "kt-start-error": "Startup error from KakaoTalk: {message}", "kt-disconnected": None, "logged-out": "You're not logged into KakaoTalk", } @@ -79,7 +75,6 @@ BridgeState.human_readable_errors.update( class User(DBUser, BaseUser): - #temp_disconnect_notices: bool = True shutdown: bool = False config: Config @@ -90,16 +85,13 @@ class User(DBUser, BaseUser): _notice_room_lock: asyncio.Lock _notice_send_lock: asyncio.Lock - command_status: dict | None is_admin: bool permission_level: str _is_logged_in: bool | None - #_is_connected: bool | None - #_connection_time: float - _prev_reconnect_fail_refresh: float + _is_connected: bool | None + _connection_time: float _db_instance: DBUser | None _sync_lock: SimpleLock - _is_refreshing: bool _logged_in_info: ProfileStruct | None _logged_in_info_time: float @@ -124,7 +116,6 @@ class User(DBUser, BaseUser): self.notice_room = notice_room self._notice_room_lock = asyncio.Lock() self._notice_send_lock = asyncio.Lock() - self.command_status = None ( self.relay_whitelisted, self.is_whitelisted, @@ -132,13 +123,11 @@ class User(DBUser, BaseUser): self.permission_level, ) = self.config.get_permissions(mxid) self._is_logged_in = None - #self._is_connected = None - #self._connection_time = time.monotonic() - self._prev_reconnect_fail_refresh = time.monotonic() + self._is_connected = None + self._connection_time = time.monotonic() self._sync_lock = SimpleLock( "Waiting for thread sync to finish before handling %s", log=self.log ) - self._is_refreshing = False self._logged_in_info = None self._logged_in_info_time = 0 @@ -150,10 +139,8 @@ class User(DBUser, BaseUser): cls.config = bridge.config cls.az = bridge.az cls.loop = bridge.loop - #cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"] return (user.reload_session(is_startup=True) async for user in cls.all_logged_in()) - """ @property def is_connected(self) -> bool | None: return self._is_connected @@ -167,11 +154,11 @@ class User(DBUser, BaseUser): @property def connection_time(self) -> float: return self._connection_time - """ @property def has_state(self) -> bool: - return bool(self.uuid and self.ktid and self.access_token and self.refresh_token) + # TODO If more state is needed, consider returning a saved LoginResult + return bool(self.access_token and self.refresh_token) # region Database getters @@ -233,11 +220,13 @@ class User(DBUser, BaseUser): async def get_uuid(self, force: bool = False) -> str: if self.uuid is None or force: self.uuid = await self._generate_uuid() + # TODO Maybe don't save yet await self.save() return self.uuid - async def _generate_uuid(self) -> str: - return await Client.generate_uuid(await self.get_all_uuids()) + @classmethod + async def _generate_uuid(cls) -> str: + return await Client.generate_uuid(await super().get_all_uuids()) # endregion @@ -285,8 +274,7 @@ class User(DBUser, BaseUser): self._logged_in_info_time = time.monotonic() self._track_metric(METRIC_LOGGED_IN, True) self._is_logged_in = True - #self.is_connected = None - self.stop_listen() + self.is_connected = None asyncio.create_task(self.post_login(is_startup=is_startup)) return True return False @@ -310,10 +298,9 @@ class User(DBUser, BaseUser): ) -> ProfileStruct: if not client: client = self.client - # TODO Retry network connection failures here, or in the client? + # TODO Retry network connection failures here, or in the client (like token refreshes are)? try: return await client.fetch_logged_in_user() - # NOTE Not catching InvalidAccessToken here, as client handles it & tries to refresh the token except AuthenticationRequired as e: if action != "restore session": await self._send_reset_notice(e) @@ -352,7 +339,7 @@ class User(DBUser, BaseUser): state_event=BridgeStateEvent.TRANSIENT_DISCONNECT, ) await asyncio.sleep(60) - await self.reload_session(event_id, retries - 1) + await self.reload_session(event_id, retries - 1, is_startup) else: await self.send_bridge_notice( notice, @@ -362,44 +349,43 @@ class User(DBUser, BaseUser): error_code="kt-reconnection-error", ) except Exception: + self.log.exception("Error connecting to KakaoTalk") await self.send_bridge_notice( "Failed to connect to KakaoTalk: unknown error (see logs for more details)", edit=event_id, state_event=BridgeStateEvent.UNKNOWN_ERROR, error_code="kt-reconnection-error", ) - finally: - self._is_refreshing = False - async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> bool: - ok = True - self.stop_listen() - if self.has_state: - # TODO Log out of KakaoTalk if an API exists for it - pass + async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None: + if self.client: + # TODO Look for a logout API call + was_connected = await self.client.disconnect() + if was_connected != self._is_connected: + self.log.warn( + f"Node backend was{' not' if not was_connected else ''} connected, " + f"but we thought it was{' not' if not self._is_connected else ''}") if remove_ktid: await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT) self._track_metric(METRIC_LOGGED_IN, False) - self._is_logged_in = False - #self.is_connected = None - if self.client: - await self.client.stop() - self.client = None - - if self.ktid and remove_ktid: - #await UserPortal.delete_all(self.ktid) - del self.by_ktid[self.ktid] - self.ktid = None if reset_device: self.uuid = await self._generate_uuid() self.access_token = None self.refresh_token = None - await self.save() - return ok + self._is_logged_in = False + self.is_connected = None + self.client = None - async def post_login(self, is_startup: bool) -> bool: + if self.ktid and remove_ktid: + #await UserPortal.delete_all(self.ktid) + del self.by_ktid[self.ktid] + self.ktid = None + + await self.save() + + async def post_login(self, is_startup: bool) -> None: self.log.info("Running post-login actions") self._add_to_cache() @@ -412,12 +398,27 @@ class User(DBUser, BaseUser): except Exception: self.log.exception("Failed to automatically enable custom puppet") - assert self.client + # TODO Check if things break when a live message comes in during syncing + if self.config["bridge.sync_on_startup"] or not is_startup: + sync_count = self.config["bridge.initial_chat_sync"] + else: + sync_count = None + await self.connect_and_sync(sync_count) + + async def get_direct_chats(self) -> dict[UserID, list[RoomID]]: + return { + pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid] + async for portal in po.Portal.get_all_by_receiver(self.ktid) + if portal.mxid + } + + @async_time(METRIC_CONNECT_AND_SYNC) + async def connect_and_sync(self, sync_count: int | None) -> bool: + # TODO Look for a way to sync all channels without (re-)logging in try: - # TODO if not is_startup, close existing listeners - login_result = await self.client.start() - await self._sync_channels(login_result, is_startup) - self.start_listen() + login_result = await self.client.connect() + await self.push_bridge_state(BridgeStateEvent.CONNECTED) + await self._sync_channels(login_result, sync_count) return True except AuthenticationRequired as e: await self.send_bridge_notice( @@ -429,41 +430,34 @@ class User(DBUser, BaseUser): ) await self.logout(remove_ktid=False) except Exception as e: - self.log.exception("Failed to start client") - await self.send_bridge_notice( - f"Got error from KakaoTalk:\n\n> {e!s}\n\n", - important=True, - state_event=BridgeStateEvent.UNKNOWN_ERROR, - error_code="kt-start-error", - error_message=str(e), - ) + self.log.exception("Failed to connect and sync") + await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e)) return False - async def get_direct_chats(self) -> dict[UserID, list[RoomID]]: - return { - pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid] - async for portal in po.Portal.get_all_by_receiver(self.ktid) - if portal.mxid - } - - @async_time(METRIC_SYNC_CHANNELS) - async def _sync_channels(self, login_result: LoginResult, is_startup: bool) -> None: - # TODO Look for a way to sync all channels without (re-)logging in - sync_count = self.config["bridge.initial_chat_sync"] - if sync_count <= 0 or not self.config["bridge.sync_on_startup"] and is_startup: - self.log.debug(f"Skipping channel syncing{' on startup' if sync_count > 0 else ''}") + async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None: + if sync_count is None: + sync_count = self.config["bridge.initial_chat_sync"] + if not sync_count: + self.log.debug("Skipping channel syncing") return if not login_result.channelList: self.log.debug("No channels to sync") return # TODO What about removed channels? Don't early-return then - sync_count = min(sync_count, len(login_result.channelList)) + num_channels = len(login_result.channelList) + sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels) await self.push_bridge_state(BridgeStateEvent.BACKFILLING) - self.log.debug(f"Syncing {sync_count} of {len(login_result.channelList)} channels...") + self.log.debug(f"Syncing {sync_count} of {num_channels} channels...") for channel_item in login_result.channelList[:sync_count]: - # TODO try-except here, above, below? - await self._sync_channel(channel_item) + try: + await self._sync_channel(channel_item) + except AuthenticationRequired: + raise + except Exception: + self.log.exception(f"Failed to sync channel {channel_item.channel.channelId}") + + await self.update_direct_chats() async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None: channel_data = channel_item.channel @@ -573,18 +567,12 @@ class User(DBUser, BaseUser): state.remote_name = puppet.name async def get_bridge_states(self) -> list[BridgeState]: - self.log.info("TODO: get_bridge_states") - return [] - """ if not self.state: return [] state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR) if self.is_connected: state.state_event = BridgeStateEvent.CONNECTED - elif self._is_refreshing or self.mqtt: - state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT return [state] - """ async def get_puppet(self) -> pu.Puppet | None: if not self.ktid: @@ -593,36 +581,10 @@ class User(DBUser, BaseUser): # region KakaoTalk event handling - def start_listen(self) -> None: - self.listen_task = asyncio.create_task(self._try_listen()) - - def _disconnect_listener_after_error(self) -> None: - self.log.info("TODO: _disconnect_listener_after_error") - - async def _try_listen(self) -> None: - try: - # TODO Pass all listeners to start_listen instead of registering them one-by-one? - await self.client.start_listen() - await self.client.on_message(self.on_message) - # TODO Handle auth errors specially? - #except AuthenticationRequired as e: - except Exception: - #self.is_connected = False - self.log.exception("Fatal error in listener") - await self.send_bridge_notice( - "Fatal error in listener (see logs for more info)", - state_event=BridgeStateEvent.UNKNOWN_ERROR, - important=True, - error_code="kt-connection-error", - ) - self._disconnect_listener_after_error() - - def stop_listen(self) -> None: - self.log.info("TODO: stop_listen") - async def on_logged_in(self, oauth_credential: OAuthCredential) -> None: self.log.debug(f"Successfully logged in as {oauth_credential.userId}") self.oauth_credential = oauth_credential + await self.push_bridge_state(BridgeStateEvent.CONNECTING) self.client = Client(self, log=self.log.getChild("ktclient")) await self.save() self._is_logged_in = True @@ -631,7 +593,6 @@ class User(DBUser, BaseUser): self._logged_in_info_time = time.monotonic() except Exception: self.log.exception("Failed to fetch post-login info") - self.stop_listen() asyncio.create_task(self.post_login(is_startup=True)) @async_time(METRIC_MESSAGE) diff --git a/node/src/client.js b/node/src/client.js index 6c203ac..f05d8de 100644 --- a/node/src/client.js +++ b/node/src/client.js @@ -14,7 +14,6 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . import { Long } from "bson" -import { emitLines, promisify } from "./util.js" import { AuthApiClient, OAuthApiClient, @@ -28,6 +27,8 @@ const { KnownChatType } = chat /** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */ /** @typedef {import("./clientmanager.js").default} ClientManager} */ +import { emitLines, promisify } from "./util.js" + class UserClient { static #initializing = false @@ -36,7 +37,7 @@ class UserClient { get talkClient() { return this.#talkClient } /** @type {ServiceApiClient} */ - #serviceClient = null + #serviceClient get serviceClient() { return this.#serviceClient } /** @@ -80,7 +81,6 @@ class UserClient { } export default class PeerClient { - /** * @param {ClientManager} manager * @param {import("net").Socket} socket @@ -136,6 +136,7 @@ export default class PeerClient { return } this.stopped = true + this.#closeUsers() try { await this.#write({ id: --this.notificationID, command: "quit", error }) await promisify(cb => this.socket.end(cb)) @@ -145,20 +146,21 @@ export default class PeerClient { } } - handleEnd = async () => { - // TODO Persist clients across bridge disconnections. - // But then have to queue received events until bridge acks them! + handleEnd = () => { + this.stopped = true + this.#closeUsers() + if (this.peerID && this.manager.clients.get(this.peerID) === this) { + this.manager.clients.delete(this.peerID) + } + this.log(`Connection closed (peer: ${this.peerID})`) + } + + #closeUsers() { this.log("Closing all API clients for", this.peerID) for (const userClient of this.userClients.values()) { userClient.close() } this.userClients.clear() - - this.stopped = true - if (this.peerID && this.manager.clients.get(this.peerID) === this) { - this.manager.clients.delete(this.peerID) - } - this.log(`Connection closed (peer: ${this.peerID})`) } /** @@ -178,6 +180,7 @@ export default class PeerClient { * @param {Object} req.form */ registerDevice = async (req) => { + // TODO Look for a deregister API call const authClient = await this.#createAuthClient(req.uuid) return await authClient.registerDevice(req.form, req.passcode, true) } @@ -192,6 +195,7 @@ export default class PeerClient { * request failed, its status is stored here. */ handleLogin = async (req) => { + // TODO Look for a logout API call const authClient = await this.#createAuthClient(req.uuid) const loginRes = await authClient.login(req.form, true) if (loginRes.status === KnownAuthStatusCode.DEVICE_NOT_REGISTERED) { @@ -226,6 +230,15 @@ export default class PeerClient { return userClient } + /** + * Unchecked lookup of a UserClient for a given mxid. + * @param {string} mxid + * @returns {UserClient | undefined} + */ + #tryGetUser(mxid) { + return this.userClients.get(mxid) + } + /** * Get the service client for the specified user ID, or create * and return a new service client if no user ID is provided. @@ -233,7 +246,7 @@ export default class PeerClient { * @param {OAuthCredential} oauth_credential */ async #getServiceClient(mxid, oauth_credential) { - return this.userClients.get(mxid)?.serviceClient || + return this.#tryGetUser(mxid)?.serviceClient || await ServiceApiClient.create(oauth_credential) } @@ -251,26 +264,15 @@ export default class PeerClient { * @param {string} req.mxid * @param {OAuthCredential} req.oauth_credential */ - handleStart = async (req) => { + handleConnect = async (req) => { // TODO Don't re-login if possible. But must still return a LoginResult! - { - const oldUserClient = this.userClients.get(req.mxid) - if (oldUserClient !== undefined) { - oldUserClient.close() - this.userClients.delete(req.mxid) - } - } + this.handleDisconnect(req) const userClient = await UserClient.create(req.mxid, req.oauth_credential) const res = await userClient.talkClient.login(req.oauth_credential) if (!res.success) return res this.userClients.set(req.mxid, userClient) - return res - } - - startListen = async (req) => { - const userClient = this.#getUser(req.mxid) userClient.talkClient.on("chat", (data, channel) => { this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`) @@ -291,7 +293,22 @@ export default class PeerClient { }) */ - return this.#voidCommandResult + return res + } + + /** + * @param {Object} req + * @param {string} req.mxid + */ + handleDisconnect = (req) => { + const userClient = this.#tryGetUser(req.mxid) + if (!!userClient) { + userClient.close() + this.userClients.delete(req.mxid) + return true + } else { + return false + } } /** @@ -368,16 +385,6 @@ export default class PeerClient { }) } - /** - * @param {Object} req - * @param {string} req.mxid - */ - handleStop = async (req) => { - this.#getUser(req.mxid).close() - this.userClients.delete(req.mxid) - return this.#voidCommandResult - } - #makeCommandResult(result) { return { success: true, @@ -386,11 +393,6 @@ export default class PeerClient { } } - #voidCommandResult = { - success: true, - status: 0, - } - handleUnknownCommand = () => { throw new Error("Unknown command") } @@ -431,9 +433,6 @@ export default class PeerClient { this.log("Ignoring old request", req.id) return } - if (req.command != "is_connected") { - this.log("Received request", req.id, "with command", req.command) - } this.maxCommandID = req.id let handler if (!this.peerID) { @@ -449,21 +448,18 @@ export default class PeerClient { handler = this.handleRegister } else { handler = { - // TODO Subclass / object for KakaoTalk-specific handlers? - start: this.handleStart, - stop: this.handleStop, - disconnect: () => this.stop(), - login: this.handleLogin, - renew: this.handleRenew, + // TODO Wrapper for per-user commands generate_uuid: util.randomAndroidSubDeviceUUID, register_device: this.registerDevice, - start_listen: this.startListen, + login: this.handleLogin, + renew: this.handleRenew, + connect: this.handleConnect, + disconnect: this.handleDisconnect, get_own_profile: this.getOwnProfile, + get_profile: this.getProfile, get_portal_channel_info: this.getPortalChannelInfo, get_chats: this.getChats, - get_profile: this.getProfile, send_message: this.sendMessage, - //is_connected: async () => ({ is_connected: !await this.puppet.isDisconnected() }), }[req.command] || this.handleUnknownCommand } const resp = { id: req.id } @@ -482,6 +478,7 @@ export default class PeerClient { resp.command = "error" resp.error = err.toString() this.log(`Error handling request ${resp.id} ${err.stack}`) + // TODO Check if session is broken. If it is, close the PeerClient } } await this.#write(resp) diff --git a/node/src/clientmanager.js b/node/src/clientmanager.js index c993d26..988dc63 100644 --- a/node/src/clientmanager.js +++ b/node/src/clientmanager.js @@ -20,6 +20,7 @@ import path from "path" import PeerClient from "./client.js" import { promisify } from "./util.js" + export default class ClientManager { constructor(listenConfig) { this.listenConfig = listenConfig diff --git a/node/src/main.js b/node/src/main.js index 6fcc49a..d258d9f 100644 --- a/node/src/main.js +++ b/node/src/main.js @@ -15,12 +15,13 @@ // along with this program. If not, see . import process from "process" import fs from "fs" -import sd from "systemd-daemon" import arg from "arg" +import sd from "systemd-daemon" import ClientManager from "./clientmanager.js" + const args = arg({ "--config": String, "-c": "--config", @@ -31,10 +32,10 @@ const configPath = args["--config"] || "config.json" console.log("[Main] Reading config from", configPath) const config = JSON.parse(fs.readFileSync(configPath).toString()) -const api = new ClientManager(config.listen) +const manager = new ClientManager(config.listen) function stop() { - api.stop().then(() => { + manager.stop().then(() => { console.log("[Main] Everything stopped") process.exit(0) }, err => { @@ -43,7 +44,7 @@ function stop() { }) } -api.start().then(() => { +manager.start().then(() => { process.once("SIGINT", stop) process.once("SIGTERM", stop) sd.notify("READY=1") diff --git a/requirements.txt b/requirements.txt index 8ea98d6..ca5978a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ aiohttp>=3,<4 asyncpg>=0.20,<0.26 -bson>=0.5,<0.6 commonmark>=0.8,<0.10 mautrix==0.15.0rc4 pycryptodome>=3,<4