diff --git a/matrix_appservice_kakaotalk/__main__.py b/matrix_appservice_kakaotalk/__main__.py
index 3a05681..8a03485 100644
--- a/matrix_appservice_kakaotalk/__main__.py
+++ b/matrix_appservice_kakaotalk/__main__.py
@@ -72,8 +72,6 @@ class KakaoTalkBridge(Bridge):
puppet.stop()
self.log.debug("Stopping kakaotalk listeners")
User.shutdown = True
- for user in User.by_ktid.values():
- user.stop_listen()
self.add_shutdown_actions(user.save() for user in User.by_mxid.values())
self.add_shutdown_actions(KakaoTalkClient.stop_cls())
diff --git a/matrix_appservice_kakaotalk/commands/conn.py b/matrix_appservice_kakaotalk/commands/conn.py
index 6a1336e..3a171d6 100644
--- a/matrix_appservice_kakaotalk/commands/conn.py
+++ b/matrix_appservice_kakaotalk/commands/conn.py
@@ -17,6 +17,8 @@ from mautrix.bridge.commands import HelpSection, command_handler
from .typehint import CommandEvent
+from ..kt.client.errors import CommandException
+
SECTION_CONNECTION = HelpSection("Connection management", 15, "")
@@ -32,35 +34,6 @@ async def set_notice_room(evt: CommandEvent) -> None:
await evt.reply("This room has been marked as your bridge notice room")
-"""
-@command_handler(
- needs_auth=True,
- management_only=True,
- help_section=SECTION_CONNECTION,
- help_text="Disconnect from KakaoTalk",
-)
-async def disconnect(evt: CommandEvent) -> None:
- if not evt.sender.mqtt:
- await evt.reply("You don't have a KakaoTalk MQTT connection")
- return
- evt.sender.mqtt.disconnect()
-
-
-@command_handler(
- needs_auth=True,
- management_only=True,
- help_section=SECTION_CONNECTION,
- help_text="Connect to KakaoTalk",
- aliases=["reconnect"],
-)
-async def connect(evt: CommandEvent) -> None:
- if evt.sender.listen_task and not evt.sender.listen_task.done():
- await evt.reply("You already have a KakaoTalk MQTT connection")
- return
- evt.sender.start_listen()
-"""
-
-
@command_handler(
needs_auth=True,
management_only=True,
@@ -72,34 +45,11 @@ async def ping(evt: CommandEvent) -> None:
await evt.reply("You're not logged into KakaoTalk")
return
await evt.mark_read()
- # try:
- own_info = await evt.sender.get_own_info()
- # TODO catch errors
- # except fbchat.PleaseRefresh as e:
- # await evt.reply(f"{e}\n\nUse `$cmdprefix+sp refresh` refresh the session.")
- # return
- await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})")
-
- """
- if not evt.sender.listen_task or evt.sender.listen_task.done():
- await evt.reply("You don't have a KakaoTalk MQTT connection. Use `connect` to connect.")
- elif not evt.sender.is_connected:
- await evt.reply("The KakaoTalk MQTT listener is **disconnected**.")
- else:
- await evt.reply("The KakaoTalk MQTT listener is connected.")
- """
-
-
-"""
-@command_handler(
- needs_auth=True,
- management_only=True,
- help_section=SECTION_CONNECTION,
- help_text="Resync chats and reconnect to MQTT",
-)
-async def refresh(evt: CommandEvent) -> None:
- await evt.sender.refresh(force_notice=True)
-"""
+ try:
+ own_info = await evt.sender.get_own_info()
+ await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})")
+ except CommandException as e:
+ await evt.reply(f"Error from KakaoTalk: {e}")
@command_handler(
@@ -107,10 +57,19 @@ async def refresh(evt: CommandEvent) -> None:
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Resync chats",
+ help_args="[count]",
)
async def sync(evt: CommandEvent) -> None:
+ try:
+ sync_count = int(evt.args[0])
+ except IndexError:
+ sync_count = None
+ except ValueError:
+ await evt.reply("**Usage:** `$cmdprefix+sp logout [--reset-device]`")
+ return
+
await evt.mark_read()
- if await evt.sender.post_login(is_startup=False):
+ if await evt.sender.connect_and_sync(sync_count):
await evt.reply("Sync complete")
else:
await evt.reply("Sync failed")
diff --git a/matrix_appservice_kakaotalk/db/user.py b/matrix_appservice_kakaotalk/db/user.py
index 22f8dbe..476e186 100644
--- a/matrix_appservice_kakaotalk/db/user.py
+++ b/matrix_appservice_kakaotalk/db/user.py
@@ -47,24 +47,23 @@ class User:
def _from_optional_row(cls, row: Record | None) -> User | None:
return cls._from_row(row) if row is not None else None
+ _columns = "mxid, ktid, uuid, access_token, refresh_token, notice_room"
+
@classmethod
async def all_logged_in(cls) -> List[User]:
- q = """
- SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user"
- WHERE ktid<>0
- """
+ q = f'SELECT {cls._columns} FROM "user" WHERE ktid<>0'
rows = await cls.db.fetch(q)
return [cls._from_row(row) for row in rows if row]
@classmethod
async def get_by_ktid(cls, ktid: int) -> User | None:
- q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
+ q = f'SELECT {cls._columns} FROM "user" WHERE ktid=$1'
row = await cls.db.fetchrow(q, ktid)
return cls._from_optional_row(row)
@classmethod
async def get_by_mxid(cls, mxid: UserID) -> User | None:
- q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE mxid=$1'
+ q = f'SELECT {cls._columns} FROM "user" WHERE mxid=$1'
row = await cls.db.fetchrow(q, mxid)
return cls._from_optional_row(row)
@@ -73,23 +72,30 @@ class User:
q = 'SELECT uuid FROM "user" WHERE uuid IS NOT NULL'
return {tuple(record)[0] for record in await cls.db.fetch(q)}
+ @property
+ def _values(self):
+ return (
+ self.mxid,
+ self.ktid,
+ self.uuid,
+ self.access_token,
+ self.refresh_token,
+ self.notice_room,
+ )
+
async def insert(self) -> None:
q = """
INSERT INTO "user" (mxid, ktid, uuid, access_token, refresh_token, notice_room)
VALUES ($1, $2, $3, $4, $5, $6)
"""
- await self.db.execute(
- q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room
- )
+ await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
async def save(self) -> None:
q = """
- UPDATE "user" SET ktid=$1, uuid=$2, access_token=$3, refresh_token=$4, notice_room=$5
- WHERE mxid=$6
+ UPDATE "user" SET ktid=$2, uuid=$3, access_token=$4, refresh_token=$5, notice_room=$6
+ WHERE mxid=$1
"""
- await self.db.execute(
- q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
- )
+ await self.db.execute(q, *self._values)
diff --git a/matrix_appservice_kakaotalk/example-config.yaml b/matrix_appservice_kakaotalk/example-config.yaml
index 1f6334a..1425d6a 100644
--- a/matrix_appservice_kakaotalk/example-config.yaml
+++ b/matrix_appservice_kakaotalk/example-config.yaml
@@ -122,7 +122,7 @@ bridge:
command_prefix: "!kt"
# Number of chats to sync (and create portals for) on startup/login.
- # Set 0 to disable automatic syncing.
+ # Set to 0 to disable automatic syncing, or -1 to sync as much as possible.
initial_chat_sync: 20
# Whether or not the KakaoTalk users of logged in Matrix users should be
# invited to private chats when the user sends a message from another client.
@@ -188,11 +188,11 @@ bridge:
# usually needed to prevent rate limits and to allow timestamp massaging.
invite_own_puppet: true
# Maximum number of messages to backfill initially.
- # Set to 0 to disable backfilling when creating portal.
+ # Set to 0 to disable backfilling when creating portal, or -1 to backfill as much as possible.
initial_limit: 0
# Maximum number of messages to backfill if messages were missed while
# the bridge was disconnected.
- # Set to 0 to disable backfilling missed messages.
+ # Set to 0 to disable backfilling missed messages, or -1 to backfill as much as possible.
missed_limit: 1000
# If using double puppeting, should notifications be disabled
# while the initial backfill is in progress?
@@ -213,7 +213,7 @@ bridge:
# The number of seconds that a disconnection can last without triggering an automatic re-sync
# and missed message backfilling when reconnecting.
# Set to 0 to always re-sync, or -1 to never re-sync automatically.
- resync_max_disconnected_time: 5
+ #resync_max_disconnected_time: 5
# Should the bridge do a resync on startup?
sync_on_startup: true
# Whether or not temporary disconnections should send notices to the notice room.
diff --git a/matrix_appservice_kakaotalk/kt/client/client.py b/matrix_appservice_kakaotalk/kt/client/client.py
index 7ee83cb..b231993 100644
--- a/matrix_appservice_kakaotalk/kt/client/client.py
+++ b/matrix_appservice_kakaotalk/kt/client/client.py
@@ -22,7 +22,7 @@ with any other potential backend.
from __future__ import annotations
-from typing import TYPE_CHECKING, cast, Awaitable, Callable, Type, Optional, Union
+from typing import TYPE_CHECKING, cast, Type, Optional, Union
import logging
import urllib.request
@@ -39,7 +39,6 @@ from ...rpc import RPCClient
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
from ..types.bson import Long
from ..types.client.client_session import LoginResult
-from ..types.channel.channel_type import ChannelType
from ..types.chat.chat import Chatlog
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.request import (
@@ -63,7 +62,6 @@ except ImportError:
if TYPE_CHECKING:
from mautrix.types import JSON
from ...user import User
- from ...rpc.rpc import EventHandler
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
@@ -80,7 +78,6 @@ class Client:
@classmethod
async def stop_cls(cls) -> None:
"""Stop and disconnect from the Node backend."""
- await cls._rpc_client.request("stop")
await cls._rpc_client.disconnect()
@@ -102,16 +99,17 @@ class Client:
)
@classmethod
- async def register_device(cls, passcode: str, **req) -> None:
+ async def register_device(cls, passcode: str, **req: JSON) -> None:
"""Register a (fake) device that will be associated with the provided login credentials."""
await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req)
@classmethod
- async def login(cls, **req) -> OAuthCredential:
+ async def login(cls, **req: JSON) -> OAuthCredential:
"""
Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials.
"""
+ # NOTE Actually returns a LoginData object, but this only needs an OAuthCredential
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
# endregion
@@ -181,20 +179,21 @@ class Client:
self.user.oauth_credential = oauth_info.credential
await self.user.save()
- async def start(self) -> LoginResult:
+ async def connect(self) -> LoginResult:
"""
Start a new session by providing a token obtained from a prior login.
Receive a snapshot of account state in response.
"""
- login_result = await self._api_user_request_result(LoginResult, "start")
+ login_result = await self._api_user_request_result(LoginResult, "connect")
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
+ # TODO Skip if handlers are already listening. But this is idempotent and thus probably safe
+ self._start_listen()
return login_result
- """
- async def is_connected(self) -> bool:
- resp = await self._rpc_client.request("is_connected")
- return resp["is_connected"]
- """
+ async def disconnect(self) -> bool:
+ connection_existed = await self._rpc_client.request("disconnect", mxid=self.user.mxid)
+ self._stop_listen()
+ return connection_existed
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
@@ -232,14 +231,6 @@ class Client:
text=text
)
- async def start_listen(self) -> None:
- # TODO Connect all listeners here?
- await self._api_user_request_void("start_listen")
-
- async def stop(self) -> None:
- # TODO Stop all event handlers
- await self._api_user_request_void("stop")
-
# TODO Combine these into one
@@ -272,19 +263,31 @@ class Client:
# region listeners
- async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None:
- async def wrapper(data: dict[str, JSON]) -> None:
- await func(
- Chatlog.deserialize(data["chatlog"]),
- Long.deserialize(data["channelId"]),
- data["channelType"]
- )
+ async def _on_message(self, data: dict[str, JSON]) -> None:
+ await self.user.on_message(
+ Chatlog.deserialize(data["chatlog"]),
+ Long.deserialize(data["channelId"]),
+ data["channelType"]
+ )
- self._add_user_handler("message", wrapper)
+ """ TODO
+ async def _on_receipt(self, data: Dict[str, JSON]) -> None:
+ await self.user.on_receipt(Receipt.deserialize(data["receipt"]))
+ """
- def _add_user_handler(self, command: str, handler: EventHandler) -> str:
- self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler)
+ def _start_listen(self) -> None:
+ # TODO Automate this somehow, like with a fancy enum
+ self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [self._on_message])
+ # TODO many more listeners
+
+ def _stop_listen(self) -> None:
+ # TODO Automate this somehow, like with a fancy enum
+ self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [])
+ # TODO many more listeners
+
+ def _get_user_cmd(self, command) -> str:
+ return f"{command}:{self.user.mxid}"
# endregion
diff --git a/matrix_appservice_kakaotalk/kt/types/request.py b/matrix_appservice_kakaotalk/kt/types/request.py
index eb862a0..ac92693 100644
--- a/matrix_appservice_kakaotalk/kt/types/request.py
+++ b/matrix_appservice_kakaotalk/kt/types/request.py
@@ -79,6 +79,16 @@ class RootCommandResult(ResponseState):
"""For brevity, this also encompasses CommandResultFailed and CommandResultDoneVoid"""
success: bool
+ @classmethod
+ def deserialize(cls, data: JSON) -> "RootCommandResult":
+ if not data or "success" not in data or "status" not in data:
+ return RootCommandResult(
+ success=True,
+ status=KnownDataStatusCode.SUCCESS
+ )
+ else:
+ return super().deserialize(data)
+
ResultType = TypeVar("ResultType", bound=Serializable)
diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py
index 5b39445..6e7fbf9 100644
--- a/matrix_appservice_kakaotalk/portal.py
+++ b/matrix_appservice_kakaotalk/portal.py
@@ -16,7 +16,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, Pattern, cast
-from collections import deque
import asyncio
import re
import time
@@ -845,11 +844,6 @@ class Portal(DBPortal, BasePortal):
return False
return True
- async def _add_kakaotalk_reply(
- self, content: MessageEventContent, reply_to: None
- ) -> None:
- self.log.info("TODO")
-
async def handle_remote_message(
self,
source: u.User,
@@ -969,7 +963,7 @@ class Portal(DBPortal, BasePortal):
if not messages:
self.log.debug("Didn't get any messages from server")
return
- self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server")
+ self.log.debug(f"Got {len(messages)} message{'s' if len(messages) > 1 else ''} from server")
self._backfill_leave = set()
async with NotificationDisabler(self.mxid, source):
for message in messages:
diff --git a/matrix_appservice_kakaotalk/rpc/__init__.py b/matrix_appservice_kakaotalk/rpc/__init__.py
index 33792bd..f5125b1 100644
--- a/matrix_appservice_kakaotalk/rpc/__init__.py
+++ b/matrix_appservice_kakaotalk/rpc/__init__.py
@@ -1 +1,2 @@
from .rpc import RPCClient
+from .types import RPCError
diff --git a/matrix_appservice_kakaotalk/rpc/rpc.py b/matrix_appservice_kakaotalk/rpc/rpc.py
index 3820317..59d8d84 100644
--- a/matrix_appservice_kakaotalk/rpc/rpc.py
+++ b/matrix_appservice_kakaotalk/rpc/rpc.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
from __future__ import annotations
-from typing import Any, Callable, Awaitable, List
+from typing import Any, Callable, Awaitable
import asyncio
import json
@@ -39,7 +39,7 @@ class RPCClient:
_req_id: int
_min_broadcast_id: int
_response_waiters: dict[int, asyncio.Future[JSON]]
- _event_handlers: dict[str, List[EventHandler]]
+ _event_handlers: dict[str, list[EventHandler]]
_command_queue: asyncio.Queue
def __init__(self, config: Config) -> None:
@@ -98,7 +98,13 @@ class RPCClient:
self._event_handlers.setdefault(method, []).append(handler)
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
- self._event_handlers.setdefault(method, []).remove(handler)
+ try:
+ self._event_handlers.setdefault(method, []).remove(handler)
+ except ValueError:
+ pass
+
+ def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
+ self._event_handlers[method] = handlers
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:
diff --git a/matrix_appservice_kakaotalk/user.py b/matrix_appservice_kakaotalk/user.py
index 136d647..b8b7c89 100644
--- a/matrix_appservice_kakaotalk/user.py
+++ b/matrix_appservice_kakaotalk/user.py
@@ -46,9 +46,7 @@ from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
-METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels")
-METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync")
-METRIC_UNKNOWN_EVENT = Summary("bridge_on_unknown_event", "calls to on_unknown_event")
+METRIC_CONNECT_AND_SYNC = Summary("bridge_sync_channels", "calls to connect_and_sync")
METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added")
METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed")
METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing")
@@ -58,7 +56,6 @@ METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent")
METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen")
METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change")
METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change")
-METRIC_THREAD_CHANGE = Summary("bridge_on_thread_change", "calls to on_thread_change")
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
@@ -71,7 +68,6 @@ BridgeState.human_readable_errors.update(
"kt-reconnection-error": "Failed to reconnect to KakaoTalk",
"kt-connection-error": "KakaoTalk disconnected unexpectedly",
"kt-auth-error": "Authentication error from KakaoTalk: {message}",
- "kt-start-error": "Startup error from KakaoTalk: {message}",
"kt-disconnected": None,
"logged-out": "You're not logged into KakaoTalk",
}
@@ -79,7 +75,6 @@ BridgeState.human_readable_errors.update(
class User(DBUser, BaseUser):
- #temp_disconnect_notices: bool = True
shutdown: bool = False
config: Config
@@ -90,16 +85,13 @@ class User(DBUser, BaseUser):
_notice_room_lock: asyncio.Lock
_notice_send_lock: asyncio.Lock
- command_status: dict | None
is_admin: bool
permission_level: str
_is_logged_in: bool | None
- #_is_connected: bool | None
- #_connection_time: float
- _prev_reconnect_fail_refresh: float
+ _is_connected: bool | None
+ _connection_time: float
_db_instance: DBUser | None
_sync_lock: SimpleLock
- _is_refreshing: bool
_logged_in_info: ProfileStruct | None
_logged_in_info_time: float
@@ -124,7 +116,6 @@ class User(DBUser, BaseUser):
self.notice_room = notice_room
self._notice_room_lock = asyncio.Lock()
self._notice_send_lock = asyncio.Lock()
- self.command_status = None
(
self.relay_whitelisted,
self.is_whitelisted,
@@ -132,13 +123,11 @@ class User(DBUser, BaseUser):
self.permission_level,
) = self.config.get_permissions(mxid)
self._is_logged_in = None
- #self._is_connected = None
- #self._connection_time = time.monotonic()
- self._prev_reconnect_fail_refresh = time.monotonic()
+ self._is_connected = None
+ self._connection_time = time.monotonic()
self._sync_lock = SimpleLock(
"Waiting for thread sync to finish before handling %s", log=self.log
)
- self._is_refreshing = False
self._logged_in_info = None
self._logged_in_info_time = 0
@@ -150,10 +139,8 @@ class User(DBUser, BaseUser):
cls.config = bridge.config
cls.az = bridge.az
cls.loop = bridge.loop
- #cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
- """
@property
def is_connected(self) -> bool | None:
return self._is_connected
@@ -167,11 +154,11 @@ class User(DBUser, BaseUser):
@property
def connection_time(self) -> float:
return self._connection_time
- """
@property
def has_state(self) -> bool:
- return bool(self.uuid and self.ktid and self.access_token and self.refresh_token)
+ # TODO If more state is needed, consider returning a saved LoginResult
+ return bool(self.access_token and self.refresh_token)
# region Database getters
@@ -233,11 +220,13 @@ class User(DBUser, BaseUser):
async def get_uuid(self, force: bool = False) -> str:
if self.uuid is None or force:
self.uuid = await self._generate_uuid()
+ # TODO Maybe don't save yet
await self.save()
return self.uuid
- async def _generate_uuid(self) -> str:
- return await Client.generate_uuid(await self.get_all_uuids())
+ @classmethod
+ async def _generate_uuid(cls) -> str:
+ return await Client.generate_uuid(await super().get_all_uuids())
# endregion
@@ -285,8 +274,7 @@ class User(DBUser, BaseUser):
self._logged_in_info_time = time.monotonic()
self._track_metric(METRIC_LOGGED_IN, True)
self._is_logged_in = True
- #self.is_connected = None
- self.stop_listen()
+ self.is_connected = None
asyncio.create_task(self.post_login(is_startup=is_startup))
return True
return False
@@ -310,10 +298,9 @@ class User(DBUser, BaseUser):
) -> ProfileStruct:
if not client:
client = self.client
- # TODO Retry network connection failures here, or in the client?
+ # TODO Retry network connection failures here, or in the client (like token refreshes are)?
try:
return await client.fetch_logged_in_user()
- # NOTE Not catching InvalidAccessToken here, as client handles it & tries to refresh the token
except AuthenticationRequired as e:
if action != "restore session":
await self._send_reset_notice(e)
@@ -352,7 +339,7 @@ class User(DBUser, BaseUser):
state_event=BridgeStateEvent.TRANSIENT_DISCONNECT,
)
await asyncio.sleep(60)
- await self.reload_session(event_id, retries - 1)
+ await self.reload_session(event_id, retries - 1, is_startup)
else:
await self.send_bridge_notice(
notice,
@@ -362,44 +349,43 @@ class User(DBUser, BaseUser):
error_code="kt-reconnection-error",
)
except Exception:
+ self.log.exception("Error connecting to KakaoTalk")
await self.send_bridge_notice(
"Failed to connect to KakaoTalk: unknown error (see logs for more details)",
edit=event_id,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
- finally:
- self._is_refreshing = False
- async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> bool:
- ok = True
- self.stop_listen()
- if self.has_state:
- # TODO Log out of KakaoTalk if an API exists for it
- pass
+ async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
+ if self.client:
+ # TODO Look for a logout API call
+ was_connected = await self.client.disconnect()
+ if was_connected != self._is_connected:
+ self.log.warn(
+ f"Node backend was{' not' if not was_connected else ''} connected, "
+ f"but we thought it was{' not' if not self._is_connected else ''}")
if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
self._track_metric(METRIC_LOGGED_IN, False)
- self._is_logged_in = False
- #self.is_connected = None
- if self.client:
- await self.client.stop()
- self.client = None
-
- if self.ktid and remove_ktid:
- #await UserPortal.delete_all(self.ktid)
- del self.by_ktid[self.ktid]
- self.ktid = None
if reset_device:
self.uuid = await self._generate_uuid()
self.access_token = None
self.refresh_token = None
- await self.save()
- return ok
+ self._is_logged_in = False
+ self.is_connected = None
+ self.client = None
- async def post_login(self, is_startup: bool) -> bool:
+ if self.ktid and remove_ktid:
+ #await UserPortal.delete_all(self.ktid)
+ del self.by_ktid[self.ktid]
+ self.ktid = None
+
+ await self.save()
+
+ async def post_login(self, is_startup: bool) -> None:
self.log.info("Running post-login actions")
self._add_to_cache()
@@ -412,12 +398,27 @@ class User(DBUser, BaseUser):
except Exception:
self.log.exception("Failed to automatically enable custom puppet")
- assert self.client
+ # TODO Check if things break when a live message comes in during syncing
+ if self.config["bridge.sync_on_startup"] or not is_startup:
+ sync_count = self.config["bridge.initial_chat_sync"]
+ else:
+ sync_count = None
+ await self.connect_and_sync(sync_count)
+
+ async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
+ return {
+ pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
+ async for portal in po.Portal.get_all_by_receiver(self.ktid)
+ if portal.mxid
+ }
+
+ @async_time(METRIC_CONNECT_AND_SYNC)
+ async def connect_and_sync(self, sync_count: int | None) -> bool:
+ # TODO Look for a way to sync all channels without (re-)logging in
try:
- # TODO if not is_startup, close existing listeners
- login_result = await self.client.start()
- await self._sync_channels(login_result, is_startup)
- self.start_listen()
+ login_result = await self.client.connect()
+ await self.push_bridge_state(BridgeStateEvent.CONNECTED)
+ await self._sync_channels(login_result, sync_count)
return True
except AuthenticationRequired as e:
await self.send_bridge_notice(
@@ -429,41 +430,34 @@ class User(DBUser, BaseUser):
)
await self.logout(remove_ktid=False)
except Exception as e:
- self.log.exception("Failed to start client")
- await self.send_bridge_notice(
- f"Got error from KakaoTalk:\n\n> {e!s}\n\n",
- important=True,
- state_event=BridgeStateEvent.UNKNOWN_ERROR,
- error_code="kt-start-error",
- error_message=str(e),
- )
+ self.log.exception("Failed to connect and sync")
+ await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e))
return False
- async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
- return {
- pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
- async for portal in po.Portal.get_all_by_receiver(self.ktid)
- if portal.mxid
- }
-
- @async_time(METRIC_SYNC_CHANNELS)
- async def _sync_channels(self, login_result: LoginResult, is_startup: bool) -> None:
- # TODO Look for a way to sync all channels without (re-)logging in
- sync_count = self.config["bridge.initial_chat_sync"]
- if sync_count <= 0 or not self.config["bridge.sync_on_startup"] and is_startup:
- self.log.debug(f"Skipping channel syncing{' on startup' if sync_count > 0 else ''}")
+ async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None:
+ if sync_count is None:
+ sync_count = self.config["bridge.initial_chat_sync"]
+ if not sync_count:
+ self.log.debug("Skipping channel syncing")
return
if not login_result.channelList:
self.log.debug("No channels to sync")
return
# TODO What about removed channels? Don't early-return then
- sync_count = min(sync_count, len(login_result.channelList))
+ num_channels = len(login_result.channelList)
+ sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
- self.log.debug(f"Syncing {sync_count} of {len(login_result.channelList)} channels...")
+ self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
for channel_item in login_result.channelList[:sync_count]:
- # TODO try-except here, above, below?
- await self._sync_channel(channel_item)
+ try:
+ await self._sync_channel(channel_item)
+ except AuthenticationRequired:
+ raise
+ except Exception:
+ self.log.exception(f"Failed to sync channel {channel_item.channel.channelId}")
+
+ await self.update_direct_chats()
async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None:
channel_data = channel_item.channel
@@ -573,18 +567,12 @@ class User(DBUser, BaseUser):
state.remote_name = puppet.name
async def get_bridge_states(self) -> list[BridgeState]:
- self.log.info("TODO: get_bridge_states")
- return []
- """
if not self.state:
return []
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED
- elif self._is_refreshing or self.mqtt:
- state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state]
- """
async def get_puppet(self) -> pu.Puppet | None:
if not self.ktid:
@@ -593,36 +581,10 @@ class User(DBUser, BaseUser):
# region KakaoTalk event handling
- def start_listen(self) -> None:
- self.listen_task = asyncio.create_task(self._try_listen())
-
- def _disconnect_listener_after_error(self) -> None:
- self.log.info("TODO: _disconnect_listener_after_error")
-
- async def _try_listen(self) -> None:
- try:
- # TODO Pass all listeners to start_listen instead of registering them one-by-one?
- await self.client.start_listen()
- await self.client.on_message(self.on_message)
- # TODO Handle auth errors specially?
- #except AuthenticationRequired as e:
- except Exception:
- #self.is_connected = False
- self.log.exception("Fatal error in listener")
- await self.send_bridge_notice(
- "Fatal error in listener (see logs for more info)",
- state_event=BridgeStateEvent.UNKNOWN_ERROR,
- important=True,
- error_code="kt-connection-error",
- )
- self._disconnect_listener_after_error()
-
- def stop_listen(self) -> None:
- self.log.info("TODO: stop_listen")
-
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential
+ await self.push_bridge_state(BridgeStateEvent.CONNECTING)
self.client = Client(self, log=self.log.getChild("ktclient"))
await self.save()
self._is_logged_in = True
@@ -631,7 +593,6 @@ class User(DBUser, BaseUser):
self._logged_in_info_time = time.monotonic()
except Exception:
self.log.exception("Failed to fetch post-login info")
- self.stop_listen()
asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE)
diff --git a/node/src/client.js b/node/src/client.js
index 6c203ac..f05d8de 100644
--- a/node/src/client.js
+++ b/node/src/client.js
@@ -14,7 +14,6 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
import { Long } from "bson"
-import { emitLines, promisify } from "./util.js"
import {
AuthApiClient,
OAuthApiClient,
@@ -28,6 +27,8 @@ const { KnownChatType } = chat
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
/** @typedef {import("./clientmanager.js").default} ClientManager} */
+import { emitLines, promisify } from "./util.js"
+
class UserClient {
static #initializing = false
@@ -36,7 +37,7 @@ class UserClient {
get talkClient() { return this.#talkClient }
/** @type {ServiceApiClient} */
- #serviceClient = null
+ #serviceClient
get serviceClient() { return this.#serviceClient }
/**
@@ -80,7 +81,6 @@ class UserClient {
}
export default class PeerClient {
-
/**
* @param {ClientManager} manager
* @param {import("net").Socket} socket
@@ -136,6 +136,7 @@ export default class PeerClient {
return
}
this.stopped = true
+ this.#closeUsers()
try {
await this.#write({ id: --this.notificationID, command: "quit", error })
await promisify(cb => this.socket.end(cb))
@@ -145,20 +146,21 @@ export default class PeerClient {
}
}
- handleEnd = async () => {
- // TODO Persist clients across bridge disconnections.
- // But then have to queue received events until bridge acks them!
+ handleEnd = () => {
+ this.stopped = true
+ this.#closeUsers()
+ if (this.peerID && this.manager.clients.get(this.peerID) === this) {
+ this.manager.clients.delete(this.peerID)
+ }
+ this.log(`Connection closed (peer: ${this.peerID})`)
+ }
+
+ #closeUsers() {
this.log("Closing all API clients for", this.peerID)
for (const userClient of this.userClients.values()) {
userClient.close()
}
this.userClients.clear()
-
- this.stopped = true
- if (this.peerID && this.manager.clients.get(this.peerID) === this) {
- this.manager.clients.delete(this.peerID)
- }
- this.log(`Connection closed (peer: ${this.peerID})`)
}
/**
@@ -178,6 +180,7 @@ export default class PeerClient {
* @param {Object} req.form
*/
registerDevice = async (req) => {
+ // TODO Look for a deregister API call
const authClient = await this.#createAuthClient(req.uuid)
return await authClient.registerDevice(req.form, req.passcode, true)
}
@@ -192,6 +195,7 @@ export default class PeerClient {
* request failed, its status is stored here.
*/
handleLogin = async (req) => {
+ // TODO Look for a logout API call
const authClient = await this.#createAuthClient(req.uuid)
const loginRes = await authClient.login(req.form, true)
if (loginRes.status === KnownAuthStatusCode.DEVICE_NOT_REGISTERED) {
@@ -226,6 +230,15 @@ export default class PeerClient {
return userClient
}
+ /**
+ * Unchecked lookup of a UserClient for a given mxid.
+ * @param {string} mxid
+ * @returns {UserClient | undefined}
+ */
+ #tryGetUser(mxid) {
+ return this.userClients.get(mxid)
+ }
+
/**
* Get the service client for the specified user ID, or create
* and return a new service client if no user ID is provided.
@@ -233,7 +246,7 @@ export default class PeerClient {
* @param {OAuthCredential} oauth_credential
*/
async #getServiceClient(mxid, oauth_credential) {
- return this.userClients.get(mxid)?.serviceClient ||
+ return this.#tryGetUser(mxid)?.serviceClient ||
await ServiceApiClient.create(oauth_credential)
}
@@ -251,26 +264,15 @@ export default class PeerClient {
* @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential
*/
- handleStart = async (req) => {
+ handleConnect = async (req) => {
// TODO Don't re-login if possible. But must still return a LoginResult!
- {
- const oldUserClient = this.userClients.get(req.mxid)
- if (oldUserClient !== undefined) {
- oldUserClient.close()
- this.userClients.delete(req.mxid)
- }
- }
+ this.handleDisconnect(req)
const userClient = await UserClient.create(req.mxid, req.oauth_credential)
const res = await userClient.talkClient.login(req.oauth_credential)
if (!res.success) return res
this.userClients.set(req.mxid, userClient)
- return res
- }
-
- startListen = async (req) => {
- const userClient = this.#getUser(req.mxid)
userClient.talkClient.on("chat", (data, channel) => {
this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
@@ -291,7 +293,22 @@ export default class PeerClient {
})
*/
- return this.#voidCommandResult
+ return res
+ }
+
+ /**
+ * @param {Object} req
+ * @param {string} req.mxid
+ */
+ handleDisconnect = (req) => {
+ const userClient = this.#tryGetUser(req.mxid)
+ if (!!userClient) {
+ userClient.close()
+ this.userClients.delete(req.mxid)
+ return true
+ } else {
+ return false
+ }
}
/**
@@ -368,16 +385,6 @@ export default class PeerClient {
})
}
- /**
- * @param {Object} req
- * @param {string} req.mxid
- */
- handleStop = async (req) => {
- this.#getUser(req.mxid).close()
- this.userClients.delete(req.mxid)
- return this.#voidCommandResult
- }
-
#makeCommandResult(result) {
return {
success: true,
@@ -386,11 +393,6 @@ export default class PeerClient {
}
}
- #voidCommandResult = {
- success: true,
- status: 0,
- }
-
handleUnknownCommand = () => {
throw new Error("Unknown command")
}
@@ -431,9 +433,6 @@ export default class PeerClient {
this.log("Ignoring old request", req.id)
return
}
- if (req.command != "is_connected") {
- this.log("Received request", req.id, "with command", req.command)
- }
this.maxCommandID = req.id
let handler
if (!this.peerID) {
@@ -449,21 +448,18 @@ export default class PeerClient {
handler = this.handleRegister
} else {
handler = {
- // TODO Subclass / object for KakaoTalk-specific handlers?
- start: this.handleStart,
- stop: this.handleStop,
- disconnect: () => this.stop(),
- login: this.handleLogin,
- renew: this.handleRenew,
+ // TODO Wrapper for per-user commands
generate_uuid: util.randomAndroidSubDeviceUUID,
register_device: this.registerDevice,
- start_listen: this.startListen,
+ login: this.handleLogin,
+ renew: this.handleRenew,
+ connect: this.handleConnect,
+ disconnect: this.handleDisconnect,
get_own_profile: this.getOwnProfile,
+ get_profile: this.getProfile,
get_portal_channel_info: this.getPortalChannelInfo,
get_chats: this.getChats,
- get_profile: this.getProfile,
send_message: this.sendMessage,
- //is_connected: async () => ({ is_connected: !await this.puppet.isDisconnected() }),
}[req.command] || this.handleUnknownCommand
}
const resp = { id: req.id }
@@ -482,6 +478,7 @@ export default class PeerClient {
resp.command = "error"
resp.error = err.toString()
this.log(`Error handling request ${resp.id} ${err.stack}`)
+ // TODO Check if session is broken. If it is, close the PeerClient
}
}
await this.#write(resp)
diff --git a/node/src/clientmanager.js b/node/src/clientmanager.js
index c993d26..988dc63 100644
--- a/node/src/clientmanager.js
+++ b/node/src/clientmanager.js
@@ -20,6 +20,7 @@ import path from "path"
import PeerClient from "./client.js"
import { promisify } from "./util.js"
+
export default class ClientManager {
constructor(listenConfig) {
this.listenConfig = listenConfig
diff --git a/node/src/main.js b/node/src/main.js
index 6fcc49a..d258d9f 100644
--- a/node/src/main.js
+++ b/node/src/main.js
@@ -15,12 +15,13 @@
// along with this program. If not, see .
import process from "process"
import fs from "fs"
-import sd from "systemd-daemon"
import arg from "arg"
+import sd from "systemd-daemon"
import ClientManager from "./clientmanager.js"
+
const args = arg({
"--config": String,
"-c": "--config",
@@ -31,10 +32,10 @@ const configPath = args["--config"] || "config.json"
console.log("[Main] Reading config from", configPath)
const config = JSON.parse(fs.readFileSync(configPath).toString())
-const api = new ClientManager(config.listen)
+const manager = new ClientManager(config.listen)
function stop() {
- api.stop().then(() => {
+ manager.stop().then(() => {
console.log("[Main] Everything stopped")
process.exit(0)
}, err => {
@@ -43,7 +44,7 @@ function stop() {
})
}
-api.start().then(() => {
+manager.start().then(() => {
process.once("SIGINT", stop)
process.once("SIGTERM", stop)
sd.notify("READY=1")
diff --git a/requirements.txt b/requirements.txt
index 8ea98d6..ca5978a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,5 @@
aiohttp>=3,<4
asyncpg>=0.20,<0.26
-bson>=0.5,<0.6
commonmark>=0.8,<0.10
mautrix==0.15.0rc4
pycryptodome>=3,<4