Handle disconnections from the node module
Make the bridge module auto-reconnect to the node module in case the latter ever crashes (or is started after the bridge module). Also work towards more holistic auto-reconnect logic in general.
This commit is contained in:
parent
d452735691
commit
8ac16e00fc
|
@ -67,19 +67,17 @@ class KakaoTalkBridge(Bridge):
|
|||
self.public_website = None
|
||||
|
||||
def prepare_stop(self) -> None:
|
||||
self.log.debug("Stopping RPC connection")
|
||||
KakaoTalkClient.stop_cls()
|
||||
self.log.debug("Stopping puppet syncers")
|
||||
for puppet in Puppet.by_custom_mxid.values():
|
||||
puppet.stop()
|
||||
self.log.debug("Stopping kakaotalk listeners")
|
||||
User.shutdown = True
|
||||
self.add_shutdown_actions(user.save() for user in User.by_mxid.values())
|
||||
self.add_shutdown_actions(KakaoTalkClient.stop_cls())
|
||||
|
||||
async def start(self) -> None:
|
||||
# Block all other startup actions until RPC is ready
|
||||
# TODO Remove when/if node backend is replaced with native
|
||||
await KakaoTalkClient.init_cls(self.config)
|
||||
|
||||
KakaoTalkClient.init_cls(self.config)
|
||||
self.add_startup_actions(User.init_cls(self))
|
||||
self.add_startup_actions(Puppet.init_cls(self))
|
||||
Portal.init_cls(self)
|
||||
|
|
|
@ -23,7 +23,6 @@ from mautrix.types import UserID
|
|||
from mautrix.util.config import ConfigUpdateHelper, ForbiddenDefault, ForbiddenKey
|
||||
|
||||
|
||||
# TODO Remove unneeded configs!!
|
||||
class Config(BaseBridgeConfig):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
|
@ -94,31 +93,15 @@ class Config(BaseBridgeConfig):
|
|||
copy("bridge.backfill.initial_limit")
|
||||
copy("bridge.backfill.missed_limit")
|
||||
copy("bridge.backfill.disable_notifications")
|
||||
if "bridge.periodic_reconnect_interval" in self:
|
||||
base["bridge.periodic_reconnect.interval"] = self["bridge.periodic_reconnect_interval"]
|
||||
base["bridge.periodic_reconnect.mode"] = self["bridge.periodic_reconnect_mode"]
|
||||
else:
|
||||
copy("bridge.periodic_reconnect.interval")
|
||||
copy("bridge.periodic_reconnect.mode")
|
||||
copy("bridge.periodic_reconnect.always")
|
||||
copy("bridge.periodic_reconnect.min_connected_time")
|
||||
""" TODO
|
||||
copy("bridge.periodic_reconnect.interval")
|
||||
copy("bridge.periodic_reconnect.always")
|
||||
copy("bridge.periodic_reconnect.min_connected_time")
|
||||
"""
|
||||
copy("bridge.resync_max_disconnected_time")
|
||||
copy("bridge.sync_on_startup")
|
||||
copy("bridge.temporary_disconnect_notices")
|
||||
copy("bridge.disable_bridge_notices")
|
||||
if "bridge.refresh_on_reconnection_fail" in self:
|
||||
base["bridge.on_reconnection_fail.action"] = (
|
||||
"refresh" if self["bridge.refresh_on_reconnection_fail"] else None
|
||||
)
|
||||
base["bridge.on_reconnection_fail.wait_for"] = 0
|
||||
elif "bridge.on_reconnection_fail.refresh" in self:
|
||||
base["bridge.on_reconnection_fail.action"] = (
|
||||
"refresh" if self["bridge.on_reconnection_fail.refresh"] else None
|
||||
)
|
||||
copy("bridge.on_reconnection_fail.wait_for")
|
||||
else:
|
||||
copy("bridge.on_reconnection_fail.action")
|
||||
copy("bridge.on_reconnection_fail.wait_for")
|
||||
copy("bridge.resend_bridge_info")
|
||||
copy("bridge.mute_bridging")
|
||||
copy("bridge.tag_only_on_create")
|
||||
|
@ -126,13 +109,14 @@ class Config(BaseBridgeConfig):
|
|||
|
||||
copy_dict("bridge.permissions")
|
||||
|
||||
""" TODO
|
||||
for key in (
|
||||
"bridge.periodic_reconnect.interval",
|
||||
"bridge.on_reconnection_fail.wait_for",
|
||||
):
|
||||
value = base.get(key, None)
|
||||
if isinstance(value, list) and len(value) != 2:
|
||||
raise ValueError(f"{key} must only be a list of two items")
|
||||
"""
|
||||
|
||||
copy("rpc.connection.type")
|
||||
if base["rpc.connection.type"] == "unix":
|
||||
|
|
|
@ -200,15 +200,13 @@ bridge:
|
|||
# If using double puppeting, should notifications be disabled
|
||||
# while the initial backfill is in progress?
|
||||
disable_notifications: false
|
||||
# TODO Confirm this isn't needed
|
||||
# TODO Implement this
|
||||
#periodic_reconnect:
|
||||
# # Interval in seconds in which to automatically reconnect all users.
|
||||
# # This can be used to automatically mitigate the bug where KakaoTalk stops sending messages.
|
||||
# # This may prevent KakaoTalk from "switching servers".
|
||||
# # Set to -1 to disable periodic reconnections entirely.
|
||||
# # Set to a list of two items to randomize the interval (min, max).
|
||||
# interval: -1
|
||||
# # What to do in periodic reconnects. Either "refresh" or "reconnect"
|
||||
# mode: refresh
|
||||
# # Should even disconnected users be reconnected?
|
||||
# always: false
|
||||
# # Only reconnect if the user has been connected for longer than this value
|
||||
|
@ -216,6 +214,7 @@ bridge:
|
|||
# The number of seconds that a disconnection can last without triggering an automatic re-sync
|
||||
# and missed message backfilling when reconnecting.
|
||||
# Set to 0 to always re-sync, or -1 to never re-sync automatically.
|
||||
# TODO Actually use this setting
|
||||
resync_max_disconnected_time: 5
|
||||
# Should the bridge do a resync on startup?
|
||||
sync_on_startup: true
|
||||
|
|
|
@ -22,7 +22,8 @@ with any other potential backend.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast, Type, Optional, Union
|
||||
from typing import TYPE_CHECKING, cast, ClassVar, Type, Optional, Union
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import urllib.request
|
||||
|
@ -64,7 +65,7 @@ except ImportError:
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from mautrix.types import JSON
|
||||
from ...user import User
|
||||
from ... import user as u
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -79,15 +80,22 @@ class Client:
|
|||
_rpc_client: RPCClient
|
||||
|
||||
@classmethod
|
||||
async def init_cls(cls, config: Config) -> None:
|
||||
def init_cls(cls, config: Config) -> None:
|
||||
"""Initialize RPC to the Node backend."""
|
||||
cls._rpc_client = RPCClient(config)
|
||||
await cls._rpc_client.connect()
|
||||
# NOTE No need to store this, as cancelling the RPCClient will cancel this too
|
||||
asyncio.create_task(cls._keep_connected())
|
||||
|
||||
@classmethod
|
||||
async def stop_cls(cls) -> None:
|
||||
async def _keep_connected(cls) -> None:
|
||||
while True:
|
||||
await cls._rpc_client.connect()
|
||||
await cls._rpc_client.wait_for_disconnection()
|
||||
|
||||
@classmethod
|
||||
def stop_cls(cls) -> None:
|
||||
"""Stop and disconnect from the Node backend."""
|
||||
await cls._rpc_client.disconnect()
|
||||
cls._rpc_client.cancel()
|
||||
|
||||
|
||||
# region tokenless commands
|
||||
|
@ -124,12 +132,15 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
user: u.User
|
||||
_rpc_disconnection_task: asyncio.Task | None
|
||||
http: ClientSession
|
||||
log: TraceLogger
|
||||
|
||||
def __init__(self, user: User, log: Optional[TraceLogger] = None):
|
||||
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
|
||||
"""Create a per-user client object for user-specific client functionality."""
|
||||
self.user = user
|
||||
self._rpc_disconnection_task = None
|
||||
|
||||
# TODO Let the Node backend use a proxy too!
|
||||
connector = None
|
||||
|
@ -188,13 +199,27 @@ class Client:
|
|||
Receive the user's profile info in response.
|
||||
"""
|
||||
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "start")
|
||||
if not self._rpc_disconnection_task:
|
||||
self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler())
|
||||
else:
|
||||
self.log.warning("Called \"start\" on an already-started client")
|
||||
return profile_req_struct.profile
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Immediately stop bridging this user."""
|
||||
self._stop_listen()
|
||||
if self._rpc_disconnection_task:
|
||||
self._rpc_disconnection_task.cancel()
|
||||
else:
|
||||
self.log.warning("Called \"stop\" on an already-stopped client")
|
||||
await self._rpc_client.request("stop", mxid=self.user.mxid)
|
||||
|
||||
async def _rpc_disconnection_handler(self) -> None:
|
||||
await self._rpc_client.wait_for_disconnection()
|
||||
self._rpc_disconnection_task = None
|
||||
self._stop_listen()
|
||||
asyncio.create_task(self.user.on_client_disconnect())
|
||||
|
||||
async def renew_and_save(self) -> None:
|
||||
"""Renew and save the user's session tokens."""
|
||||
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
|
||||
|
|
|
@ -29,6 +29,46 @@ from .types import RPCError
|
|||
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
class CancelableEvent:
|
||||
_event: asyncio.Event
|
||||
_task: asyncio.Task | None
|
||||
_cancelled: bool
|
||||
_loop: asyncio.AbstractEventLoop
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop | None):
|
||||
self._event = asyncio.Event()
|
||||
self._task = None
|
||||
self._cancelled = False
|
||||
self._loop = loop or asyncio.get_running_loop()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
|
||||
def set(self) -> None:
|
||||
self._event.set()
|
||||
self._task = None
|
||||
|
||||
def clear(self) -> None:
|
||||
self._event.clear()
|
||||
|
||||
async def wait(self) -> None:
|
||||
if self._cancelled:
|
||||
raise asyncio.CancelledError()
|
||||
if self._event.is_set():
|
||||
return
|
||||
if not self._task:
|
||||
self._task = asyncio.create_task(self._event.wait())
|
||||
await self._task
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._cancelled = True
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
|
||||
class RPCClient:
|
||||
config: Config
|
||||
loop: asyncio.AbstractEventLoop
|
||||
|
@ -41,6 +81,11 @@ class RPCClient:
|
|||
_response_waiters: dict[int, asyncio.Future[JSON]]
|
||||
_event_handlers: dict[str, list[EventHandler]]
|
||||
_command_queue: asyncio.Queue
|
||||
_read_task: asyncio.Task | None
|
||||
_connection_task: asyncio.Task | None
|
||||
_is_connected: CancelableEvent
|
||||
_is_disconnected: CancelableEvent
|
||||
_connection_lock: asyncio.Lock
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
|
@ -52,16 +97,34 @@ class RPCClient:
|
|||
self._writer = None
|
||||
self._reader = None
|
||||
self._command_queue = asyncio.Queue()
|
||||
self.loop.create_task(self._command_loop())
|
||||
self._read_task = None
|
||||
self._connection_task = None
|
||||
self._is_connected = CancelableEvent(self.loop)
|
||||
self._is_disconnected = CancelableEvent(self.loop)
|
||||
self._is_disconnected.set()
|
||||
self._connection_lock = asyncio.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._writer is not None:
|
||||
return
|
||||
async with self._connection_lock:
|
||||
if self._is_connected.cancelled():
|
||||
raise asyncio.CancelledError()
|
||||
if self._is_connected.is_set():
|
||||
return
|
||||
self._connection_task = self.loop.create_task(self._connect())
|
||||
try:
|
||||
await self._connection_task
|
||||
finally:
|
||||
self._connection_task = None
|
||||
|
||||
async def _connect(self) -> None:
|
||||
if self.config["rpc.connection.type"] == "unix":
|
||||
while True:
|
||||
try:
|
||||
r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except:
|
||||
self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
|
||||
await asyncio.sleep(10)
|
||||
|
@ -71,6 +134,8 @@ class RPCClient:
|
|||
r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
|
||||
self.config["rpc.connection.port"])
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except:
|
||||
self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
|
||||
await asyncio.sleep(10)
|
||||
|
@ -78,16 +143,46 @@ class RPCClient:
|
|||
raise RuntimeError("invalid rpc connection type")
|
||||
self._reader = r
|
||||
self._writer = w
|
||||
self.loop.create_task(self._try_read_loop())
|
||||
self.loop.create_task(self._command_loop())
|
||||
self._read_task = self.loop.create_task(self._try_read_loop())
|
||||
self._is_connected.set()
|
||||
self._is_disconnected.clear()
|
||||
await self.request("register", peer_id=self.config["appservice.address"])
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
async with self._connection_lock:
|
||||
if self._is_disconnected.cancelled():
|
||||
raise asyncio.CancelledError()
|
||||
if self._is_disconnected.is_set():
|
||||
return
|
||||
await self._disconnect()
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
if self._writer is not None:
|
||||
self._writer.write_eof()
|
||||
await self._writer.drain()
|
||||
self._writer = None
|
||||
self._reader = None
|
||||
if self._read_task is not None:
|
||||
self._read_task.cancel()
|
||||
self._read_task = None
|
||||
self._on_disconnect()
|
||||
|
||||
def _on_disconnect(self) -> None:
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._is_connected.clear()
|
||||
self._is_disconnected.set()
|
||||
|
||||
def wait_for_connection(self) -> Awaitable[None]:
|
||||
return self._is_connected.wait()
|
||||
|
||||
def wait_for_disconnection(self) -> Awaitable[None]:
|
||||
return self._is_disconnected.wait()
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._is_connected.cancel()
|
||||
self._is_disconnected.cancel()
|
||||
if self._connection_task is not None:
|
||||
self._connection_task.cancel()
|
||||
asyncio.run(self._disconnect())
|
||||
|
||||
@property
|
||||
def _next_req_id(self) -> int:
|
||||
|
@ -119,7 +214,7 @@ class RPCClient:
|
|||
for handler in handlers:
|
||||
try:
|
||||
await handler(req)
|
||||
except Exception:
|
||||
except:
|
||||
self.log.exception("Exception in event handler")
|
||||
|
||||
async def _handle_incoming_line(self, line: str) -> None:
|
||||
|
@ -162,7 +257,9 @@ class RPCClient:
|
|||
async def _try_read_loop(self) -> None:
|
||||
try:
|
||||
await self._read_loop()
|
||||
except Exception:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except:
|
||||
self.log.exception("Fatal error in read loop")
|
||||
|
||||
async def _read_loop(self) -> None:
|
||||
|
@ -178,6 +275,8 @@ class RPCClient:
|
|||
except asyncio.LimitOverrunError as e:
|
||||
self.log.warning(f"Buffer overrun: {e}")
|
||||
line += await self._reader.read(self._reader._limit)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
|
@ -187,11 +286,12 @@ class RPCClient:
|
|||
continue
|
||||
try:
|
||||
await self._handle_incoming_line(line_str)
|
||||
except Exception:
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except:
|
||||
self.log.exception("Failed to handle incoming request %s", line_str)
|
||||
self.log.debug("Reader disconnected")
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._on_disconnect()
|
||||
|
||||
async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]:
|
||||
req_id = self._next_req_id
|
||||
|
@ -205,5 +305,6 @@ class RPCClient:
|
|||
return future
|
||||
|
||||
async def request(self, command: str, **data: JSON) -> JSON:
|
||||
await self.wait_for_connection()
|
||||
future = await self._raw_request(command, **data)
|
||||
return await future
|
||||
|
|
|
@ -85,6 +85,7 @@ class User(DBUser, BaseUser):
|
|||
_connection_time: float
|
||||
_db_instance: DBUser | None
|
||||
_sync_lock: SimpleLock
|
||||
_is_rpc_reconnecting: bool
|
||||
_logged_in_info: ProfileStruct | None
|
||||
_logged_in_info_time: float
|
||||
|
||||
|
@ -121,6 +122,7 @@ class User(DBUser, BaseUser):
|
|||
self._sync_lock = SimpleLock(
|
||||
"Waiting for thread sync to finish before handling %s", log=self.log
|
||||
)
|
||||
self._is_rpc_reconnecting = False
|
||||
self._logged_in_info = None
|
||||
self._logged_in_info_time = 0
|
||||
|
||||
|
@ -332,6 +334,8 @@ class User(DBUser, BaseUser):
|
|||
state_event=BridgeStateEvent.UNKNOWN_ERROR,
|
||||
error_code="kt-reconnection-error",
|
||||
)
|
||||
finally:
|
||||
self._is_rpc_reconnecting = False
|
||||
|
||||
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
|
||||
if self.client:
|
||||
|
@ -545,9 +549,8 @@ class User(DBUser, BaseUser):
|
|||
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
|
||||
if self.is_connected:
|
||||
state.state_event = BridgeStateEvent.CONNECTED
|
||||
# TODO
|
||||
#elif self._is_logged_in and self._is_reconnecting:
|
||||
# state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
|
||||
elif self._is_rpc_reconnecting or self.client:
|
||||
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
|
||||
return [state]
|
||||
|
||||
async def get_puppet(self) -> pu.Puppet | None:
|
||||
|
@ -582,16 +585,18 @@ class User(DBUser, BaseUser):
|
|||
# region KakaoTalk event handling
|
||||
|
||||
async def on_connect(self) -> None:
|
||||
self.is_connected = True
|
||||
self._track_metric(METRIC_CONNECTED, True)
|
||||
""" TODO Don't auto-resync channels if disconnection was too short
|
||||
now = time.monotonic()
|
||||
disconnected_at = self._connection_time
|
||||
max_delay = self.config["bridge.resync_max_disconnected_time"]
|
||||
first_connect = self.is_connected is None
|
||||
self.is_connected = True
|
||||
self._track_metric(METRIC_CONNECTED, True)
|
||||
if not first_connect and disconnected_at + max_delay < now:
|
||||
duration = int(now - disconnected_at)
|
||||
self.log.debug(f"Disconnection lasted {duration} seconds")
|
||||
elif self.temp_disconnect_notices:
|
||||
self.log.debug(f"Disconnection lasted {duration} seconds, not re-syncing channels...")
|
||||
"""
|
||||
if self.temp_disconnect_notices:
|
||||
await self.send_bridge_notice("Connected to KakaoTalk chats")
|
||||
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
|
||||
|
||||
|
@ -618,6 +623,19 @@ class User(DBUser, BaseUser):
|
|||
await self.logout()
|
||||
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
|
||||
|
||||
async def on_client_disconnect(self) -> None:
|
||||
self.is_connected = False
|
||||
self._track_metric(METRIC_CONNECTED, False)
|
||||
self.client = None
|
||||
if self._is_logged_in:
|
||||
if self.temp_disconnect_notices:
|
||||
await self.send_bridge_notice(
|
||||
"Disconnected from KakaoTalk: backend helper module exited. "
|
||||
"Will reconnect once module resumes."
|
||||
)
|
||||
self._is_rpc_reconnecting = True
|
||||
asyncio.create_task(self.reload_session())
|
||||
|
||||
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
|
||||
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
|
||||
self.oauth_credential = oauth_credential
|
||||
|
|
Loading…
Reference in New Issue