Improved syncing, cleanups

This commit is contained in:
Andrew Ferrazzutti 2022-03-18 03:52:55 -04:00
parent d975dd8c1f
commit 60c47e5a20
14 changed files with 226 additions and 290 deletions

View File

@ -72,8 +72,6 @@ class KakaoTalkBridge(Bridge):
puppet.stop()
self.log.debug("Stopping kakaotalk listeners")
User.shutdown = True
for user in User.by_ktid.values():
user.stop_listen()
self.add_shutdown_actions(user.save() for user in User.by_mxid.values())
self.add_shutdown_actions(KakaoTalkClient.stop_cls())

View File

@ -17,6 +17,8 @@ from mautrix.bridge.commands import HelpSection, command_handler
from .typehint import CommandEvent
from ..kt.client.errors import CommandException
SECTION_CONNECTION = HelpSection("Connection management", 15, "")
@ -32,35 +34,6 @@ async def set_notice_room(evt: CommandEvent) -> None:
await evt.reply("This room has been marked as your bridge notice room")
"""
@command_handler(
needs_auth=True,
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Disconnect from KakaoTalk",
)
async def disconnect(evt: CommandEvent) -> None:
if not evt.sender.mqtt:
await evt.reply("You don't have a KakaoTalk MQTT connection")
return
evt.sender.mqtt.disconnect()
@command_handler(
needs_auth=True,
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Connect to KakaoTalk",
aliases=["reconnect"],
)
async def connect(evt: CommandEvent) -> None:
if evt.sender.listen_task and not evt.sender.listen_task.done():
await evt.reply("You already have a KakaoTalk MQTT connection")
return
evt.sender.start_listen()
"""
@command_handler(
needs_auth=True,
management_only=True,
@ -72,34 +45,11 @@ async def ping(evt: CommandEvent) -> None:
await evt.reply("You're not logged into KakaoTalk")
return
await evt.mark_read()
# try:
own_info = await evt.sender.get_own_info()
# TODO catch errors
# except fbchat.PleaseRefresh as e:
# await evt.reply(f"{e}\n\nUse `$cmdprefix+sp refresh` refresh the session.")
# return
await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})")
"""
if not evt.sender.listen_task or evt.sender.listen_task.done():
await evt.reply("You don't have a KakaoTalk MQTT connection. Use `connect` to connect.")
elif not evt.sender.is_connected:
await evt.reply("The KakaoTalk MQTT listener is **disconnected**.")
else:
await evt.reply("The KakaoTalk MQTT listener is connected.")
"""
"""
@command_handler(
needs_auth=True,
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Resync chats and reconnect to MQTT",
)
async def refresh(evt: CommandEvent) -> None:
await evt.sender.refresh(force_notice=True)
"""
try:
own_info = await evt.sender.get_own_info()
await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})")
except CommandException as e:
await evt.reply(f"Error from KakaoTalk: {e}")
@command_handler(
@ -107,10 +57,19 @@ async def refresh(evt: CommandEvent) -> None:
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Resync chats",
help_args="[count]",
)
async def sync(evt: CommandEvent) -> None:
try:
sync_count = int(evt.args[0])
except IndexError:
sync_count = None
except ValueError:
await evt.reply("**Usage:** `$cmdprefix+sp logout [--reset-device]`")
return
await evt.mark_read()
if await evt.sender.post_login(is_startup=False):
if await evt.sender.connect_and_sync(sync_count):
await evt.reply("Sync complete")
else:
await evt.reply("Sync failed")

View File

@ -47,24 +47,23 @@ class User:
def _from_optional_row(cls, row: Record | None) -> User | None:
return cls._from_row(row) if row is not None else None
_columns = "mxid, ktid, uuid, access_token, refresh_token, notice_room"
@classmethod
async def all_logged_in(cls) -> List[User]:
q = """
SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user"
WHERE ktid<>0
"""
q = f'SELECT {cls._columns} FROM "user" WHERE ktid<>0'
rows = await cls.db.fetch(q)
return [cls._from_row(row) for row in rows if row]
@classmethod
async def get_by_ktid(cls, ktid: int) -> User | None:
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
q = f'SELECT {cls._columns} FROM "user" WHERE ktid=$1'
row = await cls.db.fetchrow(q, ktid)
return cls._from_optional_row(row)
@classmethod
async def get_by_mxid(cls, mxid: UserID) -> User | None:
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE mxid=$1'
q = f'SELECT {cls._columns} FROM "user" WHERE mxid=$1'
row = await cls.db.fetchrow(q, mxid)
return cls._from_optional_row(row)
@ -73,23 +72,30 @@ class User:
q = 'SELECT uuid FROM "user" WHERE uuid IS NOT NULL'
return {tuple(record)[0] for record in await cls.db.fetch(q)}
@property
def _values(self):
return (
self.mxid,
self.ktid,
self.uuid,
self.access_token,
self.refresh_token,
self.notice_room,
)
async def insert(self) -> None:
q = """
INSERT INTO "user" (mxid, ktid, uuid, access_token, refresh_token, notice_room)
VALUES ($1, $2, $3, $4, $5, $6)
"""
await self.db.execute(
q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room
)
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
async def save(self) -> None:
q = """
UPDATE "user" SET ktid=$1, uuid=$2, access_token=$3, refresh_token=$4, notice_room=$5
WHERE mxid=$6
UPDATE "user" SET ktid=$2, uuid=$3, access_token=$4, refresh_token=$5, notice_room=$6
WHERE mxid=$1
"""
await self.db.execute(
q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
)
await self.db.execute(q, *self._values)

View File

@ -122,7 +122,7 @@ bridge:
command_prefix: "!kt"
# Number of chats to sync (and create portals for) on startup/login.
# Set 0 to disable automatic syncing.
# Set to 0 to disable automatic syncing, or -1 to sync as much as possible.
initial_chat_sync: 20
# Whether or not the KakaoTalk users of logged in Matrix users should be
# invited to private chats when the user sends a message from another client.
@ -188,11 +188,11 @@ bridge:
# usually needed to prevent rate limits and to allow timestamp massaging.
invite_own_puppet: true
# Maximum number of messages to backfill initially.
# Set to 0 to disable backfilling when creating portal.
# Set to 0 to disable backfilling when creating portal, or -1 to backfill as much as possible.
initial_limit: 0
# Maximum number of messages to backfill if messages were missed while
# the bridge was disconnected.
# Set to 0 to disable backfilling missed messages.
# Set to 0 to disable backfilling missed messages, or -1 to backfill as much as possible.
missed_limit: 1000
# If using double puppeting, should notifications be disabled
# while the initial backfill is in progress?
@ -213,7 +213,7 @@ bridge:
# The number of seconds that a disconnection can last without triggering an automatic re-sync
# and missed message backfilling when reconnecting.
# Set to 0 to always re-sync, or -1 to never re-sync automatically.
resync_max_disconnected_time: 5
#resync_max_disconnected_time: 5
# Should the bridge do a resync on startup?
sync_on_startup: true
# Whether or not temporary disconnections should send notices to the notice room.

View File

@ -22,7 +22,7 @@ with any other potential backend.
from __future__ import annotations
from typing import TYPE_CHECKING, cast, Awaitable, Callable, Type, Optional, Union
from typing import TYPE_CHECKING, cast, Type, Optional, Union
import logging
import urllib.request
@ -39,7 +39,6 @@ from ...rpc import RPCClient
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
from ..types.bson import Long
from ..types.client.client_session import LoginResult
from ..types.channel.channel_type import ChannelType
from ..types.chat.chat import Chatlog
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.request import (
@ -63,7 +62,6 @@ except ImportError:
if TYPE_CHECKING:
from mautrix.types import JSON
from ...user import User
from ...rpc.rpc import EventHandler
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
@ -80,7 +78,6 @@ class Client:
@classmethod
async def stop_cls(cls) -> None:
"""Stop and disconnect from the Node backend."""
await cls._rpc_client.request("stop")
await cls._rpc_client.disconnect()
@ -102,16 +99,17 @@ class Client:
)
@classmethod
async def register_device(cls, passcode: str, **req) -> None:
async def register_device(cls, passcode: str, **req: JSON) -> None:
"""Register a (fake) device that will be associated with the provided login credentials."""
await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req)
@classmethod
async def login(cls, **req) -> OAuthCredential:
async def login(cls, **req: JSON) -> OAuthCredential:
"""
Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials.
"""
# NOTE Actually returns a LoginData object, but this only needs an OAuthCredential
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
# endregion
@ -181,20 +179,21 @@ class Client:
self.user.oauth_credential = oauth_info.credential
await self.user.save()
async def start(self) -> LoginResult:
async def connect(self) -> LoginResult:
"""
Start a new session by providing a token obtained from a prior login.
Receive a snapshot of account state in response.
"""
login_result = await self._api_user_request_result(LoginResult, "start")
login_result = await self._api_user_request_result(LoginResult, "connect")
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
# TODO Skip if handlers are already listening. But this is idempotent and thus probably safe
self._start_listen()
return login_result
"""
async def is_connected(self) -> bool:
resp = await self._rpc_client.request("is_connected")
return resp["is_connected"]
"""
async def disconnect(self) -> bool:
connection_existed = await self._rpc_client.request("disconnect", mxid=self.user.mxid)
self._stop_listen()
return connection_existed
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
@ -232,14 +231,6 @@ class Client:
text=text
)
async def start_listen(self) -> None:
# TODO Connect all listeners here?
await self._api_user_request_void("start_listen")
async def stop(self) -> None:
# TODO Stop all event handlers
await self._api_user_request_void("stop")
# TODO Combine these into one
@ -272,19 +263,31 @@ class Client:
# region listeners
async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None:
async def wrapper(data: dict[str, JSON]) -> None:
await func(
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
data["channelType"]
)
async def _on_message(self, data: dict[str, JSON]) -> None:
await self.user.on_message(
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
data["channelType"]
)
self._add_user_handler("message", wrapper)
""" TODO
async def _on_receipt(self, data: Dict[str, JSON]) -> None:
await self.user.on_receipt(Receipt.deserialize(data["receipt"]))
"""
def _add_user_handler(self, command: str, handler: EventHandler) -> str:
self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler)
def _start_listen(self) -> None:
# TODO Automate this somehow, like with a fancy enum
self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [self._on_message])
# TODO many more listeners
def _stop_listen(self) -> None:
# TODO Automate this somehow, like with a fancy enum
self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [])
# TODO many more listeners
def _get_user_cmd(self, command) -> str:
return f"{command}:{self.user.mxid}"
# endregion

View File

@ -79,6 +79,16 @@ class RootCommandResult(ResponseState):
"""For brevity, this also encompasses CommandResultFailed and CommandResultDoneVoid"""
success: bool
@classmethod
def deserialize(cls, data: JSON) -> "RootCommandResult":
if not data or "success" not in data or "status" not in data:
return RootCommandResult(
success=True,
status=KnownDataStatusCode.SUCCESS
)
else:
return super().deserialize(data)
ResultType = TypeVar("ResultType", bound=Serializable)

View File

@ -16,7 +16,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, Pattern, cast
from collections import deque
import asyncio
import re
import time
@ -845,11 +844,6 @@ class Portal(DBPortal, BasePortal):
return False
return True
async def _add_kakaotalk_reply(
self, content: MessageEventContent, reply_to: None
) -> None:
self.log.info("TODO")
async def handle_remote_message(
self,
source: u.User,
@ -969,7 +963,7 @@ class Portal(DBPortal, BasePortal):
if not messages:
self.log.debug("Didn't get any messages from server")
return
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server")
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) > 1 else ''} from server")
self._backfill_leave = set()
async with NotificationDisabler(self.mxid, source):
for message in messages:

View File

@ -1 +1,2 @@
from .rpc import RPCClient
from .types import RPCError

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any, Callable, Awaitable, List
from typing import Any, Callable, Awaitable
import asyncio
import json
@ -39,7 +39,7 @@ class RPCClient:
_req_id: int
_min_broadcast_id: int
_response_waiters: dict[int, asyncio.Future[JSON]]
_event_handlers: dict[str, List[EventHandler]]
_event_handlers: dict[str, list[EventHandler]]
_command_queue: asyncio.Queue
def __init__(self, config: Config) -> None:
@ -98,7 +98,13 @@ class RPCClient:
self._event_handlers.setdefault(method, []).append(handler)
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
self._event_handlers.setdefault(method, []).remove(handler)
try:
self._event_handlers.setdefault(method, []).remove(handler)
except ValueError:
pass
def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
self._event_handlers[method] = handlers
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:

View File

@ -46,9 +46,7 @@ from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels")
METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync")
METRIC_UNKNOWN_EVENT = Summary("bridge_on_unknown_event", "calls to on_unknown_event")
METRIC_CONNECT_AND_SYNC = Summary("bridge_sync_channels", "calls to connect_and_sync")
METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added")
METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed")
METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing")
@ -58,7 +56,6 @@ METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent")
METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen")
METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change")
METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change")
METRIC_THREAD_CHANGE = Summary("bridge_on_thread_change", "calls to on_thread_change")
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
@ -71,7 +68,6 @@ BridgeState.human_readable_errors.update(
"kt-reconnection-error": "Failed to reconnect to KakaoTalk",
"kt-connection-error": "KakaoTalk disconnected unexpectedly",
"kt-auth-error": "Authentication error from KakaoTalk: {message}",
"kt-start-error": "Startup error from KakaoTalk: {message}",
"kt-disconnected": None,
"logged-out": "You're not logged into KakaoTalk",
}
@ -79,7 +75,6 @@ BridgeState.human_readable_errors.update(
class User(DBUser, BaseUser):
#temp_disconnect_notices: bool = True
shutdown: bool = False
config: Config
@ -90,16 +85,13 @@ class User(DBUser, BaseUser):
_notice_room_lock: asyncio.Lock
_notice_send_lock: asyncio.Lock
command_status: dict | None
is_admin: bool
permission_level: str
_is_logged_in: bool | None
#_is_connected: bool | None
#_connection_time: float
_prev_reconnect_fail_refresh: float
_is_connected: bool | None
_connection_time: float
_db_instance: DBUser | None
_sync_lock: SimpleLock
_is_refreshing: bool
_logged_in_info: ProfileStruct | None
_logged_in_info_time: float
@ -124,7 +116,6 @@ class User(DBUser, BaseUser):
self.notice_room = notice_room
self._notice_room_lock = asyncio.Lock()
self._notice_send_lock = asyncio.Lock()
self.command_status = None
(
self.relay_whitelisted,
self.is_whitelisted,
@ -132,13 +123,11 @@ class User(DBUser, BaseUser):
self.permission_level,
) = self.config.get_permissions(mxid)
self._is_logged_in = None
#self._is_connected = None
#self._connection_time = time.monotonic()
self._prev_reconnect_fail_refresh = time.monotonic()
self._is_connected = None
self._connection_time = time.monotonic()
self._sync_lock = SimpleLock(
"Waiting for thread sync to finish before handling %s", log=self.log
)
self._is_refreshing = False
self._logged_in_info = None
self._logged_in_info_time = 0
@ -150,10 +139,8 @@ class User(DBUser, BaseUser):
cls.config = bridge.config
cls.az = bridge.az
cls.loop = bridge.loop
#cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
"""
@property
def is_connected(self) -> bool | None:
return self._is_connected
@ -167,11 +154,11 @@ class User(DBUser, BaseUser):
@property
def connection_time(self) -> float:
return self._connection_time
"""
@property
def has_state(self) -> bool:
return bool(self.uuid and self.ktid and self.access_token and self.refresh_token)
# TODO If more state is needed, consider returning a saved LoginResult
return bool(self.access_token and self.refresh_token)
# region Database getters
@ -233,11 +220,13 @@ class User(DBUser, BaseUser):
async def get_uuid(self, force: bool = False) -> str:
if self.uuid is None or force:
self.uuid = await self._generate_uuid()
# TODO Maybe don't save yet
await self.save()
return self.uuid
async def _generate_uuid(self) -> str:
return await Client.generate_uuid(await self.get_all_uuids())
@classmethod
async def _generate_uuid(cls) -> str:
return await Client.generate_uuid(await super().get_all_uuids())
# endregion
@ -285,8 +274,7 @@ class User(DBUser, BaseUser):
self._logged_in_info_time = time.monotonic()
self._track_metric(METRIC_LOGGED_IN, True)
self._is_logged_in = True
#self.is_connected = None
self.stop_listen()
self.is_connected = None
asyncio.create_task(self.post_login(is_startup=is_startup))
return True
return False
@ -310,10 +298,9 @@ class User(DBUser, BaseUser):
) -> ProfileStruct:
if not client:
client = self.client
# TODO Retry network connection failures here, or in the client?
# TODO Retry network connection failures here, or in the client (like token refreshes are)?
try:
return await client.fetch_logged_in_user()
# NOTE Not catching InvalidAccessToken here, as client handles it & tries to refresh the token
except AuthenticationRequired as e:
if action != "restore session":
await self._send_reset_notice(e)
@ -352,7 +339,7 @@ class User(DBUser, BaseUser):
state_event=BridgeStateEvent.TRANSIENT_DISCONNECT,
)
await asyncio.sleep(60)
await self.reload_session(event_id, retries - 1)
await self.reload_session(event_id, retries - 1, is_startup)
else:
await self.send_bridge_notice(
notice,
@ -362,44 +349,43 @@ class User(DBUser, BaseUser):
error_code="kt-reconnection-error",
)
except Exception:
self.log.exception("Error connecting to KakaoTalk")
await self.send_bridge_notice(
"Failed to connect to KakaoTalk: unknown error (see logs for more details)",
edit=event_id,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
finally:
self._is_refreshing = False
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> bool:
ok = True
self.stop_listen()
if self.has_state:
# TODO Log out of KakaoTalk if an API exists for it
pass
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self.client:
# TODO Look for a logout API call
was_connected = await self.client.disconnect()
if was_connected != self._is_connected:
self.log.warn(
f"Node backend was{' not' if not was_connected else ''} connected, "
f"but we thought it was{' not' if not self._is_connected else ''}")
if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
self._track_metric(METRIC_LOGGED_IN, False)
self._is_logged_in = False
#self.is_connected = None
if self.client:
await self.client.stop()
self.client = None
if self.ktid and remove_ktid:
#await UserPortal.delete_all(self.ktid)
del self.by_ktid[self.ktid]
self.ktid = None
if reset_device:
self.uuid = await self._generate_uuid()
self.access_token = None
self.refresh_token = None
await self.save()
return ok
self._is_logged_in = False
self.is_connected = None
self.client = None
async def post_login(self, is_startup: bool) -> bool:
if self.ktid and remove_ktid:
#await UserPortal.delete_all(self.ktid)
del self.by_ktid[self.ktid]
self.ktid = None
await self.save()
async def post_login(self, is_startup: bool) -> None:
self.log.info("Running post-login actions")
self._add_to_cache()
@ -412,12 +398,27 @@ class User(DBUser, BaseUser):
except Exception:
self.log.exception("Failed to automatically enable custom puppet")
assert self.client
# TODO Check if things break when a live message comes in during syncing
if self.config["bridge.sync_on_startup"] or not is_startup:
sync_count = self.config["bridge.initial_chat_sync"]
else:
sync_count = None
await self.connect_and_sync(sync_count)
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
return {
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
async for portal in po.Portal.get_all_by_receiver(self.ktid)
if portal.mxid
}
@async_time(METRIC_CONNECT_AND_SYNC)
async def connect_and_sync(self, sync_count: int | None) -> bool:
# TODO Look for a way to sync all channels without (re-)logging in
try:
# TODO if not is_startup, close existing listeners
login_result = await self.client.start()
await self._sync_channels(login_result, is_startup)
self.start_listen()
login_result = await self.client.connect()
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
await self._sync_channels(login_result, sync_count)
return True
except AuthenticationRequired as e:
await self.send_bridge_notice(
@ -429,41 +430,34 @@ class User(DBUser, BaseUser):
)
await self.logout(remove_ktid=False)
except Exception as e:
self.log.exception("Failed to start client")
await self.send_bridge_notice(
f"Got error from KakaoTalk:\n\n> {e!s}\n\n",
important=True,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-start-error",
error_message=str(e),
)
self.log.exception("Failed to connect and sync")
await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e))
return False
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
return {
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
async for portal in po.Portal.get_all_by_receiver(self.ktid)
if portal.mxid
}
@async_time(METRIC_SYNC_CHANNELS)
async def _sync_channels(self, login_result: LoginResult, is_startup: bool) -> None:
# TODO Look for a way to sync all channels without (re-)logging in
sync_count = self.config["bridge.initial_chat_sync"]
if sync_count <= 0 or not self.config["bridge.sync_on_startup"] and is_startup:
self.log.debug(f"Skipping channel syncing{' on startup' if sync_count > 0 else ''}")
async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None:
if sync_count is None:
sync_count = self.config["bridge.initial_chat_sync"]
if not sync_count:
self.log.debug("Skipping channel syncing")
return
if not login_result.channelList:
self.log.debug("No channels to sync")
return
# TODO What about removed channels? Don't early-return then
sync_count = min(sync_count, len(login_result.channelList))
num_channels = len(login_result.channelList)
sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
self.log.debug(f"Syncing {sync_count} of {len(login_result.channelList)} channels...")
self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
for channel_item in login_result.channelList[:sync_count]:
# TODO try-except here, above, below?
await self._sync_channel(channel_item)
try:
await self._sync_channel(channel_item)
except AuthenticationRequired:
raise
except Exception:
self.log.exception(f"Failed to sync channel {channel_item.channel.channelId}")
await self.update_direct_chats()
async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None:
channel_data = channel_item.channel
@ -573,18 +567,12 @@ class User(DBUser, BaseUser):
state.remote_name = puppet.name
async def get_bridge_states(self) -> list[BridgeState]:
self.log.info("TODO: get_bridge_states")
return []
"""
if not self.state:
return []
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED
elif self._is_refreshing or self.mqtt:
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state]
"""
async def get_puppet(self) -> pu.Puppet | None:
if not self.ktid:
@ -593,36 +581,10 @@ class User(DBUser, BaseUser):
# region KakaoTalk event handling
def start_listen(self) -> None:
self.listen_task = asyncio.create_task(self._try_listen())
def _disconnect_listener_after_error(self) -> None:
self.log.info("TODO: _disconnect_listener_after_error")
async def _try_listen(self) -> None:
try:
# TODO Pass all listeners to start_listen instead of registering them one-by-one?
await self.client.start_listen()
await self.client.on_message(self.on_message)
# TODO Handle auth errors specially?
#except AuthenticationRequired as e:
except Exception:
#self.is_connected = False
self.log.exception("Fatal error in listener")
await self.send_bridge_notice(
"Fatal error in listener (see logs for more info)",
state_event=BridgeStateEvent.UNKNOWN_ERROR,
important=True,
error_code="kt-connection-error",
)
self._disconnect_listener_after_error()
def stop_listen(self) -> None:
self.log.info("TODO: stop_listen")
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential
await self.push_bridge_state(BridgeStateEvent.CONNECTING)
self.client = Client(self, log=self.log.getChild("ktclient"))
await self.save()
self._is_logged_in = True
@ -631,7 +593,6 @@ class User(DBUser, BaseUser):
self._logged_in_info_time = time.monotonic()
except Exception:
self.log.exception("Failed to fetch post-login info")
self.stop_listen()
asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE)

View File

@ -14,7 +14,6 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import { Long } from "bson"
import { emitLines, promisify } from "./util.js"
import {
AuthApiClient,
OAuthApiClient,
@ -28,6 +27,8 @@ const { KnownChatType } = chat
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
/** @typedef {import("./clientmanager.js").default} ClientManager} */
import { emitLines, promisify } from "./util.js"
class UserClient {
static #initializing = false
@ -36,7 +37,7 @@ class UserClient {
get talkClient() { return this.#talkClient }
/** @type {ServiceApiClient} */
#serviceClient = null
#serviceClient
get serviceClient() { return this.#serviceClient }
/**
@ -80,7 +81,6 @@ class UserClient {
}
export default class PeerClient {
/**
* @param {ClientManager} manager
* @param {import("net").Socket} socket
@ -136,6 +136,7 @@ export default class PeerClient {
return
}
this.stopped = true
this.#closeUsers()
try {
await this.#write({ id: --this.notificationID, command: "quit", error })
await promisify(cb => this.socket.end(cb))
@ -145,20 +146,21 @@ export default class PeerClient {
}
}
handleEnd = async () => {
// TODO Persist clients across bridge disconnections.
// But then have to queue received events until bridge acks them!
handleEnd = () => {
this.stopped = true
this.#closeUsers()
if (this.peerID && this.manager.clients.get(this.peerID) === this) {
this.manager.clients.delete(this.peerID)
}
this.log(`Connection closed (peer: ${this.peerID})`)
}
#closeUsers() {
this.log("Closing all API clients for", this.peerID)
for (const userClient of this.userClients.values()) {
userClient.close()
}
this.userClients.clear()
this.stopped = true
if (this.peerID && this.manager.clients.get(this.peerID) === this) {
this.manager.clients.delete(this.peerID)
}
this.log(`Connection closed (peer: ${this.peerID})`)
}
/**
@ -178,6 +180,7 @@ export default class PeerClient {
* @param {Object} req.form
*/
registerDevice = async (req) => {
// TODO Look for a deregister API call
const authClient = await this.#createAuthClient(req.uuid)
return await authClient.registerDevice(req.form, req.passcode, true)
}
@ -192,6 +195,7 @@ export default class PeerClient {
* request failed, its status is stored here.
*/
handleLogin = async (req) => {
// TODO Look for a logout API call
const authClient = await this.#createAuthClient(req.uuid)
const loginRes = await authClient.login(req.form, true)
if (loginRes.status === KnownAuthStatusCode.DEVICE_NOT_REGISTERED) {
@ -226,6 +230,15 @@ export default class PeerClient {
return userClient
}
/**
* Unchecked lookup of a UserClient for a given mxid.
* @param {string} mxid
* @returns {UserClient | undefined}
*/
#tryGetUser(mxid) {
return this.userClients.get(mxid)
}
/**
* Get the service client for the specified user ID, or create
* and return a new service client if no user ID is provided.
@ -233,7 +246,7 @@ export default class PeerClient {
* @param {OAuthCredential} oauth_credential
*/
async #getServiceClient(mxid, oauth_credential) {
return this.userClients.get(mxid)?.serviceClient ||
return this.#tryGetUser(mxid)?.serviceClient ||
await ServiceApiClient.create(oauth_credential)
}
@ -251,26 +264,15 @@ export default class PeerClient {
* @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential
*/
handleStart = async (req) => {
handleConnect = async (req) => {
// TODO Don't re-login if possible. But must still return a LoginResult!
{
const oldUserClient = this.userClients.get(req.mxid)
if (oldUserClient !== undefined) {
oldUserClient.close()
this.userClients.delete(req.mxid)
}
}
this.handleDisconnect(req)
const userClient = await UserClient.create(req.mxid, req.oauth_credential)
const res = await userClient.talkClient.login(req.oauth_credential)
if (!res.success) return res
this.userClients.set(req.mxid, userClient)
return res
}
startListen = async (req) => {
const userClient = this.#getUser(req.mxid)
userClient.talkClient.on("chat", (data, channel) => {
this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
@ -291,7 +293,22 @@ export default class PeerClient {
})
*/
return this.#voidCommandResult
return res
}
/**
* @param {Object} req
* @param {string} req.mxid
*/
handleDisconnect = (req) => {
const userClient = this.#tryGetUser(req.mxid)
if (!!userClient) {
userClient.close()
this.userClients.delete(req.mxid)
return true
} else {
return false
}
}
/**
@ -368,16 +385,6 @@ export default class PeerClient {
})
}
/**
* @param {Object} req
* @param {string} req.mxid
*/
handleStop = async (req) => {
this.#getUser(req.mxid).close()
this.userClients.delete(req.mxid)
return this.#voidCommandResult
}
#makeCommandResult(result) {
return {
success: true,
@ -386,11 +393,6 @@ export default class PeerClient {
}
}
#voidCommandResult = {
success: true,
status: 0,
}
handleUnknownCommand = () => {
throw new Error("Unknown command")
}
@ -431,9 +433,6 @@ export default class PeerClient {
this.log("Ignoring old request", req.id)
return
}
if (req.command != "is_connected") {
this.log("Received request", req.id, "with command", req.command)
}
this.maxCommandID = req.id
let handler
if (!this.peerID) {
@ -449,21 +448,18 @@ export default class PeerClient {
handler = this.handleRegister
} else {
handler = {
// TODO Subclass / object for KakaoTalk-specific handlers?
start: this.handleStart,
stop: this.handleStop,
disconnect: () => this.stop(),
login: this.handleLogin,
renew: this.handleRenew,
// TODO Wrapper for per-user commands
generate_uuid: util.randomAndroidSubDeviceUUID,
register_device: this.registerDevice,
start_listen: this.startListen,
login: this.handleLogin,
renew: this.handleRenew,
connect: this.handleConnect,
disconnect: this.handleDisconnect,
get_own_profile: this.getOwnProfile,
get_profile: this.getProfile,
get_portal_channel_info: this.getPortalChannelInfo,
get_chats: this.getChats,
get_profile: this.getProfile,
send_message: this.sendMessage,
//is_connected: async () => ({ is_connected: !await this.puppet.isDisconnected() }),
}[req.command] || this.handleUnknownCommand
}
const resp = { id: req.id }
@ -482,6 +478,7 @@ export default class PeerClient {
resp.command = "error"
resp.error = err.toString()
this.log(`Error handling request ${resp.id} ${err.stack}`)
// TODO Check if session is broken. If it is, close the PeerClient
}
}
await this.#write(resp)

View File

@ -20,6 +20,7 @@ import path from "path"
import PeerClient from "./client.js"
import { promisify } from "./util.js"
export default class ClientManager {
constructor(listenConfig) {
this.listenConfig = listenConfig

View File

@ -15,12 +15,13 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import process from "process"
import fs from "fs"
import sd from "systemd-daemon"
import arg from "arg"
import sd from "systemd-daemon"
import ClientManager from "./clientmanager.js"
const args = arg({
"--config": String,
"-c": "--config",
@ -31,10 +32,10 @@ const configPath = args["--config"] || "config.json"
console.log("[Main] Reading config from", configPath)
const config = JSON.parse(fs.readFileSync(configPath).toString())
const api = new ClientManager(config.listen)
const manager = new ClientManager(config.listen)
function stop() {
api.stop().then(() => {
manager.stop().then(() => {
console.log("[Main] Everything stopped")
process.exit(0)
}, err => {
@ -43,7 +44,7 @@ function stop() {
})
}
api.start().then(() => {
manager.start().then(() => {
process.once("SIGINT", stop)
process.once("SIGTERM", stop)
sd.notify("READY=1")

View File

@ -1,6 +1,5 @@
aiohttp>=3,<4
asyncpg>=0.20,<0.26
bson>=0.5,<0.6
commonmark>=0.8,<0.10
mautrix==0.15.0rc4
pycryptodome>=3,<4