Handle disconnections from the node module

Make the bridge module auto-reconnect to the node module in case the
latter ever crashes (or is started after the bridge module).

Also work towards more holistic auto-reconnect logic in general.
This commit is contained in:
Andrew Ferrazzutti 2022-04-08 05:04:46 -04:00
parent d452735691
commit 8ac16e00fc
6 changed files with 182 additions and 57 deletions

View File

@ -67,19 +67,17 @@ class KakaoTalkBridge(Bridge):
self.public_website = None
def prepare_stop(self) -> None:
self.log.debug("Stopping RPC connection")
KakaoTalkClient.stop_cls()
self.log.debug("Stopping puppet syncers")
for puppet in Puppet.by_custom_mxid.values():
puppet.stop()
self.log.debug("Stopping kakaotalk listeners")
User.shutdown = True
self.add_shutdown_actions(user.save() for user in User.by_mxid.values())
self.add_shutdown_actions(KakaoTalkClient.stop_cls())
async def start(self) -> None:
# Block all other startup actions until RPC is ready
# TODO Remove when/if node backend is replaced with native
await KakaoTalkClient.init_cls(self.config)
KakaoTalkClient.init_cls(self.config)
self.add_startup_actions(User.init_cls(self))
self.add_startup_actions(Puppet.init_cls(self))
Portal.init_cls(self)

View File

@ -23,7 +23,6 @@ from mautrix.types import UserID
from mautrix.util.config import ConfigUpdateHelper, ForbiddenDefault, ForbiddenKey
# TODO Remove unneeded configs!!
class Config(BaseBridgeConfig):
def __getitem__(self, key: str) -> Any:
try:
@ -94,31 +93,15 @@ class Config(BaseBridgeConfig):
copy("bridge.backfill.initial_limit")
copy("bridge.backfill.missed_limit")
copy("bridge.backfill.disable_notifications")
if "bridge.periodic_reconnect_interval" in self:
base["bridge.periodic_reconnect.interval"] = self["bridge.periodic_reconnect_interval"]
base["bridge.periodic_reconnect.mode"] = self["bridge.periodic_reconnect_mode"]
else:
""" TODO
copy("bridge.periodic_reconnect.interval")
copy("bridge.periodic_reconnect.mode")
copy("bridge.periodic_reconnect.always")
copy("bridge.periodic_reconnect.min_connected_time")
"""
copy("bridge.resync_max_disconnected_time")
copy("bridge.sync_on_startup")
copy("bridge.temporary_disconnect_notices")
copy("bridge.disable_bridge_notices")
if "bridge.refresh_on_reconnection_fail" in self:
base["bridge.on_reconnection_fail.action"] = (
"refresh" if self["bridge.refresh_on_reconnection_fail"] else None
)
base["bridge.on_reconnection_fail.wait_for"] = 0
elif "bridge.on_reconnection_fail.refresh" in self:
base["bridge.on_reconnection_fail.action"] = (
"refresh" if self["bridge.on_reconnection_fail.refresh"] else None
)
copy("bridge.on_reconnection_fail.wait_for")
else:
copy("bridge.on_reconnection_fail.action")
copy("bridge.on_reconnection_fail.wait_for")
copy("bridge.resend_bridge_info")
copy("bridge.mute_bridging")
copy("bridge.tag_only_on_create")
@ -126,13 +109,14 @@ class Config(BaseBridgeConfig):
copy_dict("bridge.permissions")
""" TODO
for key in (
"bridge.periodic_reconnect.interval",
"bridge.on_reconnection_fail.wait_for",
):
value = base.get(key, None)
if isinstance(value, list) and len(value) != 2:
raise ValueError(f"{key} must only be a list of two items")
"""
copy("rpc.connection.type")
if base["rpc.connection.type"] == "unix":

View File

@ -200,15 +200,13 @@ bridge:
# If using double puppeting, should notifications be disabled
# while the initial backfill is in progress?
disable_notifications: false
# TODO Confirm this isn't needed
# TODO Implement this
#periodic_reconnect:
# # Interval in seconds in which to automatically reconnect all users.
# # This can be used to automatically mitigate the bug where KakaoTalk stops sending messages.
# # This may prevent KakaoTalk from "switching servers".
# # Set to -1 to disable periodic reconnections entirely.
# # Set to a list of two items to randomize the interval (min, max).
# interval: -1
# # What to do in periodic reconnects. Either "refresh" or "reconnect"
# mode: refresh
# # Should even disconnected users be reconnected?
# always: false
# # Only reconnect if the user has been connected for longer than this value
@ -216,6 +214,7 @@ bridge:
# The number of seconds that a disconnection can last without triggering an automatic re-sync
# and missed message backfilling when reconnecting.
# Set to 0 to always re-sync, or -1 to never re-sync automatically.
# TODO Actually use this setting
resync_max_disconnected_time: 5
# Should the bridge do a resync on startup?
sync_on_startup: true

View File

@ -22,7 +22,8 @@ with any other potential backend.
from __future__ import annotations
from typing import TYPE_CHECKING, cast, Type, Optional, Union
from typing import TYPE_CHECKING, cast, ClassVar, Type, Optional, Union
import asyncio
from contextlib import asynccontextmanager
import logging
import urllib.request
@ -64,7 +65,7 @@ except ImportError:
if TYPE_CHECKING:
from mautrix.types import JSON
from ...user import User
from ... import user as u
@asynccontextmanager
@ -79,15 +80,22 @@ class Client:
_rpc_client: RPCClient
@classmethod
async def init_cls(cls, config: Config) -> None:
def init_cls(cls, config: Config) -> None:
"""Initialize RPC to the Node backend."""
cls._rpc_client = RPCClient(config)
await cls._rpc_client.connect()
# NOTE No need to store this, as cancelling the RPCClient will cancel this too
asyncio.create_task(cls._keep_connected())
@classmethod
async def stop_cls(cls) -> None:
async def _keep_connected(cls) -> None:
while True:
await cls._rpc_client.connect()
await cls._rpc_client.wait_for_disconnection()
@classmethod
def stop_cls(cls) -> None:
"""Stop and disconnect from the Node backend."""
await cls._rpc_client.disconnect()
cls._rpc_client.cancel()
# region tokenless commands
@ -124,12 +132,15 @@ class Client:
# endregion
user: u.User
_rpc_disconnection_task: asyncio.Task | None
http: ClientSession
log: TraceLogger
def __init__(self, user: User, log: Optional[TraceLogger] = None):
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
"""Create a per-user client object for user-specific client functionality."""
self.user = user
self._rpc_disconnection_task = None
# TODO Let the Node backend use a proxy too!
connector = None
@ -188,13 +199,27 @@ class Client:
Receive the user's profile info in response.
"""
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "start")
if not self._rpc_disconnection_task:
self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler())
else:
self.log.warning("Called \"start\" on an already-started client")
return profile_req_struct.profile
async def stop(self) -> None:
"""Immediately stop bridging this user."""
self._stop_listen()
if self._rpc_disconnection_task:
self._rpc_disconnection_task.cancel()
else:
self.log.warning("Called \"stop\" on an already-stopped client")
await self._rpc_client.request("stop", mxid=self.user.mxid)
async def _rpc_disconnection_handler(self) -> None:
await self._rpc_client.wait_for_disconnection()
self._rpc_disconnection_task = None
self._stop_listen()
asyncio.create_task(self.user.on_client_disconnect())
async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens."""
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)

View File

@ -29,6 +29,46 @@ from .types import RPCError
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
class CancelableEvent:
_event: asyncio.Event
_task: asyncio.Task | None
_cancelled: bool
_loop: asyncio.AbstractEventLoop
def __init__(self, loop: asyncio.AbstractEventLoop | None):
self._event = asyncio.Event()
self._task = None
self._cancelled = False
self._loop = loop or asyncio.get_running_loop()
def is_set(self) -> bool:
return self._event.is_set()
def set(self) -> None:
self._event.set()
self._task = None
def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
if self._cancelled:
raise asyncio.CancelledError()
if self._event.is_set():
return
if not self._task:
self._task = asyncio.create_task(self._event.wait())
await self._task
def cancel(self) -> None:
self._cancelled = True
if self._task is not None:
self._task.cancel()
def cancelled(self) -> bool:
return self._cancelled
class RPCClient:
config: Config
loop: asyncio.AbstractEventLoop
@ -41,6 +81,11 @@ class RPCClient:
_response_waiters: dict[int, asyncio.Future[JSON]]
_event_handlers: dict[str, list[EventHandler]]
_command_queue: asyncio.Queue
_read_task: asyncio.Task | None
_connection_task: asyncio.Task | None
_is_connected: CancelableEvent
_is_disconnected: CancelableEvent
_connection_lock: asyncio.Lock
def __init__(self, config: Config) -> None:
self.config = config
@ -52,16 +97,34 @@ class RPCClient:
self._writer = None
self._reader = None
self._command_queue = asyncio.Queue()
self.loop.create_task(self._command_loop())
self._read_task = None
self._connection_task = None
self._is_connected = CancelableEvent(self.loop)
self._is_disconnected = CancelableEvent(self.loop)
self._is_disconnected.set()
self._connection_lock = asyncio.Lock()
async def connect(self) -> None:
if self._writer is not None:
async with self._connection_lock:
if self._is_connected.cancelled():
raise asyncio.CancelledError()
if self._is_connected.is_set():
return
self._connection_task = self.loop.create_task(self._connect())
try:
await self._connection_task
finally:
self._connection_task = None
async def _connect(self) -> None:
if self.config["rpc.connection.type"] == "unix":
while True:
try:
r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
await asyncio.sleep(10)
@ -71,6 +134,8 @@ class RPCClient:
r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
self.config["rpc.connection.port"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
await asyncio.sleep(10)
@ -78,16 +143,46 @@ class RPCClient:
raise RuntimeError("invalid rpc connection type")
self._reader = r
self._writer = w
self.loop.create_task(self._try_read_loop())
self.loop.create_task(self._command_loop())
self._read_task = self.loop.create_task(self._try_read_loop())
self._is_connected.set()
self._is_disconnected.clear()
await self.request("register", peer_id=self.config["appservice.address"])
async def disconnect(self) -> None:
async with self._connection_lock:
if self._is_disconnected.cancelled():
raise asyncio.CancelledError()
if self._is_disconnected.is_set():
return
await self._disconnect()
async def _disconnect(self) -> None:
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
self._writer = None
if self._read_task is not None:
self._read_task.cancel()
self._read_task = None
self._on_disconnect()
def _on_disconnect(self) -> None:
self._reader = None
self._writer = None
self._is_connected.clear()
self._is_disconnected.set()
def wait_for_connection(self) -> Awaitable[None]:
return self._is_connected.wait()
def wait_for_disconnection(self) -> Awaitable[None]:
return self._is_disconnected.wait()
def cancel(self) -> None:
self._is_connected.cancel()
self._is_disconnected.cancel()
if self._connection_task is not None:
self._connection_task.cancel()
asyncio.run(self._disconnect())
@property
def _next_req_id(self) -> int:
@ -119,7 +214,7 @@ class RPCClient:
for handler in handlers:
try:
await handler(req)
except Exception:
except:
self.log.exception("Exception in event handler")
async def _handle_incoming_line(self, line: str) -> None:
@ -162,7 +257,9 @@ class RPCClient:
async def _try_read_loop(self) -> None:
try:
await self._read_loop()
except Exception:
except asyncio.CancelledError:
pass
except:
self.log.exception("Fatal error in read loop")
async def _read_loop(self) -> None:
@ -178,6 +275,8 @@ class RPCClient:
except asyncio.LimitOverrunError as e:
self.log.warning(f"Buffer overrun: {e}")
line += await self._reader.read(self._reader._limit)
except asyncio.CancelledError:
raise
if not line:
continue
try:
@ -187,11 +286,12 @@ class RPCClient:
continue
try:
await self._handle_incoming_line(line_str)
except Exception:
except asyncio.CancelledError:
raise
except:
self.log.exception("Failed to handle incoming request %s", line_str)
self.log.debug("Reader disconnected")
self._reader = None
self._writer = None
self._on_disconnect()
async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]:
req_id = self._next_req_id
@ -205,5 +305,6 @@ class RPCClient:
return future
async def request(self, command: str, **data: JSON) -> JSON:
await self.wait_for_connection()
future = await self._raw_request(command, **data)
return await future

View File

@ -85,6 +85,7 @@ class User(DBUser, BaseUser):
_connection_time: float
_db_instance: DBUser | None
_sync_lock: SimpleLock
_is_rpc_reconnecting: bool
_logged_in_info: ProfileStruct | None
_logged_in_info_time: float
@ -121,6 +122,7 @@ class User(DBUser, BaseUser):
self._sync_lock = SimpleLock(
"Waiting for thread sync to finish before handling %s", log=self.log
)
self._is_rpc_reconnecting = False
self._logged_in_info = None
self._logged_in_info_time = 0
@ -332,6 +334,8 @@ class User(DBUser, BaseUser):
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
finally:
self._is_rpc_reconnecting = False
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self.client:
@ -545,9 +549,8 @@ class User(DBUser, BaseUser):
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED
# TODO
#elif self._is_logged_in and self._is_reconnecting:
# state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
elif self._is_rpc_reconnecting or self.client:
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state]
async def get_puppet(self) -> pu.Puppet | None:
@ -582,16 +585,18 @@ class User(DBUser, BaseUser):
# region KakaoTalk event handling
async def on_connect(self) -> None:
self.is_connected = True
self._track_metric(METRIC_CONNECTED, True)
""" TODO Don't auto-resync channels if disconnection was too short
now = time.monotonic()
disconnected_at = self._connection_time
max_delay = self.config["bridge.resync_max_disconnected_time"]
first_connect = self.is_connected is None
self.is_connected = True
self._track_metric(METRIC_CONNECTED, True)
if not first_connect and disconnected_at + max_delay < now:
duration = int(now - disconnected_at)
self.log.debug(f"Disconnection lasted {duration} seconds")
elif self.temp_disconnect_notices:
self.log.debug(f"Disconnection lasted {duration} seconds, not re-syncing channels...")
"""
if self.temp_disconnect_notices:
await self.send_bridge_notice("Connected to KakaoTalk chats")
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
@ -618,6 +623,19 @@ class User(DBUser, BaseUser):
await self.logout()
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
async def on_client_disconnect(self) -> None:
self.is_connected = False
self._track_metric(METRIC_CONNECTED, False)
self.client = None
if self._is_logged_in:
if self.temp_disconnect_notices:
await self.send_bridge_notice(
"Disconnected from KakaoTalk: backend helper module exited. "
"Will reconnect once module resumes."
)
self._is_rpc_reconnecting = True
asyncio.create_task(self.reload_session())
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential