From 491cdca7b64febe9d66f83825a18ad26b25975a5 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Sun, 13 Mar 2022 00:30:29 -0500 Subject: [PATCH] Fix how DirectChat and MemoChat channels are handled --- .../kt/types/channel/channel_type.py | 2 +- matrix_appservice_kakaotalk/portal.py | 56 +++++++++++++------ matrix_appservice_kakaotalk/puppet.py | 4 +- 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py b/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py index 6b01b65..634c702 100644 --- a/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py +++ b/matrix_appservice_kakaotalk/kt/types/channel/channel_type.py @@ -28,7 +28,7 @@ class KnownChannelType(str, Enum): @classmethod def is_direct(cls, value: Union["KnownChannelType", str]) -> bool: - return value == KnownChannelType.DirectChat + return value in [cls.DirectChat, cls.MemoChat] ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py index 1065890..5b39445 100644 --- a/matrix_appservice_kakaotalk/portal.py +++ b/matrix_appservice_kakaotalk/portal.py @@ -92,6 +92,7 @@ class Portal(DBPortal, BasePortal): config: Config _main_intent: IntentAPI | None + _kt_sender: int | None _create_room_lock: asyncio.Lock _send_locks: dict[int, asyncio.Lock] _noop_lock: FakeLock = FakeLock() @@ -132,6 +133,7 @@ class Portal(DBPortal, BasePortal): self.log = self.log.getChild(self.ktid_log) self._main_intent = None + self._kt_sender = None self._create_room_lock = asyncio.Lock() self._send_locks = {} self._typing = set() @@ -176,17 +178,31 @@ class Portal(DBPortal, BasePortal): @property def ktid_log(self) -> str: if self.is_direct: - return f"{self.ktid}<->{self.kt_receiver}" + return f"{self.ktid}->{self.kt_receiver}" return str(self.ktid) @property def is_direct(self) -> bool: return KnownChannelType.is_direct(self.kt_type) + @property + def kt_sender(self) -> int | None: + if self.is_direct: + if not self._kt_sender: + raise ValueError("Direct chat portal must set sender") + else: + if self._kt_sender: + raise ValueError(f"Non-direct chat portal should have no sender, but has sender {self._kt_sender}") + return self._kt_sender + @property def main_intent(self) -> IntentAPI: if not self._main_intent: - raise ValueError("Portal must be postinit()ed before main_intent can be used") + raise ValueError( + "Portal must be postinit()ed before main_intent can be used" + if not self.is_direct else + "Direct chat portal must call postinit and _update_participants before main_intent can be used" + ) return self._main_intent # endregion @@ -227,7 +243,7 @@ class Portal(DBPortal, BasePortal): source: u.User, info: PortalChannelInfo | None = None, force_save: bool = False, - ) -> PortalChannelInfo | None: + ) -> PortalChannelInfo: if not info: self.log.debug("Called update_info with no info, fetching channel info...") info = await source.client.get_portal_channel_info(self.ktid) @@ -384,11 +400,19 @@ class Portal(DBPortal, BasePortal): async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool: changed = False + if not self._main_intent: + assert self.is_direct, "_main_intent for non-direct chat portal should have been set already" + self._kt_sender = participants[ + 0 if self.kt_type == KnownChannelType.MemoChat or participants[0].userId != source.ktid else 1 + ].userId + self._main_intent = (await p.Puppet.get_by_ktid(self._kt_sender)).default_mxid_intent + else: + self._kt_sender = (await p.Puppet.get_by_mxid(self._main_intent.mxid)).ktid if self.is_direct else None # TODO nick_map? for participant in participants: puppet = await p.Puppet.get_by_ktid(participant.userId) await puppet.update_info(source, participant) - if self.is_direct and self.ktid == puppet.ktid and self.encrypted: + if self.is_direct and self._kt_sender == puppet.ktid and self.encrypted: changed = await self._update_name(puppet.name) or changed changed = await self._update_photo_from_puppet(puppet) or changed if self.mxid: @@ -418,6 +442,8 @@ class Portal(DBPortal, BasePortal): async def _update_matrix_room( self, source: u.User, info: PortalChannelInfo | None = None ) -> None: + info = await self.update_info(source, info) + puppet = await p.Puppet.get_by_custom_mxid(source.mxid) await self.main_intent.invite_user( self.mxid, @@ -430,11 +456,6 @@ class Portal(DBPortal, BasePortal): if did_join and self.is_direct: await source.update_direct_chats({self.main_intent.mxid: [self.mxid]}) - info = await self.update_info(source, info) - if not info: - self.log.warning("Canceling _update_matrix_room as update_info didn't return info") - return - # TODO #await self._sync_read_receipts(info.read_receipts.nodes) @@ -520,6 +541,9 @@ class Portal(DBPortal, BasePortal): return self.mxid self.log.debug(f"Creating Matrix room") + if self.is_direct: + # NOTE Must do this to find the other member of the DM, since the channel ID != the member's ID! + await self._update_participants(source, info.participants) name: str | None = None initial_state = [ { @@ -547,9 +571,6 @@ class Portal(DBPortal, BasePortal): invites.append(self.az.bot_mxid) info = await self.update_info(source=source, info=info) - if not info: - self.log.debug("update_info() didn't return info, cancelling room creation") - return None if self.encrypted or not self.is_direct: name = self.name @@ -602,6 +623,7 @@ class Portal(DBPortal, BasePortal): ) if not self.is_direct: + # NOTE Calling this after room creation to invite participants await self._update_participants(source, info.participants) try: @@ -964,11 +986,10 @@ class Portal(DBPortal, BasePortal): self.by_ktid[self._ktid_full] = self if self.mxid: self.by_mxid[self.mxid] = self - self._main_intent = ( - (await p.Puppet.get_by_ktid(self.ktid)).default_mxid_intent - if self.is_direct - else self.az.intent - ) + if not self.is_direct: + self._main_intent = self.az.intent + else: + self.log.debug("Not setting _main_intent of direct chat until after checking participant list") @classmethod @async_getter_lock @@ -995,6 +1016,7 @@ class Portal(DBPortal, BasePortal): create: bool = True, kt_type: ChannelType | None = None, ) -> Portal | None: + # TODO Find out if direct channels are shared. If so, don't need kt_receiver! if kt_type: kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0 ktid_full = (ktid, kt_receiver) diff --git a/matrix_appservice_kakaotalk/puppet.py b/matrix_appservice_kakaotalk/puppet.py index 644bfba..65a7366 100644 --- a/matrix_appservice_kakaotalk/puppet.py +++ b/matrix_appservice_kakaotalk/puppet.py @@ -33,6 +33,7 @@ from .db import Puppet as DBPuppet from .kt.types.bson import Long +from .kt.types.channel.channel_type import KnownChannelType from .kt.client.types import UserInfoUnion if TYPE_CHECKING: @@ -94,8 +95,9 @@ class Puppet(DBPuppet, BasePuppet): return False async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool: + # TODO Find out if direct channels are shared. If not, default puppet shouldn't leave! portal = await p.Portal.get_by_mxid(room_id) - return portal and portal.ktid != self.ktid + return portal and portal.kt_type != KnownChannelType.MemoChat async def _leave_rooms_with_default_user(self) -> None: await super()._leave_rooms_with_default_user()