Improved syncing, cleanups
This commit is contained in:
parent
d975dd8c1f
commit
60c47e5a20
|
@ -72,8 +72,6 @@ class KakaoTalkBridge(Bridge):
|
||||||
puppet.stop()
|
puppet.stop()
|
||||||
self.log.debug("Stopping kakaotalk listeners")
|
self.log.debug("Stopping kakaotalk listeners")
|
||||||
User.shutdown = True
|
User.shutdown = True
|
||||||
for user in User.by_ktid.values():
|
|
||||||
user.stop_listen()
|
|
||||||
self.add_shutdown_actions(user.save() for user in User.by_mxid.values())
|
self.add_shutdown_actions(user.save() for user in User.by_mxid.values())
|
||||||
self.add_shutdown_actions(KakaoTalkClient.stop_cls())
|
self.add_shutdown_actions(KakaoTalkClient.stop_cls())
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@ from mautrix.bridge.commands import HelpSection, command_handler
|
||||||
|
|
||||||
from .typehint import CommandEvent
|
from .typehint import CommandEvent
|
||||||
|
|
||||||
|
from ..kt.client.errors import CommandException
|
||||||
|
|
||||||
SECTION_CONNECTION = HelpSection("Connection management", 15, "")
|
SECTION_CONNECTION = HelpSection("Connection management", 15, "")
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,35 +34,6 @@ async def set_notice_room(evt: CommandEvent) -> None:
|
||||||
await evt.reply("This room has been marked as your bridge notice room")
|
await evt.reply("This room has been marked as your bridge notice room")
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
@command_handler(
|
|
||||||
needs_auth=True,
|
|
||||||
management_only=True,
|
|
||||||
help_section=SECTION_CONNECTION,
|
|
||||||
help_text="Disconnect from KakaoTalk",
|
|
||||||
)
|
|
||||||
async def disconnect(evt: CommandEvent) -> None:
|
|
||||||
if not evt.sender.mqtt:
|
|
||||||
await evt.reply("You don't have a KakaoTalk MQTT connection")
|
|
||||||
return
|
|
||||||
evt.sender.mqtt.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
@command_handler(
|
|
||||||
needs_auth=True,
|
|
||||||
management_only=True,
|
|
||||||
help_section=SECTION_CONNECTION,
|
|
||||||
help_text="Connect to KakaoTalk",
|
|
||||||
aliases=["reconnect"],
|
|
||||||
)
|
|
||||||
async def connect(evt: CommandEvent) -> None:
|
|
||||||
if evt.sender.listen_task and not evt.sender.listen_task.done():
|
|
||||||
await evt.reply("You already have a KakaoTalk MQTT connection")
|
|
||||||
return
|
|
||||||
evt.sender.start_listen()
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@command_handler(
|
@command_handler(
|
||||||
needs_auth=True,
|
needs_auth=True,
|
||||||
management_only=True,
|
management_only=True,
|
||||||
|
@ -72,34 +45,11 @@ async def ping(evt: CommandEvent) -> None:
|
||||||
await evt.reply("You're not logged into KakaoTalk")
|
await evt.reply("You're not logged into KakaoTalk")
|
||||||
return
|
return
|
||||||
await evt.mark_read()
|
await evt.mark_read()
|
||||||
# try:
|
try:
|
||||||
own_info = await evt.sender.get_own_info()
|
own_info = await evt.sender.get_own_info()
|
||||||
# TODO catch errors
|
await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})")
|
||||||
# except fbchat.PleaseRefresh as e:
|
except CommandException as e:
|
||||||
# await evt.reply(f"{e}\n\nUse `$cmdprefix+sp refresh` refresh the session.")
|
await evt.reply(f"Error from KakaoTalk: {e}")
|
||||||
# return
|
|
||||||
await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})")
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not evt.sender.listen_task or evt.sender.listen_task.done():
|
|
||||||
await evt.reply("You don't have a KakaoTalk MQTT connection. Use `connect` to connect.")
|
|
||||||
elif not evt.sender.is_connected:
|
|
||||||
await evt.reply("The KakaoTalk MQTT listener is **disconnected**.")
|
|
||||||
else:
|
|
||||||
await evt.reply("The KakaoTalk MQTT listener is connected.")
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
@command_handler(
|
|
||||||
needs_auth=True,
|
|
||||||
management_only=True,
|
|
||||||
help_section=SECTION_CONNECTION,
|
|
||||||
help_text="Resync chats and reconnect to MQTT",
|
|
||||||
)
|
|
||||||
async def refresh(evt: CommandEvent) -> None:
|
|
||||||
await evt.sender.refresh(force_notice=True)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@command_handler(
|
@command_handler(
|
||||||
|
@ -107,10 +57,19 @@ async def refresh(evt: CommandEvent) -> None:
|
||||||
management_only=True,
|
management_only=True,
|
||||||
help_section=SECTION_CONNECTION,
|
help_section=SECTION_CONNECTION,
|
||||||
help_text="Resync chats",
|
help_text="Resync chats",
|
||||||
|
help_args="[count]",
|
||||||
)
|
)
|
||||||
async def sync(evt: CommandEvent) -> None:
|
async def sync(evt: CommandEvent) -> None:
|
||||||
|
try:
|
||||||
|
sync_count = int(evt.args[0])
|
||||||
|
except IndexError:
|
||||||
|
sync_count = None
|
||||||
|
except ValueError:
|
||||||
|
await evt.reply("**Usage:** `$cmdprefix+sp logout [--reset-device]`")
|
||||||
|
return
|
||||||
|
|
||||||
await evt.mark_read()
|
await evt.mark_read()
|
||||||
if await evt.sender.post_login(is_startup=False):
|
if await evt.sender.connect_and_sync(sync_count):
|
||||||
await evt.reply("Sync complete")
|
await evt.reply("Sync complete")
|
||||||
else:
|
else:
|
||||||
await evt.reply("Sync failed")
|
await evt.reply("Sync failed")
|
||||||
|
|
|
@ -47,24 +47,23 @@ class User:
|
||||||
def _from_optional_row(cls, row: Record | None) -> User | None:
|
def _from_optional_row(cls, row: Record | None) -> User | None:
|
||||||
return cls._from_row(row) if row is not None else None
|
return cls._from_row(row) if row is not None else None
|
||||||
|
|
||||||
|
_columns = "mxid, ktid, uuid, access_token, refresh_token, notice_room"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def all_logged_in(cls) -> List[User]:
|
async def all_logged_in(cls) -> List[User]:
|
||||||
q = """
|
q = f'SELECT {cls._columns} FROM "user" WHERE ktid<>0'
|
||||||
SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user"
|
|
||||||
WHERE ktid<>0
|
|
||||||
"""
|
|
||||||
rows = await cls.db.fetch(q)
|
rows = await cls.db.fetch(q)
|
||||||
return [cls._from_row(row) for row in rows if row]
|
return [cls._from_row(row) for row in rows if row]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_ktid(cls, ktid: int) -> User | None:
|
async def get_by_ktid(cls, ktid: int) -> User | None:
|
||||||
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
|
q = f'SELECT {cls._columns} FROM "user" WHERE ktid=$1'
|
||||||
row = await cls.db.fetchrow(q, ktid)
|
row = await cls.db.fetchrow(q, ktid)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_mxid(cls, mxid: UserID) -> User | None:
|
async def get_by_mxid(cls, mxid: UserID) -> User | None:
|
||||||
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE mxid=$1'
|
q = f'SELECT {cls._columns} FROM "user" WHERE mxid=$1'
|
||||||
row = await cls.db.fetchrow(q, mxid)
|
row = await cls.db.fetchrow(q, mxid)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
|
@ -73,23 +72,30 @@ class User:
|
||||||
q = 'SELECT uuid FROM "user" WHERE uuid IS NOT NULL'
|
q = 'SELECT uuid FROM "user" WHERE uuid IS NOT NULL'
|
||||||
return {tuple(record)[0] for record in await cls.db.fetch(q)}
|
return {tuple(record)[0] for record in await cls.db.fetch(q)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _values(self):
|
||||||
|
return (
|
||||||
|
self.mxid,
|
||||||
|
self.ktid,
|
||||||
|
self.uuid,
|
||||||
|
self.access_token,
|
||||||
|
self.refresh_token,
|
||||||
|
self.notice_room,
|
||||||
|
)
|
||||||
|
|
||||||
async def insert(self) -> None:
|
async def insert(self) -> None:
|
||||||
q = """
|
q = """
|
||||||
INSERT INTO "user" (mxid, ktid, uuid, access_token, refresh_token, notice_room)
|
INSERT INTO "user" (mxid, ktid, uuid, access_token, refresh_token, notice_room)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
"""
|
"""
|
||||||
await self.db.execute(
|
await self.db.execute(q, *self._values)
|
||||||
q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
|
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
|
||||||
|
|
||||||
async def save(self) -> None:
|
async def save(self) -> None:
|
||||||
q = """
|
q = """
|
||||||
UPDATE "user" SET ktid=$1, uuid=$2, access_token=$3, refresh_token=$4, notice_room=$5
|
UPDATE "user" SET ktid=$2, uuid=$3, access_token=$4, refresh_token=$5, notice_room=$6
|
||||||
WHERE mxid=$6
|
WHERE mxid=$1
|
||||||
"""
|
"""
|
||||||
await self.db.execute(
|
await self.db.execute(q, *self._values)
|
||||||
q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
|
|
||||||
)
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ bridge:
|
||||||
command_prefix: "!kt"
|
command_prefix: "!kt"
|
||||||
|
|
||||||
# Number of chats to sync (and create portals for) on startup/login.
|
# Number of chats to sync (and create portals for) on startup/login.
|
||||||
# Set 0 to disable automatic syncing.
|
# Set to 0 to disable automatic syncing, or -1 to sync as much as possible.
|
||||||
initial_chat_sync: 20
|
initial_chat_sync: 20
|
||||||
# Whether or not the KakaoTalk users of logged in Matrix users should be
|
# Whether or not the KakaoTalk users of logged in Matrix users should be
|
||||||
# invited to private chats when the user sends a message from another client.
|
# invited to private chats when the user sends a message from another client.
|
||||||
|
@ -188,11 +188,11 @@ bridge:
|
||||||
# usually needed to prevent rate limits and to allow timestamp massaging.
|
# usually needed to prevent rate limits and to allow timestamp massaging.
|
||||||
invite_own_puppet: true
|
invite_own_puppet: true
|
||||||
# Maximum number of messages to backfill initially.
|
# Maximum number of messages to backfill initially.
|
||||||
# Set to 0 to disable backfilling when creating portal.
|
# Set to 0 to disable backfilling when creating portal, or -1 to backfill as much as possible.
|
||||||
initial_limit: 0
|
initial_limit: 0
|
||||||
# Maximum number of messages to backfill if messages were missed while
|
# Maximum number of messages to backfill if messages were missed while
|
||||||
# the bridge was disconnected.
|
# the bridge was disconnected.
|
||||||
# Set to 0 to disable backfilling missed messages.
|
# Set to 0 to disable backfilling missed messages, or -1 to backfill as much as possible.
|
||||||
missed_limit: 1000
|
missed_limit: 1000
|
||||||
# If using double puppeting, should notifications be disabled
|
# If using double puppeting, should notifications be disabled
|
||||||
# while the initial backfill is in progress?
|
# while the initial backfill is in progress?
|
||||||
|
@ -213,7 +213,7 @@ bridge:
|
||||||
# The number of seconds that a disconnection can last without triggering an automatic re-sync
|
# The number of seconds that a disconnection can last without triggering an automatic re-sync
|
||||||
# and missed message backfilling when reconnecting.
|
# and missed message backfilling when reconnecting.
|
||||||
# Set to 0 to always re-sync, or -1 to never re-sync automatically.
|
# Set to 0 to always re-sync, or -1 to never re-sync automatically.
|
||||||
resync_max_disconnected_time: 5
|
#resync_max_disconnected_time: 5
|
||||||
# Should the bridge do a resync on startup?
|
# Should the bridge do a resync on startup?
|
||||||
sync_on_startup: true
|
sync_on_startup: true
|
||||||
# Whether or not temporary disconnections should send notices to the notice room.
|
# Whether or not temporary disconnections should send notices to the notice room.
|
||||||
|
|
|
@ -22,7 +22,7 @@ with any other potential backend.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, cast, Awaitable, Callable, Type, Optional, Union
|
from typing import TYPE_CHECKING, cast, Type, Optional, Union
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
@ -39,7 +39,6 @@ from ...rpc import RPCClient
|
||||||
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
|
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
|
||||||
from ..types.bson import Long
|
from ..types.bson import Long
|
||||||
from ..types.client.client_session import LoginResult
|
from ..types.client.client_session import LoginResult
|
||||||
from ..types.channel.channel_type import ChannelType
|
|
||||||
from ..types.chat.chat import Chatlog
|
from ..types.chat.chat import Chatlog
|
||||||
from ..types.oauth import OAuthCredential, OAuthInfo
|
from ..types.oauth import OAuthCredential, OAuthInfo
|
||||||
from ..types.request import (
|
from ..types.request import (
|
||||||
|
@ -63,7 +62,6 @@ except ImportError:
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mautrix.types import JSON
|
from mautrix.types import JSON
|
||||||
from ...user import User
|
from ...user import User
|
||||||
from ...rpc.rpc import EventHandler
|
|
||||||
|
|
||||||
|
|
||||||
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
|
# TODO Consider defining an interface for this, with node/native backend as swappable implementations
|
||||||
|
@ -80,7 +78,6 @@ class Client:
|
||||||
@classmethod
|
@classmethod
|
||||||
async def stop_cls(cls) -> None:
|
async def stop_cls(cls) -> None:
|
||||||
"""Stop and disconnect from the Node backend."""
|
"""Stop and disconnect from the Node backend."""
|
||||||
await cls._rpc_client.request("stop")
|
|
||||||
await cls._rpc_client.disconnect()
|
await cls._rpc_client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,16 +99,17 @@ class Client:
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def register_device(cls, passcode: str, **req) -> None:
|
async def register_device(cls, passcode: str, **req: JSON) -> None:
|
||||||
"""Register a (fake) device that will be associated with the provided login credentials."""
|
"""Register a (fake) device that will be associated with the provided login credentials."""
|
||||||
await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req)
|
await cls._api_request_void("register_device", passcode=passcode, is_secret=True, **req)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def login(cls, **req) -> OAuthCredential:
|
async def login(cls, **req: JSON) -> OAuthCredential:
|
||||||
"""
|
"""
|
||||||
Obtain a session token by logging in with user-provided credentials.
|
Obtain a session token by logging in with user-provided credentials.
|
||||||
Must have first called register_device with these credentials.
|
Must have first called register_device with these credentials.
|
||||||
"""
|
"""
|
||||||
|
# NOTE Actually returns a LoginData object, but this only needs an OAuthCredential
|
||||||
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
|
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
@ -181,20 +179,21 @@ class Client:
|
||||||
self.user.oauth_credential = oauth_info.credential
|
self.user.oauth_credential = oauth_info.credential
|
||||||
await self.user.save()
|
await self.user.save()
|
||||||
|
|
||||||
async def start(self) -> LoginResult:
|
async def connect(self) -> LoginResult:
|
||||||
"""
|
"""
|
||||||
Start a new session by providing a token obtained from a prior login.
|
Start a new session by providing a token obtained from a prior login.
|
||||||
Receive a snapshot of account state in response.
|
Receive a snapshot of account state in response.
|
||||||
"""
|
"""
|
||||||
login_result = await self._api_user_request_result(LoginResult, "start")
|
login_result = await self._api_user_request_result(LoginResult, "connect")
|
||||||
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
||||||
|
# TODO Skip if handlers are already listening. But this is idempotent and thus probably safe
|
||||||
|
self._start_listen()
|
||||||
return login_result
|
return login_result
|
||||||
|
|
||||||
"""
|
async def disconnect(self) -> bool:
|
||||||
async def is_connected(self) -> bool:
|
connection_existed = await self._rpc_client.request("disconnect", mxid=self.user.mxid)
|
||||||
resp = await self._rpc_client.request("is_connected")
|
self._stop_listen()
|
||||||
return resp["is_connected"]
|
return connection_existed
|
||||||
"""
|
|
||||||
|
|
||||||
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
|
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
|
||||||
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
|
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
|
||||||
|
@ -232,14 +231,6 @@ class Client:
|
||||||
text=text
|
text=text
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start_listen(self) -> None:
|
|
||||||
# TODO Connect all listeners here?
|
|
||||||
await self._api_user_request_void("start_listen")
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
# TODO Stop all event handlers
|
|
||||||
await self._api_user_request_void("stop")
|
|
||||||
|
|
||||||
|
|
||||||
# TODO Combine these into one
|
# TODO Combine these into one
|
||||||
|
|
||||||
|
@ -272,19 +263,31 @@ class Client:
|
||||||
|
|
||||||
# region listeners
|
# region listeners
|
||||||
|
|
||||||
async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None:
|
async def _on_message(self, data: dict[str, JSON]) -> None:
|
||||||
async def wrapper(data: dict[str, JSON]) -> None:
|
await self.user.on_message(
|
||||||
await func(
|
Chatlog.deserialize(data["chatlog"]),
|
||||||
Chatlog.deserialize(data["chatlog"]),
|
Long.deserialize(data["channelId"]),
|
||||||
Long.deserialize(data["channelId"]),
|
data["channelType"]
|
||||||
data["channelType"]
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self._add_user_handler("message", wrapper)
|
""" TODO
|
||||||
|
async def _on_receipt(self, data: Dict[str, JSON]) -> None:
|
||||||
|
await self.user.on_receipt(Receipt.deserialize(data["receipt"]))
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _add_user_handler(self, command: str, handler: EventHandler) -> str:
|
def _start_listen(self) -> None:
|
||||||
self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler)
|
# TODO Automate this somehow, like with a fancy enum
|
||||||
|
self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [self._on_message])
|
||||||
|
# TODO many more listeners
|
||||||
|
|
||||||
|
def _stop_listen(self) -> None:
|
||||||
|
# TODO Automate this somehow, like with a fancy enum
|
||||||
|
self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [])
|
||||||
|
# TODO many more listeners
|
||||||
|
|
||||||
|
def _get_user_cmd(self, command) -> str:
|
||||||
|
return f"{command}:{self.user.mxid}"
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
|
@ -79,6 +79,16 @@ class RootCommandResult(ResponseState):
|
||||||
"""For brevity, this also encompasses CommandResultFailed and CommandResultDoneVoid"""
|
"""For brevity, this also encompasses CommandResultFailed and CommandResultDoneVoid"""
|
||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, data: JSON) -> "RootCommandResult":
|
||||||
|
if not data or "success" not in data or "status" not in data:
|
||||||
|
return RootCommandResult(
|
||||||
|
success=True,
|
||||||
|
status=KnownDataStatusCode.SUCCESS
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().deserialize(data)
|
||||||
|
|
||||||
|
|
||||||
ResultType = TypeVar("ResultType", bound=Serializable)
|
ResultType = TypeVar("ResultType", bound=Serializable)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Pattern, cast
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Pattern, cast
|
||||||
from collections import deque
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
@ -845,11 +844,6 @@ class Portal(DBPortal, BasePortal):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _add_kakaotalk_reply(
|
|
||||||
self, content: MessageEventContent, reply_to: None
|
|
||||||
) -> None:
|
|
||||||
self.log.info("TODO")
|
|
||||||
|
|
||||||
async def handle_remote_message(
|
async def handle_remote_message(
|
||||||
self,
|
self,
|
||||||
source: u.User,
|
source: u.User,
|
||||||
|
@ -969,7 +963,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
if not messages:
|
if not messages:
|
||||||
self.log.debug("Didn't get any messages from server")
|
self.log.debug("Didn't get any messages from server")
|
||||||
return
|
return
|
||||||
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server")
|
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) > 1 else ''} from server")
|
||||||
self._backfill_leave = set()
|
self._backfill_leave = set()
|
||||||
async with NotificationDisabler(self.mxid, source):
|
async with NotificationDisabler(self.mxid, source):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
from .rpc import RPCClient
|
from .rpc import RPCClient
|
||||||
|
from .types import RPCError
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Awaitable, List
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
@ -39,7 +39,7 @@ class RPCClient:
|
||||||
_req_id: int
|
_req_id: int
|
||||||
_min_broadcast_id: int
|
_min_broadcast_id: int
|
||||||
_response_waiters: dict[int, asyncio.Future[JSON]]
|
_response_waiters: dict[int, asyncio.Future[JSON]]
|
||||||
_event_handlers: dict[str, List[EventHandler]]
|
_event_handlers: dict[str, list[EventHandler]]
|
||||||
_command_queue: asyncio.Queue
|
_command_queue: asyncio.Queue
|
||||||
|
|
||||||
def __init__(self, config: Config) -> None:
|
def __init__(self, config: Config) -> None:
|
||||||
|
@ -98,7 +98,13 @@ class RPCClient:
|
||||||
self._event_handlers.setdefault(method, []).append(handler)
|
self._event_handlers.setdefault(method, []).append(handler)
|
||||||
|
|
||||||
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
|
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
|
||||||
self._event_handlers.setdefault(method, []).remove(handler)
|
try:
|
||||||
|
self._event_handlers.setdefault(method, []).remove(handler)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
|
||||||
|
self._event_handlers[method] = handlers
|
||||||
|
|
||||||
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
|
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
|
||||||
if req_id > self._min_broadcast_id:
|
if req_id > self._min_broadcast_id:
|
||||||
|
|
|
@ -46,9 +46,7 @@ from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
|
||||||
from .kt.types.oauth import OAuthCredential
|
from .kt.types.oauth import OAuthCredential
|
||||||
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
|
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
|
||||||
|
|
||||||
METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels")
|
METRIC_CONNECT_AND_SYNC = Summary("bridge_sync_channels", "calls to connect_and_sync")
|
||||||
METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync")
|
|
||||||
METRIC_UNKNOWN_EVENT = Summary("bridge_on_unknown_event", "calls to on_unknown_event")
|
|
||||||
METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added")
|
METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added")
|
||||||
METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed")
|
METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed")
|
||||||
METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing")
|
METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing")
|
||||||
|
@ -58,7 +56,6 @@ METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent")
|
||||||
METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen")
|
METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen")
|
||||||
METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change")
|
METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change")
|
||||||
METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change")
|
METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change")
|
||||||
METRIC_THREAD_CHANGE = Summary("bridge_on_thread_change", "calls to on_thread_change")
|
|
||||||
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
|
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
|
||||||
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
|
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
|
||||||
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
|
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
|
||||||
|
@ -71,7 +68,6 @@ BridgeState.human_readable_errors.update(
|
||||||
"kt-reconnection-error": "Failed to reconnect to KakaoTalk",
|
"kt-reconnection-error": "Failed to reconnect to KakaoTalk",
|
||||||
"kt-connection-error": "KakaoTalk disconnected unexpectedly",
|
"kt-connection-error": "KakaoTalk disconnected unexpectedly",
|
||||||
"kt-auth-error": "Authentication error from KakaoTalk: {message}",
|
"kt-auth-error": "Authentication error from KakaoTalk: {message}",
|
||||||
"kt-start-error": "Startup error from KakaoTalk: {message}",
|
|
||||||
"kt-disconnected": None,
|
"kt-disconnected": None,
|
||||||
"logged-out": "You're not logged into KakaoTalk",
|
"logged-out": "You're not logged into KakaoTalk",
|
||||||
}
|
}
|
||||||
|
@ -79,7 +75,6 @@ BridgeState.human_readable_errors.update(
|
||||||
|
|
||||||
|
|
||||||
class User(DBUser, BaseUser):
|
class User(DBUser, BaseUser):
|
||||||
#temp_disconnect_notices: bool = True
|
|
||||||
shutdown: bool = False
|
shutdown: bool = False
|
||||||
config: Config
|
config: Config
|
||||||
|
|
||||||
|
@ -90,16 +85,13 @@ class User(DBUser, BaseUser):
|
||||||
|
|
||||||
_notice_room_lock: asyncio.Lock
|
_notice_room_lock: asyncio.Lock
|
||||||
_notice_send_lock: asyncio.Lock
|
_notice_send_lock: asyncio.Lock
|
||||||
command_status: dict | None
|
|
||||||
is_admin: bool
|
is_admin: bool
|
||||||
permission_level: str
|
permission_level: str
|
||||||
_is_logged_in: bool | None
|
_is_logged_in: bool | None
|
||||||
#_is_connected: bool | None
|
_is_connected: bool | None
|
||||||
#_connection_time: float
|
_connection_time: float
|
||||||
_prev_reconnect_fail_refresh: float
|
|
||||||
_db_instance: DBUser | None
|
_db_instance: DBUser | None
|
||||||
_sync_lock: SimpleLock
|
_sync_lock: SimpleLock
|
||||||
_is_refreshing: bool
|
|
||||||
_logged_in_info: ProfileStruct | None
|
_logged_in_info: ProfileStruct | None
|
||||||
_logged_in_info_time: float
|
_logged_in_info_time: float
|
||||||
|
|
||||||
|
@ -124,7 +116,6 @@ class User(DBUser, BaseUser):
|
||||||
self.notice_room = notice_room
|
self.notice_room = notice_room
|
||||||
self._notice_room_lock = asyncio.Lock()
|
self._notice_room_lock = asyncio.Lock()
|
||||||
self._notice_send_lock = asyncio.Lock()
|
self._notice_send_lock = asyncio.Lock()
|
||||||
self.command_status = None
|
|
||||||
(
|
(
|
||||||
self.relay_whitelisted,
|
self.relay_whitelisted,
|
||||||
self.is_whitelisted,
|
self.is_whitelisted,
|
||||||
|
@ -132,13 +123,11 @@ class User(DBUser, BaseUser):
|
||||||
self.permission_level,
|
self.permission_level,
|
||||||
) = self.config.get_permissions(mxid)
|
) = self.config.get_permissions(mxid)
|
||||||
self._is_logged_in = None
|
self._is_logged_in = None
|
||||||
#self._is_connected = None
|
self._is_connected = None
|
||||||
#self._connection_time = time.monotonic()
|
self._connection_time = time.monotonic()
|
||||||
self._prev_reconnect_fail_refresh = time.monotonic()
|
|
||||||
self._sync_lock = SimpleLock(
|
self._sync_lock = SimpleLock(
|
||||||
"Waiting for thread sync to finish before handling %s", log=self.log
|
"Waiting for thread sync to finish before handling %s", log=self.log
|
||||||
)
|
)
|
||||||
self._is_refreshing = False
|
|
||||||
self._logged_in_info = None
|
self._logged_in_info = None
|
||||||
self._logged_in_info_time = 0
|
self._logged_in_info_time = 0
|
||||||
|
|
||||||
|
@ -150,10 +139,8 @@ class User(DBUser, BaseUser):
|
||||||
cls.config = bridge.config
|
cls.config = bridge.config
|
||||||
cls.az = bridge.az
|
cls.az = bridge.az
|
||||||
cls.loop = bridge.loop
|
cls.loop = bridge.loop
|
||||||
#cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
|
|
||||||
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
|
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
|
||||||
|
|
||||||
"""
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool | None:
|
def is_connected(self) -> bool | None:
|
||||||
return self._is_connected
|
return self._is_connected
|
||||||
|
@ -167,11 +154,11 @@ class User(DBUser, BaseUser):
|
||||||
@property
|
@property
|
||||||
def connection_time(self) -> float:
|
def connection_time(self) -> float:
|
||||||
return self._connection_time
|
return self._connection_time
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_state(self) -> bool:
|
def has_state(self) -> bool:
|
||||||
return bool(self.uuid and self.ktid and self.access_token and self.refresh_token)
|
# TODO If more state is needed, consider returning a saved LoginResult
|
||||||
|
return bool(self.access_token and self.refresh_token)
|
||||||
|
|
||||||
# region Database getters
|
# region Database getters
|
||||||
|
|
||||||
|
@ -233,11 +220,13 @@ class User(DBUser, BaseUser):
|
||||||
async def get_uuid(self, force: bool = False) -> str:
|
async def get_uuid(self, force: bool = False) -> str:
|
||||||
if self.uuid is None or force:
|
if self.uuid is None or force:
|
||||||
self.uuid = await self._generate_uuid()
|
self.uuid = await self._generate_uuid()
|
||||||
|
# TODO Maybe don't save yet
|
||||||
await self.save()
|
await self.save()
|
||||||
return self.uuid
|
return self.uuid
|
||||||
|
|
||||||
async def _generate_uuid(self) -> str:
|
@classmethod
|
||||||
return await Client.generate_uuid(await self.get_all_uuids())
|
async def _generate_uuid(cls) -> str:
|
||||||
|
return await Client.generate_uuid(await super().get_all_uuids())
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -285,8 +274,7 @@ class User(DBUser, BaseUser):
|
||||||
self._logged_in_info_time = time.monotonic()
|
self._logged_in_info_time = time.monotonic()
|
||||||
self._track_metric(METRIC_LOGGED_IN, True)
|
self._track_metric(METRIC_LOGGED_IN, True)
|
||||||
self._is_logged_in = True
|
self._is_logged_in = True
|
||||||
#self.is_connected = None
|
self.is_connected = None
|
||||||
self.stop_listen()
|
|
||||||
asyncio.create_task(self.post_login(is_startup=is_startup))
|
asyncio.create_task(self.post_login(is_startup=is_startup))
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@ -310,10 +298,9 @@ class User(DBUser, BaseUser):
|
||||||
) -> ProfileStruct:
|
) -> ProfileStruct:
|
||||||
if not client:
|
if not client:
|
||||||
client = self.client
|
client = self.client
|
||||||
# TODO Retry network connection failures here, or in the client?
|
# TODO Retry network connection failures here, or in the client (like token refreshes are)?
|
||||||
try:
|
try:
|
||||||
return await client.fetch_logged_in_user()
|
return await client.fetch_logged_in_user()
|
||||||
# NOTE Not catching InvalidAccessToken here, as client handles it & tries to refresh the token
|
|
||||||
except AuthenticationRequired as e:
|
except AuthenticationRequired as e:
|
||||||
if action != "restore session":
|
if action != "restore session":
|
||||||
await self._send_reset_notice(e)
|
await self._send_reset_notice(e)
|
||||||
|
@ -352,7 +339,7 @@ class User(DBUser, BaseUser):
|
||||||
state_event=BridgeStateEvent.TRANSIENT_DISCONNECT,
|
state_event=BridgeStateEvent.TRANSIENT_DISCONNECT,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
await self.reload_session(event_id, retries - 1)
|
await self.reload_session(event_id, retries - 1, is_startup)
|
||||||
else:
|
else:
|
||||||
await self.send_bridge_notice(
|
await self.send_bridge_notice(
|
||||||
notice,
|
notice,
|
||||||
|
@ -362,44 +349,43 @@ class User(DBUser, BaseUser):
|
||||||
error_code="kt-reconnection-error",
|
error_code="kt-reconnection-error",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
self.log.exception("Error connecting to KakaoTalk")
|
||||||
await self.send_bridge_notice(
|
await self.send_bridge_notice(
|
||||||
"Failed to connect to KakaoTalk: unknown error (see logs for more details)",
|
"Failed to connect to KakaoTalk: unknown error (see logs for more details)",
|
||||||
edit=event_id,
|
edit=event_id,
|
||||||
state_event=BridgeStateEvent.UNKNOWN_ERROR,
|
state_event=BridgeStateEvent.UNKNOWN_ERROR,
|
||||||
error_code="kt-reconnection-error",
|
error_code="kt-reconnection-error",
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self._is_refreshing = False
|
|
||||||
|
|
||||||
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> bool:
|
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
|
||||||
ok = True
|
if self.client:
|
||||||
self.stop_listen()
|
# TODO Look for a logout API call
|
||||||
if self.has_state:
|
was_connected = await self.client.disconnect()
|
||||||
# TODO Log out of KakaoTalk if an API exists for it
|
if was_connected != self._is_connected:
|
||||||
pass
|
self.log.warn(
|
||||||
|
f"Node backend was{' not' if not was_connected else ''} connected, "
|
||||||
|
f"but we thought it was{' not' if not self._is_connected else ''}")
|
||||||
if remove_ktid:
|
if remove_ktid:
|
||||||
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
|
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
|
||||||
self._track_metric(METRIC_LOGGED_IN, False)
|
self._track_metric(METRIC_LOGGED_IN, False)
|
||||||
self._is_logged_in = False
|
|
||||||
#self.is_connected = None
|
|
||||||
if self.client:
|
|
||||||
await self.client.stop()
|
|
||||||
self.client = None
|
|
||||||
|
|
||||||
if self.ktid and remove_ktid:
|
|
||||||
#await UserPortal.delete_all(self.ktid)
|
|
||||||
del self.by_ktid[self.ktid]
|
|
||||||
self.ktid = None
|
|
||||||
|
|
||||||
if reset_device:
|
if reset_device:
|
||||||
self.uuid = await self._generate_uuid()
|
self.uuid = await self._generate_uuid()
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
self.refresh_token = None
|
self.refresh_token = None
|
||||||
|
|
||||||
await self.save()
|
self._is_logged_in = False
|
||||||
return ok
|
self.is_connected = None
|
||||||
|
self.client = None
|
||||||
|
|
||||||
async def post_login(self, is_startup: bool) -> bool:
|
if self.ktid and remove_ktid:
|
||||||
|
#await UserPortal.delete_all(self.ktid)
|
||||||
|
del self.by_ktid[self.ktid]
|
||||||
|
self.ktid = None
|
||||||
|
|
||||||
|
await self.save()
|
||||||
|
|
||||||
|
async def post_login(self, is_startup: bool) -> None:
|
||||||
self.log.info("Running post-login actions")
|
self.log.info("Running post-login actions")
|
||||||
self._add_to_cache()
|
self._add_to_cache()
|
||||||
|
|
||||||
|
@ -412,12 +398,27 @@ class User(DBUser, BaseUser):
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to automatically enable custom puppet")
|
self.log.exception("Failed to automatically enable custom puppet")
|
||||||
|
|
||||||
assert self.client
|
# TODO Check if things break when a live message comes in during syncing
|
||||||
|
if self.config["bridge.sync_on_startup"] or not is_startup:
|
||||||
|
sync_count = self.config["bridge.initial_chat_sync"]
|
||||||
|
else:
|
||||||
|
sync_count = None
|
||||||
|
await self.connect_and_sync(sync_count)
|
||||||
|
|
||||||
|
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
|
||||||
|
return {
|
||||||
|
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
|
||||||
|
async for portal in po.Portal.get_all_by_receiver(self.ktid)
|
||||||
|
if portal.mxid
|
||||||
|
}
|
||||||
|
|
||||||
|
@async_time(METRIC_CONNECT_AND_SYNC)
|
||||||
|
async def connect_and_sync(self, sync_count: int | None) -> bool:
|
||||||
|
# TODO Look for a way to sync all channels without (re-)logging in
|
||||||
try:
|
try:
|
||||||
# TODO if not is_startup, close existing listeners
|
login_result = await self.client.connect()
|
||||||
login_result = await self.client.start()
|
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
|
||||||
await self._sync_channels(login_result, is_startup)
|
await self._sync_channels(login_result, sync_count)
|
||||||
self.start_listen()
|
|
||||||
return True
|
return True
|
||||||
except AuthenticationRequired as e:
|
except AuthenticationRequired as e:
|
||||||
await self.send_bridge_notice(
|
await self.send_bridge_notice(
|
||||||
|
@ -429,41 +430,34 @@ class User(DBUser, BaseUser):
|
||||||
)
|
)
|
||||||
await self.logout(remove_ktid=False)
|
await self.logout(remove_ktid=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.exception("Failed to start client")
|
self.log.exception("Failed to connect and sync")
|
||||||
await self.send_bridge_notice(
|
await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e))
|
||||||
f"Got error from KakaoTalk:\n\n> {e!s}\n\n",
|
|
||||||
important=True,
|
|
||||||
state_event=BridgeStateEvent.UNKNOWN_ERROR,
|
|
||||||
error_code="kt-start-error",
|
|
||||||
error_message=str(e),
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
|
async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None:
|
||||||
return {
|
if sync_count is None:
|
||||||
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
|
sync_count = self.config["bridge.initial_chat_sync"]
|
||||||
async for portal in po.Portal.get_all_by_receiver(self.ktid)
|
if not sync_count:
|
||||||
if portal.mxid
|
self.log.debug("Skipping channel syncing")
|
||||||
}
|
|
||||||
|
|
||||||
@async_time(METRIC_SYNC_CHANNELS)
|
|
||||||
async def _sync_channels(self, login_result: LoginResult, is_startup: bool) -> None:
|
|
||||||
# TODO Look for a way to sync all channels without (re-)logging in
|
|
||||||
sync_count = self.config["bridge.initial_chat_sync"]
|
|
||||||
if sync_count <= 0 or not self.config["bridge.sync_on_startup"] and is_startup:
|
|
||||||
self.log.debug(f"Skipping channel syncing{' on startup' if sync_count > 0 else ''}")
|
|
||||||
return
|
return
|
||||||
if not login_result.channelList:
|
if not login_result.channelList:
|
||||||
self.log.debug("No channels to sync")
|
self.log.debug("No channels to sync")
|
||||||
return
|
return
|
||||||
# TODO What about removed channels? Don't early-return then
|
# TODO What about removed channels? Don't early-return then
|
||||||
|
|
||||||
sync_count = min(sync_count, len(login_result.channelList))
|
num_channels = len(login_result.channelList)
|
||||||
|
sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
|
||||||
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
|
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
|
||||||
self.log.debug(f"Syncing {sync_count} of {len(login_result.channelList)} channels...")
|
self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
|
||||||
for channel_item in login_result.channelList[:sync_count]:
|
for channel_item in login_result.channelList[:sync_count]:
|
||||||
# TODO try-except here, above, below?
|
try:
|
||||||
await self._sync_channel(channel_item)
|
await self._sync_channel(channel_item)
|
||||||
|
except AuthenticationRequired:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.log.exception(f"Failed to sync channel {channel_item.channel.channelId}")
|
||||||
|
|
||||||
|
await self.update_direct_chats()
|
||||||
|
|
||||||
async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None:
|
async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None:
|
||||||
channel_data = channel_item.channel
|
channel_data = channel_item.channel
|
||||||
|
@ -573,18 +567,12 @@ class User(DBUser, BaseUser):
|
||||||
state.remote_name = puppet.name
|
state.remote_name = puppet.name
|
||||||
|
|
||||||
async def get_bridge_states(self) -> list[BridgeState]:
|
async def get_bridge_states(self) -> list[BridgeState]:
|
||||||
self.log.info("TODO: get_bridge_states")
|
|
||||||
return []
|
|
||||||
"""
|
|
||||||
if not self.state:
|
if not self.state:
|
||||||
return []
|
return []
|
||||||
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
|
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
state.state_event = BridgeStateEvent.CONNECTED
|
state.state_event = BridgeStateEvent.CONNECTED
|
||||||
elif self._is_refreshing or self.mqtt:
|
|
||||||
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
|
|
||||||
return [state]
|
return [state]
|
||||||
"""
|
|
||||||
|
|
||||||
async def get_puppet(self) -> pu.Puppet | None:
|
async def get_puppet(self) -> pu.Puppet | None:
|
||||||
if not self.ktid:
|
if not self.ktid:
|
||||||
|
@ -593,36 +581,10 @@ class User(DBUser, BaseUser):
|
||||||
|
|
||||||
# region KakaoTalk event handling
|
# region KakaoTalk event handling
|
||||||
|
|
||||||
def start_listen(self) -> None:
|
|
||||||
self.listen_task = asyncio.create_task(self._try_listen())
|
|
||||||
|
|
||||||
def _disconnect_listener_after_error(self) -> None:
|
|
||||||
self.log.info("TODO: _disconnect_listener_after_error")
|
|
||||||
|
|
||||||
async def _try_listen(self) -> None:
|
|
||||||
try:
|
|
||||||
# TODO Pass all listeners to start_listen instead of registering them one-by-one?
|
|
||||||
await self.client.start_listen()
|
|
||||||
await self.client.on_message(self.on_message)
|
|
||||||
# TODO Handle auth errors specially?
|
|
||||||
#except AuthenticationRequired as e:
|
|
||||||
except Exception:
|
|
||||||
#self.is_connected = False
|
|
||||||
self.log.exception("Fatal error in listener")
|
|
||||||
await self.send_bridge_notice(
|
|
||||||
"Fatal error in listener (see logs for more info)",
|
|
||||||
state_event=BridgeStateEvent.UNKNOWN_ERROR,
|
|
||||||
important=True,
|
|
||||||
error_code="kt-connection-error",
|
|
||||||
)
|
|
||||||
self._disconnect_listener_after_error()
|
|
||||||
|
|
||||||
def stop_listen(self) -> None:
|
|
||||||
self.log.info("TODO: stop_listen")
|
|
||||||
|
|
||||||
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
|
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
|
||||||
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
|
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
|
||||||
self.oauth_credential = oauth_credential
|
self.oauth_credential = oauth_credential
|
||||||
|
await self.push_bridge_state(BridgeStateEvent.CONNECTING)
|
||||||
self.client = Client(self, log=self.log.getChild("ktclient"))
|
self.client = Client(self, log=self.log.getChild("ktclient"))
|
||||||
await self.save()
|
await self.save()
|
||||||
self._is_logged_in = True
|
self._is_logged_in = True
|
||||||
|
@ -631,7 +593,6 @@ class User(DBUser, BaseUser):
|
||||||
self._logged_in_info_time = time.monotonic()
|
self._logged_in_info_time = time.monotonic()
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to fetch post-login info")
|
self.log.exception("Failed to fetch post-login info")
|
||||||
self.stop_listen()
|
|
||||||
asyncio.create_task(self.post_login(is_startup=True))
|
asyncio.create_task(self.post_login(is_startup=True))
|
||||||
|
|
||||||
@async_time(METRIC_MESSAGE)
|
@async_time(METRIC_MESSAGE)
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
// You should have received a copy of the GNU Affero General Public License
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
import { Long } from "bson"
|
import { Long } from "bson"
|
||||||
import { emitLines, promisify } from "./util.js"
|
|
||||||
import {
|
import {
|
||||||
AuthApiClient,
|
AuthApiClient,
|
||||||
OAuthApiClient,
|
OAuthApiClient,
|
||||||
|
@ -28,6 +27,8 @@ const { KnownChatType } = chat
|
||||||
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
|
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
|
||||||
/** @typedef {import("./clientmanager.js").default} ClientManager} */
|
/** @typedef {import("./clientmanager.js").default} ClientManager} */
|
||||||
|
|
||||||
|
import { emitLines, promisify } from "./util.js"
|
||||||
|
|
||||||
|
|
||||||
class UserClient {
|
class UserClient {
|
||||||
static #initializing = false
|
static #initializing = false
|
||||||
|
@ -36,7 +37,7 @@ class UserClient {
|
||||||
get talkClient() { return this.#talkClient }
|
get talkClient() { return this.#talkClient }
|
||||||
|
|
||||||
/** @type {ServiceApiClient} */
|
/** @type {ServiceApiClient} */
|
||||||
#serviceClient = null
|
#serviceClient
|
||||||
get serviceClient() { return this.#serviceClient }
|
get serviceClient() { return this.#serviceClient }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -80,7 +81,6 @@ class UserClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
export default class PeerClient {
|
export default class PeerClient {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {ClientManager} manager
|
* @param {ClientManager} manager
|
||||||
* @param {import("net").Socket} socket
|
* @param {import("net").Socket} socket
|
||||||
|
@ -136,6 +136,7 @@ export default class PeerClient {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
this.stopped = true
|
this.stopped = true
|
||||||
|
this.#closeUsers()
|
||||||
try {
|
try {
|
||||||
await this.#write({ id: --this.notificationID, command: "quit", error })
|
await this.#write({ id: --this.notificationID, command: "quit", error })
|
||||||
await promisify(cb => this.socket.end(cb))
|
await promisify(cb => this.socket.end(cb))
|
||||||
|
@ -145,20 +146,21 @@ export default class PeerClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
handleEnd = async () => {
|
handleEnd = () => {
|
||||||
// TODO Persist clients across bridge disconnections.
|
this.stopped = true
|
||||||
// But then have to queue received events until bridge acks them!
|
this.#closeUsers()
|
||||||
|
if (this.peerID && this.manager.clients.get(this.peerID) === this) {
|
||||||
|
this.manager.clients.delete(this.peerID)
|
||||||
|
}
|
||||||
|
this.log(`Connection closed (peer: ${this.peerID})`)
|
||||||
|
}
|
||||||
|
|
||||||
|
#closeUsers() {
|
||||||
this.log("Closing all API clients for", this.peerID)
|
this.log("Closing all API clients for", this.peerID)
|
||||||
for (const userClient of this.userClients.values()) {
|
for (const userClient of this.userClients.values()) {
|
||||||
userClient.close()
|
userClient.close()
|
||||||
}
|
}
|
||||||
this.userClients.clear()
|
this.userClients.clear()
|
||||||
|
|
||||||
this.stopped = true
|
|
||||||
if (this.peerID && this.manager.clients.get(this.peerID) === this) {
|
|
||||||
this.manager.clients.delete(this.peerID)
|
|
||||||
}
|
|
||||||
this.log(`Connection closed (peer: ${this.peerID})`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -178,6 +180,7 @@ export default class PeerClient {
|
||||||
* @param {Object} req.form
|
* @param {Object} req.form
|
||||||
*/
|
*/
|
||||||
registerDevice = async (req) => {
|
registerDevice = async (req) => {
|
||||||
|
// TODO Look for a deregister API call
|
||||||
const authClient = await this.#createAuthClient(req.uuid)
|
const authClient = await this.#createAuthClient(req.uuid)
|
||||||
return await authClient.registerDevice(req.form, req.passcode, true)
|
return await authClient.registerDevice(req.form, req.passcode, true)
|
||||||
}
|
}
|
||||||
|
@ -192,6 +195,7 @@ export default class PeerClient {
|
||||||
* request failed, its status is stored here.
|
* request failed, its status is stored here.
|
||||||
*/
|
*/
|
||||||
handleLogin = async (req) => {
|
handleLogin = async (req) => {
|
||||||
|
// TODO Look for a logout API call
|
||||||
const authClient = await this.#createAuthClient(req.uuid)
|
const authClient = await this.#createAuthClient(req.uuid)
|
||||||
const loginRes = await authClient.login(req.form, true)
|
const loginRes = await authClient.login(req.form, true)
|
||||||
if (loginRes.status === KnownAuthStatusCode.DEVICE_NOT_REGISTERED) {
|
if (loginRes.status === KnownAuthStatusCode.DEVICE_NOT_REGISTERED) {
|
||||||
|
@ -226,6 +230,15 @@ export default class PeerClient {
|
||||||
return userClient
|
return userClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unchecked lookup of a UserClient for a given mxid.
|
||||||
|
* @param {string} mxid
|
||||||
|
* @returns {UserClient | undefined}
|
||||||
|
*/
|
||||||
|
#tryGetUser(mxid) {
|
||||||
|
return this.userClients.get(mxid)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the service client for the specified user ID, or create
|
* Get the service client for the specified user ID, or create
|
||||||
* and return a new service client if no user ID is provided.
|
* and return a new service client if no user ID is provided.
|
||||||
|
@ -233,7 +246,7 @@ export default class PeerClient {
|
||||||
* @param {OAuthCredential} oauth_credential
|
* @param {OAuthCredential} oauth_credential
|
||||||
*/
|
*/
|
||||||
async #getServiceClient(mxid, oauth_credential) {
|
async #getServiceClient(mxid, oauth_credential) {
|
||||||
return this.userClients.get(mxid)?.serviceClient ||
|
return this.#tryGetUser(mxid)?.serviceClient ||
|
||||||
await ServiceApiClient.create(oauth_credential)
|
await ServiceApiClient.create(oauth_credential)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,26 +264,15 @@ export default class PeerClient {
|
||||||
* @param {string} req.mxid
|
* @param {string} req.mxid
|
||||||
* @param {OAuthCredential} req.oauth_credential
|
* @param {OAuthCredential} req.oauth_credential
|
||||||
*/
|
*/
|
||||||
handleStart = async (req) => {
|
handleConnect = async (req) => {
|
||||||
// TODO Don't re-login if possible. But must still return a LoginResult!
|
// TODO Don't re-login if possible. But must still return a LoginResult!
|
||||||
{
|
this.handleDisconnect(req)
|
||||||
const oldUserClient = this.userClients.get(req.mxid)
|
|
||||||
if (oldUserClient !== undefined) {
|
|
||||||
oldUserClient.close()
|
|
||||||
this.userClients.delete(req.mxid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const userClient = await UserClient.create(req.mxid, req.oauth_credential)
|
const userClient = await UserClient.create(req.mxid, req.oauth_credential)
|
||||||
const res = await userClient.talkClient.login(req.oauth_credential)
|
const res = await userClient.talkClient.login(req.oauth_credential)
|
||||||
if (!res.success) return res
|
if (!res.success) return res
|
||||||
|
|
||||||
this.userClients.set(req.mxid, userClient)
|
this.userClients.set(req.mxid, userClient)
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
startListen = async (req) => {
|
|
||||||
const userClient = this.#getUser(req.mxid)
|
|
||||||
|
|
||||||
userClient.talkClient.on("chat", (data, channel) => {
|
userClient.talkClient.on("chat", (data, channel) => {
|
||||||
this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
|
this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
|
||||||
|
@ -291,7 +293,22 @@ export default class PeerClient {
|
||||||
})
|
})
|
||||||
*/
|
*/
|
||||||
|
|
||||||
return this.#voidCommandResult
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {Object} req
|
||||||
|
* @param {string} req.mxid
|
||||||
|
*/
|
||||||
|
handleDisconnect = (req) => {
|
||||||
|
const userClient = this.#tryGetUser(req.mxid)
|
||||||
|
if (!!userClient) {
|
||||||
|
userClient.close()
|
||||||
|
this.userClients.delete(req.mxid)
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -368,16 +385,6 @@ export default class PeerClient {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @param {Object} req
|
|
||||||
* @param {string} req.mxid
|
|
||||||
*/
|
|
||||||
handleStop = async (req) => {
|
|
||||||
this.#getUser(req.mxid).close()
|
|
||||||
this.userClients.delete(req.mxid)
|
|
||||||
return this.#voidCommandResult
|
|
||||||
}
|
|
||||||
|
|
||||||
#makeCommandResult(result) {
|
#makeCommandResult(result) {
|
||||||
return {
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
|
@ -386,11 +393,6 @@ export default class PeerClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#voidCommandResult = {
|
|
||||||
success: true,
|
|
||||||
status: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
handleUnknownCommand = () => {
|
handleUnknownCommand = () => {
|
||||||
throw new Error("Unknown command")
|
throw new Error("Unknown command")
|
||||||
}
|
}
|
||||||
|
@ -431,9 +433,6 @@ export default class PeerClient {
|
||||||
this.log("Ignoring old request", req.id)
|
this.log("Ignoring old request", req.id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (req.command != "is_connected") {
|
|
||||||
this.log("Received request", req.id, "with command", req.command)
|
|
||||||
}
|
|
||||||
this.maxCommandID = req.id
|
this.maxCommandID = req.id
|
||||||
let handler
|
let handler
|
||||||
if (!this.peerID) {
|
if (!this.peerID) {
|
||||||
|
@ -449,21 +448,18 @@ export default class PeerClient {
|
||||||
handler = this.handleRegister
|
handler = this.handleRegister
|
||||||
} else {
|
} else {
|
||||||
handler = {
|
handler = {
|
||||||
// TODO Subclass / object for KakaoTalk-specific handlers?
|
// TODO Wrapper for per-user commands
|
||||||
start: this.handleStart,
|
|
||||||
stop: this.handleStop,
|
|
||||||
disconnect: () => this.stop(),
|
|
||||||
login: this.handleLogin,
|
|
||||||
renew: this.handleRenew,
|
|
||||||
generate_uuid: util.randomAndroidSubDeviceUUID,
|
generate_uuid: util.randomAndroidSubDeviceUUID,
|
||||||
register_device: this.registerDevice,
|
register_device: this.registerDevice,
|
||||||
start_listen: this.startListen,
|
login: this.handleLogin,
|
||||||
|
renew: this.handleRenew,
|
||||||
|
connect: this.handleConnect,
|
||||||
|
disconnect: this.handleDisconnect,
|
||||||
get_own_profile: this.getOwnProfile,
|
get_own_profile: this.getOwnProfile,
|
||||||
|
get_profile: this.getProfile,
|
||||||
get_portal_channel_info: this.getPortalChannelInfo,
|
get_portal_channel_info: this.getPortalChannelInfo,
|
||||||
get_chats: this.getChats,
|
get_chats: this.getChats,
|
||||||
get_profile: this.getProfile,
|
|
||||||
send_message: this.sendMessage,
|
send_message: this.sendMessage,
|
||||||
//is_connected: async () => ({ is_connected: !await this.puppet.isDisconnected() }),
|
|
||||||
}[req.command] || this.handleUnknownCommand
|
}[req.command] || this.handleUnknownCommand
|
||||||
}
|
}
|
||||||
const resp = { id: req.id }
|
const resp = { id: req.id }
|
||||||
|
@ -482,6 +478,7 @@ export default class PeerClient {
|
||||||
resp.command = "error"
|
resp.command = "error"
|
||||||
resp.error = err.toString()
|
resp.error = err.toString()
|
||||||
this.log(`Error handling request ${resp.id} ${err.stack}`)
|
this.log(`Error handling request ${resp.id} ${err.stack}`)
|
||||||
|
// TODO Check if session is broken. If it is, close the PeerClient
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
await this.#write(resp)
|
await this.#write(resp)
|
||||||
|
|
|
@ -20,6 +20,7 @@ import path from "path"
|
||||||
import PeerClient from "./client.js"
|
import PeerClient from "./client.js"
|
||||||
import { promisify } from "./util.js"
|
import { promisify } from "./util.js"
|
||||||
|
|
||||||
|
|
||||||
export default class ClientManager {
|
export default class ClientManager {
|
||||||
constructor(listenConfig) {
|
constructor(listenConfig) {
|
||||||
this.listenConfig = listenConfig
|
this.listenConfig = listenConfig
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
import process from "process"
|
import process from "process"
|
||||||
import fs from "fs"
|
import fs from "fs"
|
||||||
import sd from "systemd-daemon"
|
|
||||||
|
|
||||||
import arg from "arg"
|
import arg from "arg"
|
||||||
|
import sd from "systemd-daemon"
|
||||||
|
|
||||||
import ClientManager from "./clientmanager.js"
|
import ClientManager from "./clientmanager.js"
|
||||||
|
|
||||||
|
|
||||||
const args = arg({
|
const args = arg({
|
||||||
"--config": String,
|
"--config": String,
|
||||||
"-c": "--config",
|
"-c": "--config",
|
||||||
|
@ -31,10 +32,10 @@ const configPath = args["--config"] || "config.json"
|
||||||
console.log("[Main] Reading config from", configPath)
|
console.log("[Main] Reading config from", configPath)
|
||||||
const config = JSON.parse(fs.readFileSync(configPath).toString())
|
const config = JSON.parse(fs.readFileSync(configPath).toString())
|
||||||
|
|
||||||
const api = new ClientManager(config.listen)
|
const manager = new ClientManager(config.listen)
|
||||||
|
|
||||||
function stop() {
|
function stop() {
|
||||||
api.stop().then(() => {
|
manager.stop().then(() => {
|
||||||
console.log("[Main] Everything stopped")
|
console.log("[Main] Everything stopped")
|
||||||
process.exit(0)
|
process.exit(0)
|
||||||
}, err => {
|
}, err => {
|
||||||
|
@ -43,7 +44,7 @@ function stop() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
api.start().then(() => {
|
manager.start().then(() => {
|
||||||
process.once("SIGINT", stop)
|
process.once("SIGINT", stop)
|
||||||
process.once("SIGTERM", stop)
|
process.once("SIGTERM", stop)
|
||||||
sd.notify("READY=1")
|
sd.notify("READY=1")
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
aiohttp>=3,<4
|
aiohttp>=3,<4
|
||||||
asyncpg>=0.20,<0.26
|
asyncpg>=0.20,<0.26
|
||||||
bson>=0.5,<0.6
|
|
||||||
commonmark>=0.8,<0.10
|
commonmark>=0.8,<0.10
|
||||||
mautrix==0.15.0rc4
|
mautrix==0.15.0rc4
|
||||||
pycryptodome>=3,<4
|
pycryptodome>=3,<4
|
||||||
|
|
Loading…
Reference in New Issue