From 0c9550841c8255bb6ea3b9f6d4df7ae485c477c9 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Wed, 9 Mar 2022 02:25:28 -0500 Subject: [PATCH] Puppets and backfilling --- matrix_appservice_kakaotalk/commands/conn.py | 12 ++ matrix_appservice_kakaotalk/db/message.py | 88 ++++++------ .../example-config.yaml | 1 - .../kt/client/client.py | 39 ++++-- .../kt/client/types.py | 31 +++- matrix_appservice_kakaotalk/kt/types/bson.py | 33 ++++- .../kt/types/channel/channel_type.py | 6 +- .../kt/types/request.py | 21 ++- matrix_appservice_kakaotalk/portal.py | 132 ++++++++++++++++-- matrix_appservice_kakaotalk/puppet.py | 15 +- matrix_appservice_kakaotalk/user.py | 3 +- node/src/client.js | 36 +++-- 12 files changed, 305 insertions(+), 112 deletions(-) diff --git a/matrix_appservice_kakaotalk/commands/conn.py b/matrix_appservice_kakaotalk/commands/conn.py index 00446c4..b390de0 100644 --- a/matrix_appservice_kakaotalk/commands/conn.py +++ b/matrix_appservice_kakaotalk/commands/conn.py @@ -99,3 +99,15 @@ async def ping(evt: CommandEvent) -> None: async def refresh(evt: CommandEvent) -> None: await evt.sender.refresh(force_notice=True) """ + + +@command_handler( + needs_auth=True, + management_only=True, + help_section=SECTION_CONNECTION, + help_text="Resync chats", +) +async def sync(evt: CommandEvent) -> None: + await evt.mark_read() + await evt.sender.post_login(is_startup=False) + await evt.reply("Sync complete") diff --git a/matrix_appservice_kakaotalk/db/message.py b/matrix_appservice_kakaotalk/db/message.py index 847740f..79c69b2 100644 --- a/matrix_appservice_kakaotalk/db/message.py +++ b/matrix_appservice_kakaotalk/db/message.py @@ -23,6 +23,8 @@ from attr import dataclass from mautrix.types import EventID, RoomID from mautrix.util.async_db import Database +from ..kt.types.bson import Long, StrLong + fake_db = Database.create("") if TYPE_CHECKING else None @@ -30,46 +32,44 @@ fake_db = Database.create("") if TYPE_CHECKING else None class Message: db: ClassVar[Database] = fake_db + # TODO Store all Long values as the same type mxid: EventID mx_room: RoomID - ktid: str | None - kt_txn_id: int | None + ktid: Long index: int - kt_chat: int - kt_receiver: int - kt_sender: int + kt_chat: Long + kt_receiver: Long timestamp: int @classmethod def _from_row(cls, row: Record | None) -> Message | None: - if row is None: - return None - return cls(**row) - - columns = 'mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, kt_sender, timestamp' + data = {**row} + ktid = data.pop("ktid") + kt_chat = data.pop("kt_chat") + kt_receiver = data.pop("kt_receiver") + return cls( + **data, + ktid=StrLong(ktid), + kt_chat=Long.from_bytes(kt_chat), + kt_receiver=Long.from_bytes(kt_receiver) + ) @classmethod - async def get_all_by_ktid(cls, ktid: str, kt_receiver: int) -> list[Message]: + def _from_optional_row(cls, row: Record | None) -> Message | None: + return cls._from_row(row) if row is not None else None + + columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp' + + @classmethod + async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]: q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2" - rows = await cls.db.fetch(q, ktid, kt_receiver) + rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver)) return [cls._from_row(row) for row in rows] @classmethod - async def get_by_ktid(cls, ktid: str, kt_receiver: int, index: int = 0) -> Message | None: + async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None: q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' - row = await cls.db.fetchrow(q, ktid, kt_receiver, index) - return cls._from_row(row) - - @classmethod - async def get_by_ktid_or_oti( - cls, ktid: str, oti: int, kt_receiver: int, kt_sender: int, index: int = 0 - ) -> Message | None: - q = ( - f"SELECT {cls.columns} " - "FROM message WHERE (ktid=$1 OR (kt_txn_id=$2 AND kt_sender=$3)) AND " - ' kt_receiver=$4 AND "index"=$5' - ) - row = await cls.db.fetchrow(q, ktid, oti, kt_sender, kt_receiver, index) + row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index) return cls._from_row(row) @classmethod @@ -83,18 +83,18 @@ class Message: return cls._from_row(row) @classmethod - async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None: + async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None: q = ( f"SELECT {cls.columns} " "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL " "ORDER BY timestamp DESC LIMIT 1" ) - row = await cls.db.fetchrow(q, kt_chat, kt_receiver) + row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver)) return cls._from_row(row) @classmethod async def get_closest_before( - cls, kt_chat: int, kt_receiver: int, timestamp: int + cls, kt_chat: Long, kt_receiver: Long, timestamp: int ) -> Message | None: q = ( f"SELECT {cls.columns} " @@ -102,23 +102,21 @@ class Message: " ktid IS NOT NULL " "ORDER BY timestamp DESC LIMIT 1" ) - row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp) + row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp) return cls._from_row(row) _insert_query = ( - 'INSERT INTO message (mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, ' - " kt_sender, timestamp) " - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, ' + " timestamp) " + "VALUES ($1, $2, $3, $4, $5, $6, $7)" ) @classmethod async def bulk_create( cls, - ktid: str, - oti: int, - kt_chat: int, - kt_receiver: int, - kt_sender: int, + ktid: Long, + kt_chat: Long, + kt_receiver: Long, event_ids: list[EventID], timestamp: int, mx_room: RoomID, @@ -127,7 +125,7 @@ class Message: return columns = [col.strip('"') for col in cls.columns.split(", ")] records = [ - (mxid, mx_room, ktid, oti, index, kt_chat, kt_receiver, kt_sender, timestamp) + (mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp) for index, mxid in enumerate(event_ids) ] async with cls.db.acquire() as conn, conn.transaction(): @@ -142,19 +140,17 @@ class Message: q, self.mxid, self.mx_room, - self.ktid, - self.kt_txn_id, + str(self.ktid), self.index, - self.kt_chat, - self.kt_receiver, - self.kt_sender, + bytes(self.kt_chat), + bytes(self.kt_receiver), self.timestamp, ) async def delete(self) -> None: q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' - await self.db.execute(q, self.ktid, self.kt_receiver, self.index) + await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index) async def update(self) -> None: q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4" - await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room) + await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room) diff --git a/matrix_appservice_kakaotalk/example-config.yaml b/matrix_appservice_kakaotalk/example-config.yaml index 3197d93..1f6334a 100644 --- a/matrix_appservice_kakaotalk/example-config.yaml +++ b/matrix_appservice_kakaotalk/example-config.yaml @@ -186,7 +186,6 @@ bridge: # Whether or not the KakaoTalk users of logged in Matrix users should be # invited to private chats when backfilling history from KakaoTalk. This is # usually needed to prevent rate limits and to allow timestamp massaging. - # TODO Is this necessary? invite_own_puppet: true # Maximum number of messages to backfill initially. # Set to 0 to disable backfilling when creating portal. diff --git a/matrix_appservice_kakaotalk/kt/client/client.py b/matrix_appservice_kakaotalk/kt/client/client.py index e589ea3..c415414 100644 --- a/matrix_appservice_kakaotalk/kt/client/client.py +++ b/matrix_appservice_kakaotalk/kt/client/client.py @@ -43,7 +43,11 @@ from ..types.chat.chat import Chatlog from ..types.oauth import OAuthCredential, OAuthInfo from ..types.request import ( deserialize_result, - ResultType, RootCommandResult, CommandResultDoneValue) + ResultType, + ResultListType, + RootCommandResult, + CommandResultDoneValue +) from .types import ChannelInfoUnion from .types import PortalChannelInfo @@ -155,7 +159,8 @@ class Client: ) -> _RequestContextManager: # TODO Is auth ever needed? headers = { - **self._headers, + # TODO Are any default headers needed? + #**self._headers, **(headers or {}), } url = URL(url) @@ -185,16 +190,24 @@ class Client: assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" return login_result - async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct: - profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") - return profile_req_struct.profile - """ async def is_connected(self) -> bool: resp = await self._rpc_client.request("is_connected") return resp["is_connected"] """ + async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct: + profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") + return profile_req_struct.profile + + async def get_profile(self, user_id: Long) -> ProfileStruct: + profile_req_struct = await self._api_user_request_result( + ProfileReqStruct, + "get_profile", + user_id=user_id.serialize() + ) + return profile_req_struct.profile + async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo: req = await self._api_user_request_result( PortalChannelInfo, @@ -204,13 +217,13 @@ class Client: req.channel_info = channel_info return req - async def get_profile(self, user_id: Long) -> ProfileStruct: - profile_req_struct = await self._api_user_request_result( - ProfileReqStruct, - "get_profile", - user_id=user_id.serialize() - ) - return profile_req_struct.profile + async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]: + return (await self._api_user_request_result( + ResultListType(Chatlog), + "get_chats", + channel_id=channel_id.serialize(), + sync_from=sync_from.serialize() if sync_from else None + ))[-limit if limit else 0:] async def stop(self) -> None: # TODO Stop all event handlers diff --git a/matrix_appservice_kakaotalk/kt/client/types.py b/matrix_appservice_kakaotalk/kt/client/types.py index 07a51c5..f44b9e8 100644 --- a/matrix_appservice_kakaotalk/kt/client/types.py +++ b/matrix_appservice_kakaotalk/kt/client/types.py @@ -15,23 +15,44 @@ # along with this program. If not, see . """Custom wrapper classes around types defined by the KakaoTalk API.""" -from typing import Optional, Union +from typing import Optional, NewType, Union from attr import dataclass -from mautrix.types import SerializableAttrs +from mautrix.types import SerializableAttrs, JSON, deserializer from ..types.channel.channel_info import NormalChannelInfo from ..types.openlink.open_channel_info import OpenChannelInfo from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo -ChannelInfoUnion = Union[NormalChannelInfo, OpenChannelInfo] -UserInfoUnion = Union[NormalChannelUserInfo, OpenChannelUserInfo] +ChannelInfoUnion = NewType("ChannelInfoUnion", Union[NormalChannelInfo, OpenChannelInfo]) + +@deserializer(ChannelInfoUnion) +def deserialize_channel_info_union(data: JSON) -> ChannelInfoUnion: + if "openLink" in data: + return OpenChannelInfo.deserialize(data) + else: + return NormalChannelInfo.deserialize(data) + +setattr(ChannelInfoUnion, "deserialize", deserialize_channel_info_union) + + +UserInfoUnion = NewType("UserInfoUnion", Union[NormalChannelUserInfo, OpenChannelUserInfo]) + +@deserializer(UserInfoUnion) +def deserialize_user_info_union(data: JSON) -> UserInfoUnion: + if "perm" in data: + return OpenChannelUserInfo.deserialize(data) + else: + return NormalChannelUserInfo.deserialize(data) + +setattr(UserInfoUnion, "deserialize", deserialize_user_info_union) + @dataclass class PortalChannelInfo(SerializableAttrs): name: str - #participants: list[PuppetUserInfo] + participants: list[UserInfoUnion] # TODO Image channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller diff --git a/matrix_appservice_kakaotalk/kt/types/bson.py b/matrix_appservice_kakaotalk/kt/types/bson.py index 60939cb..296e3e5 100644 --- a/matrix_appservice_kakaotalk/kt/types/bson.py +++ b/matrix_appservice_kakaotalk/kt/types/bson.py @@ -32,11 +32,11 @@ class Long(SerializableAttrs): return cls(**bson.loads(raw)) @classmethod - def from_optional_bytes(cls, raw: bytes | None) -> Optional["Long"]: + def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]: return cls(**bson.loads(raw)) if raw is not None else None @classmethod - def to_optional_bytes(cls, value: Optional["Long"]) -> bytes | None: + def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]: return bytes(value) if value is not None else None def serialize(self) -> JSON: @@ -48,13 +48,34 @@ class Long(SerializableAttrs): return bson.dumps(asdict(self)) def __int__(self) -> int: - # TODO Is this right? - return self.high << 32 + self.low + if self.unsigned: + pass + result = \ + ((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \ + ( self.low + (1 << 32 if self.low < 0 else 0)) + return result + (1 << 32 if self.unsigned and result < 0 else 0) def __str__(self) -> str: - return f"{self.high << 32 if self.high else ''}{self.low}" + return str(int(self)) ZERO: ClassVar["Long"] - Long.ZERO = Long(0, 0, False) + + +class IntLong(Long): + """Helper class for constructing a Long from an int.""" + def __init__(self, val: int): + if val < 0: + pass + super().__init__( + high=(val & 0xffffffff00000000) >> 32, + low = val & 0x00000000ffffffff, + unsigned=val < 0, + ) + + +class StrLong(IntLong): + """Helper class for constructing a Long from the string representation of an int.""" + def __init__(self, val: str): + super().__init__(int(val)) diff --git a/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py b/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py index eb97727..6b01b65 100644 --- a/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py +++ b/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py @@ -23,12 +23,12 @@ class KnownChannelType(str, Enum): DirectChat = "DirectChat" PlusChat = "PlusChat" MemoChat = "MemoChat" - OM = "OM" - OD = "OD" + OM = "OM" # "OpenMulti"? + OD = "OD" # "OpenDirect"? @classmethod def is_direct(cls, value: Union["KnownChannelType", str]) -> bool: - return str in [KnownChannelType.DirectChat, KnownChannelType.OD] + return value == KnownChannelType.DirectChat ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str diff --git a/matrix_appservice_kakaotalk/kt/types/request.py b/matrix_appservice_kakaotalk/kt/types/request.py index fa74098..eb862a0 100644 --- a/matrix_appservice_kakaotalk/kt/types/request.py +++ b/matrix_appservice_kakaotalk/kt/types/request.py @@ -13,12 +13,12 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Generic, Type, TypeVar, Union +from typing import Generic, Type, TypeVar, Union, Iterable from attr import dataclass from enum import IntEnum -from mautrix.types import SerializableAttrs, JSON +from mautrix.types import Serializable, SerializableAttrs, JSON from .api.auth_api_client import KnownAuthStatusCode @@ -80,7 +80,22 @@ class RootCommandResult(ResponseState): success: bool -ResultType = TypeVar("ResultType", bound=SerializableAttrs) +ResultType = TypeVar("ResultType", bound=Serializable) + +def ResultListType(result_type: Type[ResultType]): + class _ResultListType(list[result_type], Serializable): + def __init__(self, iterable: Iterable[result_type]=()): + list.__init__(self, (result_type.deserialize(x) for x in iterable)) + + def serialize(self) -> list[JSON]: + return [v.serialize() for v in self] + + @classmethod + def deserialize(cls, data: list[JSON]) -> "_ResultListType": + return cls(data) + + return _ResultListType + @dataclass class CommandResultDoneValue(RootCommandResult, Generic[ResultType]): diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py index e1d82a1..c47c14b 100644 --- a/matrix_appservice_kakaotalk/portal.py +++ b/matrix_appservice_kakaotalk/portal.py @@ -45,12 +45,14 @@ from .db import ( Message as DBMessage, Portal as DBPortal, ) +from .formatter.from_kakaotalk import kakaotalk_to_matrix -from .kt.types.bson import Long +from .kt.types.bson import Long, IntLong +from .kt.types.channel.channel_info import ChannelInfo from .kt.types.channel.channel_type import KnownChannelType, ChannelType -from .kt.types.user.channel_user_info import DisplayUserInfo +from .kt.types.chat.chat import Chatlog -from .kt.client.types import PortalChannelInfo +from .kt.client.types import UserInfoUnion, PortalChannelInfo if TYPE_CHECKING: from .__main__ import KakaoTalkBridge @@ -201,7 +203,7 @@ class Portal(DBPortal, BasePortal): #self._update_photo(source, info.image), ) ) - changed = await self._update_participants(source, info.channel_info.displayUserList) or changed + changed = await self._update_participants(source, info.participants) or changed if changed or force_save: await self.update_bridge_info() await self.save() @@ -342,7 +344,7 @@ class Portal(DBPortal, BasePortal): ) """ - async def _update_participants(self, source: u.User, participants: list[DisplayUserInfo]) -> bool: + async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool: changed = False # TODO nick_map? for participant in participants: @@ -556,10 +558,10 @@ class Portal(DBPortal, BasePortal): ) if not self.is_direct: - await self._update_participants(source, info.channel_info.displayUserList) + await self._update_participants(source, info.participants) try: - await self.backfill(source, is_initial=True, channel=info.channel_info) + await self.backfill(source, is_initial=True, channel_info=info.channel_info) except Exception: self.log.exception("Failed to backfill new portal") @@ -752,31 +754,131 @@ class Portal(DBPortal, BasePortal): self, source: u.User, sender: p.Puppet, - message: str, - reply_to: None = None, + message: Chatlog, + reply_to: Chatlog | None = None, ) -> None: try: await self._handle_remote_message(source, sender, message, reply_to) except Exception: self.log.exception( - "Error handling Kakaotalk message " + "Error handling KakaoTalk message %s", + message.logId, ) async def _handle_remote_message( self, source: u.User, sender: p.Puppet, - message: str, - reply_to: None = None, + message: Chatlog, + reply_to: Chatlog | None = None, ) -> None: - self.log.info("TODO") + self.log.debug(f"Handling KakaoTalk event {message.logId}") + if not self.mxid: + mxid = await self.create_matrix_room(source) + if not mxid: + # Failed to create + return + if not await self._bridge_own_message_pm(source, sender, f"message {message.logId}"): + return + intent = sender.intent_for(self) + if ( + self._backfill_leave is not None + and self.ktid != sender.ktid + and intent != sender.intent + and intent not in self._backfill_leave + ): + self.log.debug("Adding %s's default puppet to room for backfilling", sender.mxid) + await self.main_intent.invite_user(self.mxid, intent.mxid) + await intent.ensure_joined(self.mxid) + self._backfill_leave.add(intent) + + if message.attachment: + self.log.info("TODO: _handle_remote_message attachments") + if message.supplement: + self.log.info("TODO: _handle_remote_message supplements") + if message.text: + content = await kakaotalk_to_matrix(message.text) + event_id = await self._send_message(intent, content, timestamp=message.sendAt) + await DBMessage( + mxid=event_id, + mx_room=self.mxid, + ktid=message.logId, + index=0, + kt_chat=self.ktid, + kt_receiver=self.kt_receiver, + timestamp=message.sendAt, + ).insert() + await self._send_delivery_receipt(event_id) + else: + self.log.warning(f"Unhandled KakaoTalk message {message.logId}") # TODO Many more remote handlers # endregion - async def backfill(self, source: u.User, is_initial: bool, channel: PortalChannelInfo) -> None: - self.log.info("TODO: backfill") + async def backfill(self, source: u.User, is_initial: bool, channel_info: ChannelInfo) -> None: + limit = ( + self.config["bridge.backfill.initial_limit"] + if is_initial + else self.config["bridge.backfill.missed_limit"] + ) + if limit == 0: + return + elif limit < 0: + limit = None + last_log_id = None + if not is_initial and channel_info.lastChatLog: + last_log_id = channel_info.lastChatLog.logId + most_recent = await DBMessage.get_most_recent(self.ktid, self.kt_receiver) + if most_recent and is_initial: + self.log.debug("Not backfilling %s: already bridged messages found", self.ktid_log) + # TODO Should this be removed? With it, can't sync an empty portal! + #elif (not most_recent or not most_recent.timestamp) and not is_initial: + # self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log) + elif last_log_id and most_recent and int(most_recent.ktid) >= int(last_log_id): + self.log.debug( + "Not backfilling %s: last activity is equal to most recent bridged " + "message (%s >= %s)", + self.ktid_log, + most_recent.ktid, + last_log_id, + ) + else: + with self.backfill_lock: + await self._backfill( + source, + limit, + most_recent.ktid if most_recent else None, + channel_info=channel_info, + ) + + async def _backfill( + self, + source: u.User, + limit: int | None, + after_log_id: Long | None, + channel_info: ChannelInfo, + ) -> None: + self.log.debug("Backfilling history through %s", source.mxid) + self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid)) + messages = await source.client.get_chats( + channel_info.channelId, + limit, + after_log_id + ) + if not messages: + self.log.debug("Didn't get any messages from server") + return + self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server") + self._backfill_leave = set() + async with NotificationDisabler(self.mxid, source): + for message in messages: + puppet = await p.Puppet.get_by_ktid(message.sender.userId) + await self.handle_remote_message(source, puppet, message) + for intent in self._backfill_leave: + self.log.trace("Leaving room with %s post-backfill", intent.mxid) + await intent.leave_room(self.mxid) + self.log.info("Backfilled %d messages through %s", len(messages), source.mxid) # region Database getters diff --git a/matrix_appservice_kakaotalk/puppet.py b/matrix_appservice_kakaotalk/puppet.py index 6b718e4..180a1ac 100644 --- a/matrix_appservice_kakaotalk/puppet.py +++ b/matrix_appservice_kakaotalk/puppet.py @@ -31,8 +31,9 @@ from . import matrix as m, portal as p, user as u from .config import Config from .db import Puppet as DBPuppet -from .kt.types.bson import Long -from .kt.types.user.channel_user_info import DisplayUserInfo +from .kt.types.bson import Long, StrLong + +from .kt.client.types import UserInfoUnion if TYPE_CHECKING: from .__main__ import KakaoTalkBridge @@ -42,7 +43,7 @@ class Puppet(DBPuppet, BasePuppet): mx: m.MatrixHandler config: Config hs_domain: str - mxid_template: SimpleTemplate[int] + mxid_template: SimpleTemplate[StrLong] by_ktid: dict[Long, Puppet] = {} by_custom_mxid: dict[UserID, Puppet] = {} @@ -126,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet): keyword="userid", prefix="@", suffix=f":{Puppet.hs_domain}", - type=int, + type=StrLong, ) cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"] cls.homeserver_url_map = { @@ -147,7 +148,7 @@ class Puppet(DBPuppet, BasePuppet): async def update_info( self, source: u.User, - info: DisplayUserInfo, + info: UserInfoUnion, update_avatar: bool = True, ) -> Puppet: self._last_info_sync = datetime.now() @@ -161,7 +162,7 @@ class Puppet(DBPuppet, BasePuppet): self.log.exception(f"Failed to update info from source {source.ktid}") return self - async def _update_name(self, info: DisplayUserInfo) -> bool: + async def _update_name(self, info: UserInfoUnion) -> bool: name = info.nickname if name != self.name or not self.name_set: self.name = name @@ -259,7 +260,7 @@ class Puppet(DBPuppet, BasePuppet): return None @classmethod - def get_id_from_mxid(cls, mxid: UserID) -> int | None: + def get_id_from_mxid(cls, mxid: UserID) -> Long | None: return cls.mxid_template.parse(mxid) @classmethod diff --git a/matrix_appservice_kakaotalk/user.py b/matrix_appservice_kakaotalk/user.py index 8c2f778..ba75fcd 100644 --- a/matrix_appservice_kakaotalk/user.py +++ b/matrix_appservice_kakaotalk/user.py @@ -415,6 +415,7 @@ class User(DBUser, BaseUser): assert self.client try: + # TODO if not is_startup, close existing listeners login_result = await self.client.start() await self._sync_channels(login_result, is_startup) # TODO connect listeners, even if channel sync fails (except if it's an auth failure) @@ -502,7 +503,7 @@ class User(DBUser, BaseUser): await portal.create_matrix_room(self, portal_info) else: await portal.update_matrix_room(self, portal_info) - await portal.backfill(self, is_initial=False, channel=channel_info) + await portal.backfill(self, is_initial=False, channel_info=channel_info) async def get_notice_room(self) -> RoomID: if not self.notice_room: diff --git a/node/src/client.js b/node/src/client.js index 1e34195..521f317 100644 --- a/node/src/client.js +++ b/node/src/client.js @@ -202,6 +202,14 @@ export default class PeerClient { return loginRes } + /** + * TODO Consider caching per-user + * @param {string} uuid + */ + async #createAuthClient(uuid) { + return await AuthApiClient.create("KakaoTalk Bridge", uuid) + } + // TODO Wrapper for per-user commands /** @@ -237,14 +245,6 @@ export default class PeerClient { return await oAuthClient.renew(req.oauth_credential) } - /** - * TODO Consider caching per-user - * @param {string} uuid - */ - async #createAuthClient(uuid) { - return await AuthApiClient.create("KakaoTalk Bridge", uuid) - } - /** * @param {Object} req * @param {string} req.mxid @@ -314,22 +314,33 @@ export default class PeerClient { * @param {string} req.mxid * @param {Long} req.channel_id */ - getPortalChannelInfo = (req) => { + getPortalChannelInfo = async (req) => { const userClient = this.#getUser(req.mxid) const talkChannel = userClient.talkClient.channelList.get(req.channel_id) - /* TODO Decide if this is needed. If it is, make function async! const res = await talkChannel.updateAll() if (!res.success) return res - */ return this.#makeCommandResult({ name: talkChannel.getDisplayName(), - //participants: Array.from(talkChannel.getAllUserInfo()), + participants: Array.from(talkChannel.getAllUserInfo()), // TODO Image }) } + /** + * @param {Object} req + * @param {string} req.mxid + * @param {Long} req.channel_id + * @param {Long?} req.sync_from + */ + getChats = async (req) => { + const userClient = this.#getUser(req.mxid) + const talkChannel = userClient.talkClient.channelList.get(req.channel_id) + + return await talkChannel.getChatListFrom(req.sync_from) + } + /** * @param {Object} req * @param {string} req.mxid @@ -421,6 +432,7 @@ export default class PeerClient { register_device: this.registerDevice, get_own_profile: this.getOwnProfile, get_portal_channel_info: this.getPortalChannelInfo, + get_chats: this.getChats, get_profile: this.getProfile, /* send: req => this.puppet.sendMessage(req.chat_id, req.text),