Fix how DirectChat and MemoChat channels are handled
This commit is contained in:
parent
c7df4b1e65
commit
491cdca7b6
|
@ -28,7 +28,7 @@ class KnownChannelType(str, Enum):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_direct(cls, value: Union["KnownChannelType", str]) -> bool:
|
def is_direct(cls, value: Union["KnownChannelType", str]) -> bool:
|
||||||
return value == KnownChannelType.DirectChat
|
return value in [cls.DirectChat, cls.MemoChat]
|
||||||
|
|
||||||
|
|
||||||
ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str
|
ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str
|
||||||
|
|
|
@ -92,6 +92,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
config: Config
|
config: Config
|
||||||
|
|
||||||
_main_intent: IntentAPI | None
|
_main_intent: IntentAPI | None
|
||||||
|
_kt_sender: int | None
|
||||||
_create_room_lock: asyncio.Lock
|
_create_room_lock: asyncio.Lock
|
||||||
_send_locks: dict[int, asyncio.Lock]
|
_send_locks: dict[int, asyncio.Lock]
|
||||||
_noop_lock: FakeLock = FakeLock()
|
_noop_lock: FakeLock = FakeLock()
|
||||||
|
@ -132,6 +133,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
self.log = self.log.getChild(self.ktid_log)
|
self.log = self.log.getChild(self.ktid_log)
|
||||||
|
|
||||||
self._main_intent = None
|
self._main_intent = None
|
||||||
|
self._kt_sender = None
|
||||||
self._create_room_lock = asyncio.Lock()
|
self._create_room_lock = asyncio.Lock()
|
||||||
self._send_locks = {}
|
self._send_locks = {}
|
||||||
self._typing = set()
|
self._typing = set()
|
||||||
|
@ -176,17 +178,31 @@ class Portal(DBPortal, BasePortal):
|
||||||
@property
|
@property
|
||||||
def ktid_log(self) -> str:
|
def ktid_log(self) -> str:
|
||||||
if self.is_direct:
|
if self.is_direct:
|
||||||
return f"{self.ktid}<->{self.kt_receiver}"
|
return f"{self.ktid}->{self.kt_receiver}"
|
||||||
return str(self.ktid)
|
return str(self.ktid)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_direct(self) -> bool:
|
def is_direct(self) -> bool:
|
||||||
return KnownChannelType.is_direct(self.kt_type)
|
return KnownChannelType.is_direct(self.kt_type)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kt_sender(self) -> int | None:
|
||||||
|
if self.is_direct:
|
||||||
|
if not self._kt_sender:
|
||||||
|
raise ValueError("Direct chat portal must set sender")
|
||||||
|
else:
|
||||||
|
if self._kt_sender:
|
||||||
|
raise ValueError(f"Non-direct chat portal should have no sender, but has sender {self._kt_sender}")
|
||||||
|
return self._kt_sender
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def main_intent(self) -> IntentAPI:
|
def main_intent(self) -> IntentAPI:
|
||||||
if not self._main_intent:
|
if not self._main_intent:
|
||||||
raise ValueError("Portal must be postinit()ed before main_intent can be used")
|
raise ValueError(
|
||||||
|
"Portal must be postinit()ed before main_intent can be used"
|
||||||
|
if not self.is_direct else
|
||||||
|
"Direct chat portal must call postinit and _update_participants before main_intent can be used"
|
||||||
|
)
|
||||||
return self._main_intent
|
return self._main_intent
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
@ -227,7 +243,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
source: u.User,
|
source: u.User,
|
||||||
info: PortalChannelInfo | None = None,
|
info: PortalChannelInfo | None = None,
|
||||||
force_save: bool = False,
|
force_save: bool = False,
|
||||||
) -> PortalChannelInfo | None:
|
) -> PortalChannelInfo:
|
||||||
if not info:
|
if not info:
|
||||||
self.log.debug("Called update_info with no info, fetching channel info...")
|
self.log.debug("Called update_info with no info, fetching channel info...")
|
||||||
info = await source.client.get_portal_channel_info(self.ktid)
|
info = await source.client.get_portal_channel_info(self.ktid)
|
||||||
|
@ -384,11 +400,19 @@ class Portal(DBPortal, BasePortal):
|
||||||
|
|
||||||
async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool:
|
async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool:
|
||||||
changed = False
|
changed = False
|
||||||
|
if not self._main_intent:
|
||||||
|
assert self.is_direct, "_main_intent for non-direct chat portal should have been set already"
|
||||||
|
self._kt_sender = participants[
|
||||||
|
0 if self.kt_type == KnownChannelType.MemoChat or participants[0].userId != source.ktid else 1
|
||||||
|
].userId
|
||||||
|
self._main_intent = (await p.Puppet.get_by_ktid(self._kt_sender)).default_mxid_intent
|
||||||
|
else:
|
||||||
|
self._kt_sender = (await p.Puppet.get_by_mxid(self._main_intent.mxid)).ktid if self.is_direct else None
|
||||||
# TODO nick_map?
|
# TODO nick_map?
|
||||||
for participant in participants:
|
for participant in participants:
|
||||||
puppet = await p.Puppet.get_by_ktid(participant.userId)
|
puppet = await p.Puppet.get_by_ktid(participant.userId)
|
||||||
await puppet.update_info(source, participant)
|
await puppet.update_info(source, participant)
|
||||||
if self.is_direct and self.ktid == puppet.ktid and self.encrypted:
|
if self.is_direct and self._kt_sender == puppet.ktid and self.encrypted:
|
||||||
changed = await self._update_name(puppet.name) or changed
|
changed = await self._update_name(puppet.name) or changed
|
||||||
changed = await self._update_photo_from_puppet(puppet) or changed
|
changed = await self._update_photo_from_puppet(puppet) or changed
|
||||||
if self.mxid:
|
if self.mxid:
|
||||||
|
@ -418,6 +442,8 @@ class Portal(DBPortal, BasePortal):
|
||||||
async def _update_matrix_room(
|
async def _update_matrix_room(
|
||||||
self, source: u.User, info: PortalChannelInfo | None = None
|
self, source: u.User, info: PortalChannelInfo | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
info = await self.update_info(source, info)
|
||||||
|
|
||||||
puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
|
puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
|
||||||
await self.main_intent.invite_user(
|
await self.main_intent.invite_user(
|
||||||
self.mxid,
|
self.mxid,
|
||||||
|
@ -430,11 +456,6 @@ class Portal(DBPortal, BasePortal):
|
||||||
if did_join and self.is_direct:
|
if did_join and self.is_direct:
|
||||||
await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
|
await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
|
||||||
|
|
||||||
info = await self.update_info(source, info)
|
|
||||||
if not info:
|
|
||||||
self.log.warning("Canceling _update_matrix_room as update_info didn't return info")
|
|
||||||
return
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
#await self._sync_read_receipts(info.read_receipts.nodes)
|
#await self._sync_read_receipts(info.read_receipts.nodes)
|
||||||
|
|
||||||
|
@ -520,6 +541,9 @@ class Portal(DBPortal, BasePortal):
|
||||||
return self.mxid
|
return self.mxid
|
||||||
|
|
||||||
self.log.debug(f"Creating Matrix room")
|
self.log.debug(f"Creating Matrix room")
|
||||||
|
if self.is_direct:
|
||||||
|
# NOTE Must do this to find the other member of the DM, since the channel ID != the member's ID!
|
||||||
|
await self._update_participants(source, info.participants)
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
initial_state = [
|
initial_state = [
|
||||||
{
|
{
|
||||||
|
@ -547,9 +571,6 @@ class Portal(DBPortal, BasePortal):
|
||||||
invites.append(self.az.bot_mxid)
|
invites.append(self.az.bot_mxid)
|
||||||
|
|
||||||
info = await self.update_info(source=source, info=info)
|
info = await self.update_info(source=source, info=info)
|
||||||
if not info:
|
|
||||||
self.log.debug("update_info() didn't return info, cancelling room creation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self.encrypted or not self.is_direct:
|
if self.encrypted or not self.is_direct:
|
||||||
name = self.name
|
name = self.name
|
||||||
|
@ -602,6 +623,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.is_direct:
|
if not self.is_direct:
|
||||||
|
# NOTE Calling this after room creation to invite participants
|
||||||
await self._update_participants(source, info.participants)
|
await self._update_participants(source, info.participants)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -964,11 +986,10 @@ class Portal(DBPortal, BasePortal):
|
||||||
self.by_ktid[self._ktid_full] = self
|
self.by_ktid[self._ktid_full] = self
|
||||||
if self.mxid:
|
if self.mxid:
|
||||||
self.by_mxid[self.mxid] = self
|
self.by_mxid[self.mxid] = self
|
||||||
self._main_intent = (
|
if not self.is_direct:
|
||||||
(await p.Puppet.get_by_ktid(self.ktid)).default_mxid_intent
|
self._main_intent = self.az.intent
|
||||||
if self.is_direct
|
else:
|
||||||
else self.az.intent
|
self.log.debug("Not setting _main_intent of direct chat until after checking participant list")
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@async_getter_lock
|
@async_getter_lock
|
||||||
|
@ -995,6 +1016,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
create: bool = True,
|
create: bool = True,
|
||||||
kt_type: ChannelType | None = None,
|
kt_type: ChannelType | None = None,
|
||||||
) -> Portal | None:
|
) -> Portal | None:
|
||||||
|
# TODO Find out if direct channels are shared. If so, don't need kt_receiver!
|
||||||
if kt_type:
|
if kt_type:
|
||||||
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
|
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
|
||||||
ktid_full = (ktid, kt_receiver)
|
ktid_full = (ktid, kt_receiver)
|
||||||
|
|
|
@ -33,6 +33,7 @@ from .db import Puppet as DBPuppet
|
||||||
|
|
||||||
from .kt.types.bson import Long
|
from .kt.types.bson import Long
|
||||||
|
|
||||||
|
from .kt.types.channel.channel_type import KnownChannelType
|
||||||
from .kt.client.types import UserInfoUnion
|
from .kt.client.types import UserInfoUnion
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -94,8 +95,9 @@ class Puppet(DBPuppet, BasePuppet):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
|
async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
|
||||||
|
# TODO Find out if direct channels are shared. If not, default puppet shouldn't leave!
|
||||||
portal = await p.Portal.get_by_mxid(room_id)
|
portal = await p.Portal.get_by_mxid(room_id)
|
||||||
return portal and portal.ktid != self.ktid
|
return portal and portal.kt_type != KnownChannelType.MemoChat
|
||||||
|
|
||||||
async def _leave_rooms_with_default_user(self) -> None:
|
async def _leave_rooms_with_default_user(self) -> None:
|
||||||
await super()._leave_rooms_with_default_user()
|
await super()._leave_rooms_with_default_user()
|
||||||
|
|
Loading…
Reference in New Issue