From 8ac16e00fcf6cc346fa24de431eb44e7d7968c51 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 8 Apr 2022 05:04:46 -0400 Subject: [PATCH] Handle disconnections from the node module Make the bridge module auto-reconnect to the node module in case the latter ever crashes (or is started after the bridge module). Also work towards more holistic auto-reconnect logic in general. --- matrix_appservice_kakaotalk/__main__.py | 8 +- matrix_appservice_kakaotalk/config.py | 30 +---- .../example-config.yaml | 7 +- .../kt/client/client.py | 39 +++++- matrix_appservice_kakaotalk/rpc/rpc.py | 123 ++++++++++++++++-- matrix_appservice_kakaotalk/user.py | 32 ++++- 6 files changed, 182 insertions(+), 57 deletions(-) diff --git a/matrix_appservice_kakaotalk/__main__.py b/matrix_appservice_kakaotalk/__main__.py index 8a03485..301d04e 100644 --- a/matrix_appservice_kakaotalk/__main__.py +++ b/matrix_appservice_kakaotalk/__main__.py @@ -67,19 +67,17 @@ class KakaoTalkBridge(Bridge): self.public_website = None def prepare_stop(self) -> None: + self.log.debug("Stopping RPC connection") + KakaoTalkClient.stop_cls() self.log.debug("Stopping puppet syncers") for puppet in Puppet.by_custom_mxid.values(): puppet.stop() self.log.debug("Stopping kakaotalk listeners") User.shutdown = True self.add_shutdown_actions(user.save() for user in User.by_mxid.values()) - self.add_shutdown_actions(KakaoTalkClient.stop_cls()) async def start(self) -> None: - # Block all other startup actions until RPC is ready - # TODO Remove when/if node backend is replaced with native - await KakaoTalkClient.init_cls(self.config) - + KakaoTalkClient.init_cls(self.config) self.add_startup_actions(User.init_cls(self)) self.add_startup_actions(Puppet.init_cls(self)) Portal.init_cls(self) diff --git a/matrix_appservice_kakaotalk/config.py b/matrix_appservice_kakaotalk/config.py index d844471..fe2c6c2 100644 --- a/matrix_appservice_kakaotalk/config.py +++ b/matrix_appservice_kakaotalk/config.py @@ -23,7 +23,6 @@ from mautrix.types import UserID from mautrix.util.config import ConfigUpdateHelper, ForbiddenDefault, ForbiddenKey -# TODO Remove unneeded configs!! class Config(BaseBridgeConfig): def __getitem__(self, key: str) -> Any: try: @@ -94,31 +93,15 @@ class Config(BaseBridgeConfig): copy("bridge.backfill.initial_limit") copy("bridge.backfill.missed_limit") copy("bridge.backfill.disable_notifications") - if "bridge.periodic_reconnect_interval" in self: - base["bridge.periodic_reconnect.interval"] = self["bridge.periodic_reconnect_interval"] - base["bridge.periodic_reconnect.mode"] = self["bridge.periodic_reconnect_mode"] - else: - copy("bridge.periodic_reconnect.interval") - copy("bridge.periodic_reconnect.mode") - copy("bridge.periodic_reconnect.always") - copy("bridge.periodic_reconnect.min_connected_time") + """ TODO + copy("bridge.periodic_reconnect.interval") + copy("bridge.periodic_reconnect.always") + copy("bridge.periodic_reconnect.min_connected_time") + """ copy("bridge.resync_max_disconnected_time") copy("bridge.sync_on_startup") copy("bridge.temporary_disconnect_notices") copy("bridge.disable_bridge_notices") - if "bridge.refresh_on_reconnection_fail" in self: - base["bridge.on_reconnection_fail.action"] = ( - "refresh" if self["bridge.refresh_on_reconnection_fail"] else None - ) - base["bridge.on_reconnection_fail.wait_for"] = 0 - elif "bridge.on_reconnection_fail.refresh" in self: - base["bridge.on_reconnection_fail.action"] = ( - "refresh" if self["bridge.on_reconnection_fail.refresh"] else None - ) - copy("bridge.on_reconnection_fail.wait_for") - else: - copy("bridge.on_reconnection_fail.action") - copy("bridge.on_reconnection_fail.wait_for") copy("bridge.resend_bridge_info") copy("bridge.mute_bridging") copy("bridge.tag_only_on_create") @@ -126,13 +109,14 @@ class Config(BaseBridgeConfig): copy_dict("bridge.permissions") + """ TODO for key in ( "bridge.periodic_reconnect.interval", - "bridge.on_reconnection_fail.wait_for", ): value = base.get(key, None) if isinstance(value, list) and len(value) != 2: raise ValueError(f"{key} must only be a list of two items") + """ copy("rpc.connection.type") if base["rpc.connection.type"] == "unix": diff --git a/matrix_appservice_kakaotalk/example-config.yaml b/matrix_appservice_kakaotalk/example-config.yaml index 632344d..b05ecb1 100644 --- a/matrix_appservice_kakaotalk/example-config.yaml +++ b/matrix_appservice_kakaotalk/example-config.yaml @@ -200,15 +200,13 @@ bridge: # If using double puppeting, should notifications be disabled # while the initial backfill is in progress? disable_notifications: false - # TODO Confirm this isn't needed + # TODO Implement this #periodic_reconnect: # # Interval in seconds in which to automatically reconnect all users. - # # This can be used to automatically mitigate the bug where KakaoTalk stops sending messages. + # # This may prevent KakaoTalk from "switching servers". # # Set to -1 to disable periodic reconnections entirely. # # Set to a list of two items to randomize the interval (min, max). # interval: -1 - # # What to do in periodic reconnects. Either "refresh" or "reconnect" - # mode: refresh # # Should even disconnected users be reconnected? # always: false # # Only reconnect if the user has been connected for longer than this value @@ -216,6 +214,7 @@ bridge: # The number of seconds that a disconnection can last without triggering an automatic re-sync # and missed message backfilling when reconnecting. # Set to 0 to always re-sync, or -1 to never re-sync automatically. + # TODO Actually use this setting resync_max_disconnected_time: 5 # Should the bridge do a resync on startup? sync_on_startup: true diff --git a/matrix_appservice_kakaotalk/kt/client/client.py b/matrix_appservice_kakaotalk/kt/client/client.py index 99fc427..dfea810 100644 --- a/matrix_appservice_kakaotalk/kt/client/client.py +++ b/matrix_appservice_kakaotalk/kt/client/client.py @@ -22,7 +22,8 @@ with any other potential backend. from __future__ import annotations -from typing import TYPE_CHECKING, cast, Type, Optional, Union +from typing import TYPE_CHECKING, cast, ClassVar, Type, Optional, Union +import asyncio from contextlib import asynccontextmanager import logging import urllib.request @@ -64,7 +65,7 @@ except ImportError: if TYPE_CHECKING: from mautrix.types import JSON - from ...user import User + from ... import user as u @asynccontextmanager @@ -79,15 +80,22 @@ class Client: _rpc_client: RPCClient @classmethod - async def init_cls(cls, config: Config) -> None: + def init_cls(cls, config: Config) -> None: """Initialize RPC to the Node backend.""" cls._rpc_client = RPCClient(config) - await cls._rpc_client.connect() + # NOTE No need to store this, as cancelling the RPCClient will cancel this too + asyncio.create_task(cls._keep_connected()) @classmethod - async def stop_cls(cls) -> None: + async def _keep_connected(cls) -> None: + while True: + await cls._rpc_client.connect() + await cls._rpc_client.wait_for_disconnection() + + @classmethod + def stop_cls(cls) -> None: """Stop and disconnect from the Node backend.""" - await cls._rpc_client.disconnect() + cls._rpc_client.cancel() # region tokenless commands @@ -124,12 +132,15 @@ class Client: # endregion + user: u.User + _rpc_disconnection_task: asyncio.Task | None http: ClientSession log: TraceLogger - def __init__(self, user: User, log: Optional[TraceLogger] = None): + def __init__(self, user: u.User, log: Optional[TraceLogger] = None): """Create a per-user client object for user-specific client functionality.""" self.user = user + self._rpc_disconnection_task = None # TODO Let the Node backend use a proxy too! connector = None @@ -188,13 +199,27 @@ class Client: Receive the user's profile info in response. """ profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "start") + if not self._rpc_disconnection_task: + self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler()) + else: + self.log.warning("Called \"start\" on an already-started client") return profile_req_struct.profile async def stop(self) -> None: """Immediately stop bridging this user.""" self._stop_listen() + if self._rpc_disconnection_task: + self._rpc_disconnection_task.cancel() + else: + self.log.warning("Called \"stop\" on an already-stopped client") await self._rpc_client.request("stop", mxid=self.user.mxid) + async def _rpc_disconnection_handler(self) -> None: + await self._rpc_client.wait_for_disconnection() + self._rpc_disconnection_task = None + self._stop_listen() + asyncio.create_task(self.user.on_client_disconnect()) + async def renew_and_save(self) -> None: """Renew and save the user's session tokens.""" oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential) diff --git a/matrix_appservice_kakaotalk/rpc/rpc.py b/matrix_appservice_kakaotalk/rpc/rpc.py index 80d6bb5..6c5bbf2 100644 --- a/matrix_appservice_kakaotalk/rpc/rpc.py +++ b/matrix_appservice_kakaotalk/rpc/rpc.py @@ -29,6 +29,46 @@ from .types import RPCError EventHandler = Callable[[dict[str, Any]], Awaitable[None]] +class CancelableEvent: + _event: asyncio.Event + _task: asyncio.Task | None + _cancelled: bool + _loop: asyncio.AbstractEventLoop + + def __init__(self, loop: asyncio.AbstractEventLoop | None): + self._event = asyncio.Event() + self._task = None + self._cancelled = False + self._loop = loop or asyncio.get_running_loop() + + def is_set(self) -> bool: + return self._event.is_set() + + def set(self) -> None: + self._event.set() + self._task = None + + def clear(self) -> None: + self._event.clear() + + async def wait(self) -> None: + if self._cancelled: + raise asyncio.CancelledError() + if self._event.is_set(): + return + if not self._task: + self._task = asyncio.create_task(self._event.wait()) + await self._task + + def cancel(self) -> None: + self._cancelled = True + if self._task is not None: + self._task.cancel() + + def cancelled(self) -> bool: + return self._cancelled + + class RPCClient: config: Config loop: asyncio.AbstractEventLoop @@ -41,6 +81,11 @@ class RPCClient: _response_waiters: dict[int, asyncio.Future[JSON]] _event_handlers: dict[str, list[EventHandler]] _command_queue: asyncio.Queue + _read_task: asyncio.Task | None + _connection_task: asyncio.Task | None + _is_connected: CancelableEvent + _is_disconnected: CancelableEvent + _connection_lock: asyncio.Lock def __init__(self, config: Config) -> None: self.config = config @@ -52,16 +97,34 @@ class RPCClient: self._writer = None self._reader = None self._command_queue = asyncio.Queue() + self.loop.create_task(self._command_loop()) + self._read_task = None + self._connection_task = None + self._is_connected = CancelableEvent(self.loop) + self._is_disconnected = CancelableEvent(self.loop) + self._is_disconnected.set() + self._connection_lock = asyncio.Lock() async def connect(self) -> None: - if self._writer is not None: - return + async with self._connection_lock: + if self._is_connected.cancelled(): + raise asyncio.CancelledError() + if self._is_connected.is_set(): + return + self._connection_task = self.loop.create_task(self._connect()) + try: + await self._connection_task + finally: + self._connection_task = None + async def _connect(self) -> None: if self.config["rpc.connection.type"] == "unix": while True: try: r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"]) break + except asyncio.CancelledError: + raise except: self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...') await asyncio.sleep(10) @@ -71,6 +134,8 @@ class RPCClient: r, w = await asyncio.open_connection(self.config["rpc.connection.host"], self.config["rpc.connection.port"]) break + except asyncio.CancelledError: + raise except: self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...') await asyncio.sleep(10) @@ -78,16 +143,46 @@ class RPCClient: raise RuntimeError("invalid rpc connection type") self._reader = r self._writer = w - self.loop.create_task(self._try_read_loop()) - self.loop.create_task(self._command_loop()) + self._read_task = self.loop.create_task(self._try_read_loop()) + self._is_connected.set() + self._is_disconnected.clear() await self.request("register", peer_id=self.config["appservice.address"]) async def disconnect(self) -> None: + async with self._connection_lock: + if self._is_disconnected.cancelled(): + raise asyncio.CancelledError() + if self._is_disconnected.is_set(): + return + await self._disconnect() + + async def _disconnect(self) -> None: if self._writer is not None: self._writer.write_eof() await self._writer.drain() - self._writer = None - self._reader = None + if self._read_task is not None: + self._read_task.cancel() + self._read_task = None + self._on_disconnect() + + def _on_disconnect(self) -> None: + self._reader = None + self._writer = None + self._is_connected.clear() + self._is_disconnected.set() + + def wait_for_connection(self) -> Awaitable[None]: + return self._is_connected.wait() + + def wait_for_disconnection(self) -> Awaitable[None]: + return self._is_disconnected.wait() + + def cancel(self) -> None: + self._is_connected.cancel() + self._is_disconnected.cancel() + if self._connection_task is not None: + self._connection_task.cancel() + asyncio.run(self._disconnect()) @property def _next_req_id(self) -> int: @@ -119,7 +214,7 @@ class RPCClient: for handler in handlers: try: await handler(req) - except Exception: + except: self.log.exception("Exception in event handler") async def _handle_incoming_line(self, line: str) -> None: @@ -162,7 +257,9 @@ class RPCClient: async def _try_read_loop(self) -> None: try: await self._read_loop() - except Exception: + except asyncio.CancelledError: + pass + except: self.log.exception("Fatal error in read loop") async def _read_loop(self) -> None: @@ -178,6 +275,8 @@ class RPCClient: except asyncio.LimitOverrunError as e: self.log.warning(f"Buffer overrun: {e}") line += await self._reader.read(self._reader._limit) + except asyncio.CancelledError: + raise if not line: continue try: @@ -187,11 +286,12 @@ class RPCClient: continue try: await self._handle_incoming_line(line_str) - except Exception: + except asyncio.CancelledError: + raise + except: self.log.exception("Failed to handle incoming request %s", line_str) self.log.debug("Reader disconnected") - self._reader = None - self._writer = None + self._on_disconnect() async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]: req_id = self._next_req_id @@ -205,5 +305,6 @@ class RPCClient: return future async def request(self, command: str, **data: JSON) -> JSON: + await self.wait_for_connection() future = await self._raw_request(command, **data) return await future diff --git a/matrix_appservice_kakaotalk/user.py b/matrix_appservice_kakaotalk/user.py index 34449ff..22632c2 100644 --- a/matrix_appservice_kakaotalk/user.py +++ b/matrix_appservice_kakaotalk/user.py @@ -85,6 +85,7 @@ class User(DBUser, BaseUser): _connection_time: float _db_instance: DBUser | None _sync_lock: SimpleLock + _is_rpc_reconnecting: bool _logged_in_info: ProfileStruct | None _logged_in_info_time: float @@ -121,6 +122,7 @@ class User(DBUser, BaseUser): self._sync_lock = SimpleLock( "Waiting for thread sync to finish before handling %s", log=self.log ) + self._is_rpc_reconnecting = False self._logged_in_info = None self._logged_in_info_time = 0 @@ -332,6 +334,8 @@ class User(DBUser, BaseUser): state_event=BridgeStateEvent.UNKNOWN_ERROR, error_code="kt-reconnection-error", ) + finally: + self._is_rpc_reconnecting = False async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None: if self.client: @@ -545,9 +549,8 @@ class User(DBUser, BaseUser): state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR) if self.is_connected: state.state_event = BridgeStateEvent.CONNECTED - # TODO - #elif self._is_logged_in and self._is_reconnecting: - # state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT + elif self._is_rpc_reconnecting or self.client: + state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT return [state] async def get_puppet(self) -> pu.Puppet | None: @@ -582,16 +585,18 @@ class User(DBUser, BaseUser): # region KakaoTalk event handling async def on_connect(self) -> None: + self.is_connected = True + self._track_metric(METRIC_CONNECTED, True) + """ TODO Don't auto-resync channels if disconnection was too short now = time.monotonic() disconnected_at = self._connection_time max_delay = self.config["bridge.resync_max_disconnected_time"] first_connect = self.is_connected is None - self.is_connected = True - self._track_metric(METRIC_CONNECTED, True) if not first_connect and disconnected_at + max_delay < now: duration = int(now - disconnected_at) - self.log.debug(f"Disconnection lasted {duration} seconds") - elif self.temp_disconnect_notices: + self.log.debug(f"Disconnection lasted {duration} seconds, not re-syncing channels...") + """ + if self.temp_disconnect_notices: await self.send_bridge_notice("Connected to KakaoTalk chats") await self.push_bridge_state(BridgeStateEvent.CONNECTED) @@ -618,6 +623,19 @@ class User(DBUser, BaseUser): await self.logout() await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}") + async def on_client_disconnect(self) -> None: + self.is_connected = False + self._track_metric(METRIC_CONNECTED, False) + self.client = None + if self._is_logged_in: + if self.temp_disconnect_notices: + await self.send_bridge_notice( + "Disconnected from KakaoTalk: backend helper module exited. " + "Will reconnect once module resumes." + ) + self._is_rpc_reconnecting = True + asyncio.create_task(self.reload_session()) + async def on_logged_in(self, oauth_credential: OAuthCredential) -> None: self.log.debug(f"Successfully logged in as {oauth_credential.userId}") self.oauth_credential = oauth_credential