Many fixes thanks to mypy

Also add some missing license headers
This commit is contained in:
Andrew Ferrazzutti 2022-04-28 01:50:47 -04:00
parent 2143282195
commit e952c05d35
16 changed files with 188 additions and 68 deletions

View File

@ -1,3 +1,18 @@
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from .auth import SECTION_AUTH from .auth import SECTION_AUTH
from .conn import SECTION_CONNECTION from .conn import SECTION_CONNECTION
from .kakaotalk import SECTION_FRIENDS from .kakaotalk import SECTION_FRIENDS

View File

@ -60,16 +60,19 @@ async def whoami(evt: CommandEvent) -> None:
await evt.mark_read() await evt.mark_read()
try: try:
own_info = await evt.sender.get_own_info() own_info = await evt.sender.get_own_info()
except SerializerError:
evt.sender.log.exception("Failed to deserialize settings struct")
own_info = None
except CommandException as e:
await evt.reply(f"Error from KakaoTalk: {e}")
if own_info:
await evt.reply( await evt.reply(
f"You're logged in as `{own_info.more.uuid}` (nickname: {own_info.more.nickName}, user ID: {evt.sender.ktid})." f"You're logged in as `{own_info.more.uuid}` (nickname: {own_info.more.nickName}, user ID: {evt.sender.ktid})."
) )
except SerializerError: else:
evt.sender.log.exception("Failed to deserialize settings struct")
await evt.reply( await evt.reply(
f"You're logged in, but the bridge is unable to retrieve your profile information (user ID: {evt.sender.ktid})." f"You're logged in, but the bridge is unable to retrieve your profile information (user ID: {evt.sender.ktid})."
) )
except CommandException as e:
await evt.reply(f"Error from KakaoTalk: {e}")
@command_handler( @command_handler(

View File

@ -1,12 +1,31 @@
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from mautrix.bridge.commands import CommandEvent as BaseCommandEvent from mautrix.bridge.commands import CommandEvent as BaseCommandEvent
if TYPE_CHECKING: if TYPE_CHECKING:
from ..__main__ import KakaoTalkBridge from ..__main__ import KakaoTalkBridge
from ..portal import Portal
from ..user import User from ..user import User
class CommandEvent(BaseCommandEvent): class CommandEvent(BaseCommandEvent):
bridge: "KakaoTalkBridge" bridge: KakaoTalkBridge
sender: "User" portal: Portal
sender: User

View File

@ -1,3 +1,18 @@
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from .message import Message from .message import Message

View File

@ -23,7 +23,7 @@ from attr import dataclass, field
from mautrix.types import EventID, RoomID from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database, Scheme from mautrix.util.async_db import Database, Scheme
from ..kt.types.bson import Long from ..kt.types.bson import Long, to_optional_long
fake_db = Database.create("") if TYPE_CHECKING else None fake_db = Database.create("") if TYPE_CHECKING else None
@ -34,7 +34,7 @@ class Message:
mxid: EventID mxid: EventID
mx_room: RoomID mx_room: RoomID
ktid: Long | None = field(converter=lambda ktid: Long(ktid) if ktid is not None else None) ktid: Long | None = field(converter=to_optional_long)
index: int index: int
kt_chat: Long = field(converter=Long) kt_chat: Long = field(converter=Long)
kt_receiver: Long = field(converter=Long) kt_receiver: Long = field(converter=Long)

View File

@ -1,3 +1,18 @@
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import UpgradeTable from mautrix.util.async_db import UpgradeTable
upgrade_table = UpgradeTable() upgrade_table = UpgradeTable()

View File

@ -23,7 +23,7 @@ from attr import dataclass, field
from mautrix.types import RoomID, UserID from mautrix.types import RoomID, UserID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from ..kt.types.bson import Long from ..kt.types.bson import Long, to_optional_long
fake_db = Database.create("") if TYPE_CHECKING else None fake_db = Database.create("") if TYPE_CHECKING else None
@ -33,7 +33,7 @@ class User:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
mxid: UserID mxid: UserID
ktid: Long | None = field(converter=lambda x: Long(x) if x is not None else None) ktid: Long | None = field(converter=to_optional_long)
uuid: str | None uuid: str | None
access_token: str | None access_token: str | None
refresh_token: str | None refresh_token: str | None

View File

@ -1,2 +1,17 @@
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from .from_kakaotalk import kakaotalk_to_matrix from .from_kakaotalk import kakaotalk_to_matrix
from .from_matrix import matrix_to_kakaotalk from .from_matrix import matrix_to_kakaotalk

View File

@ -111,7 +111,7 @@ async def matrix_to_kakaotalk(
# NOTE By design, this *throws* if user intent can't be matched (i.e. if a reply can't be created) # NOTE By design, this *throws* if user intent can't be matched (i.e. if a reply can't be created)
if content.relates_to.rel_type == RelationType.REPLY and not skip_reply: if content.relates_to.rel_type == RelationType.REPLY and not skip_reply:
message = await DBMessage.get_by_mxid(content.relates_to.event_id, room_id) message = await DBMessage.get_by_mxid(content.relates_to.event_id, room_id)
if not message: if not message or not message.ktid:
raise ValueError( raise ValueError(
f"Couldn't find reply target {content.relates_to.event_id}" f"Couldn't find reply target {content.relates_to.event_id}"
" to bridge text message reply metadata to KakaoTalk" " to bridge text message reply metadata to KakaoTalk"
@ -167,17 +167,17 @@ async def matrix_to_kakaotalk(
mxid = mention.extra_info["user_id"] mxid = mention.extra_info["user_id"]
if mxid not in joined_members: if mxid not in joined_members:
continue continue
ktid = await _get_id_from_mxid(mxid) ktid = await _get_id_from_mxid(mxid, portal)
if ktid is None: if ktid is None:
continue continue
at += text[last_offset:mention.offset+1].count("@") at += text[last_offset:mention.offset+1].count("@")
last_offset = mention.offset+1 last_offset = mention.offset+1
mention = mentions_by_user.setdefault(ktid, MentionStruct( mention_by_user = mentions_by_user.setdefault(ktid, MentionStruct(
at=[], at=[],
len=mention.length, len=mention.length,
user_id=ktid, user_id=ktid,
)) ))
mention.at.append(at) mention_by_user.at.append(at)
mentions = list(mentions_by_user.values()) if mentions_by_user else None mentions = list(mentions_by_user.values()) if mentions_by_user else None
else: else:
text = content.body text = content.body

View File

@ -137,7 +137,7 @@ class Client:
return cls._api_request_void("register_device", passcode=passcode, **req) return cls._api_request_void("register_device", passcode=passcode, **req)
@classmethod @classmethod
async def login(cls, **req: JSON) -> Awaitable[OAuthCredential]: def login(cls, **req: JSON) -> Awaitable[OAuthCredential]:
""" """
Obtain a session token by logging in with user-provided credentials. Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials. Must have first called register_device with these credentials.
@ -493,7 +493,7 @@ class Client:
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None: async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
while True: while True:
try: try:
return await self._api_user_request_result( await self._api_user_request_void(
command, oauth_credential=self._oauth_credential, renew=False, **data command, oauth_credential=self._oauth_credential, renew=False, **data
) )
except InvalidAccessToken: except InvalidAccessToken:
@ -599,7 +599,7 @@ class Client:
res = None res = None
return self._on_disconnect(res) return self._on_disconnect(res)
def _on_switch_server(self) -> Awaitable[None]: def _on_switch_server(self, _: dict[str, JSON]) -> Awaitable[None]:
# TODO Reconnect automatically instead # TODO Reconnect automatically instead
return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER)) return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))

View File

@ -70,7 +70,7 @@ class ResponseError(Exception):
pass pass
_status_code_message_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int] = { _status_code_message_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int, str] = {
KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number", KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number",
KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success", KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success",
KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)", KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)",

View File

@ -26,9 +26,9 @@ from ...bson import Long
class OpenChatSettingsStruct(SerializableAttrs): class OpenChatSettingsStruct(SerializableAttrs):
chatMemberMaxJoin: int chatMemberMaxJoin: int
chatRoomMaxJoin: int chatRoomMaxJoin: int
createLinkLimit: 10; createLinkLimit = 10
createCardLinkLimit: 3; createCardLinkLimit = 3
numOfStaffLimit: 5; numOfStaffLimit = 5
rewritable: bool rewritable: bool
handoverEnabled: bool handoverEnabled: bool

View File

@ -13,6 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any
from mautrix.types import Serializable, JSON from mautrix.types import Serializable, JSON
@ -23,3 +25,6 @@ class Long(int, Serializable):
@classmethod @classmethod
def deserialize(cls, raw: JSON) -> "Long": def deserialize(cls, raw: JSON) -> "Long":
return cls(raw) return cls(raw)
def to_optional_long(x: Any | None) -> Long | None:
return Long(x) if x is not None else None

View File

@ -21,7 +21,6 @@ from attr import dataclass
from mautrix.types import SerializableAttrs from mautrix.types import SerializableAttrs
@dataclass
class KnownKickoutType(IntEnum): class KnownKickoutType(IntEnum):
CHANGE_SERVER = -2 CHANGE_SERVER = -2
LOGIN_ANOTHER = 0 LOGIN_ANOTHER = 0

View File

@ -21,8 +21,10 @@ from typing import (
AsyncGenerator, AsyncGenerator,
Awaitable, Awaitable,
Callable, Callable,
NamedTuple, Coroutine,
Generic,
Pattern, Pattern,
TypeVar,
cast, cast,
) )
from io import BytesIO from io import BytesIO
@ -31,6 +33,8 @@ import asyncio
import re import re
import time import time
from attr import dataclass
from mautrix.appservice import IntentAPI from mautrix.appservice import IntentAPI
from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock
from mautrix.errors import MatrixError, MForbidden, MNotFound, SessionNotFound from mautrix.errors import MatrixError, MForbidden, MNotFound, SessionNotFound
@ -125,11 +129,15 @@ class FakeLock:
pass pass
class StateEventHandler(NamedTuple): T = TypeVar("T")
# TODO Can this use Generic to force the two StateEventContent parameters to be of the same type? ACallable = Coroutine[None, None, T]
# Or, just have a single StateEvent parameter
apply: Callable[[u.User, StateEventContent, StateEventContent], Awaitable[None]] StateEventHandlerContentType = TypeVar("StateEventHandlerContentType", bound=StateEventContent)
revert: Callable[[StateEventContent], Awaitable[None]]
@dataclass
class StateEventHandler(Generic[StateEventHandlerContentType]):
apply: Callable[[Portal, u.User, StateEventHandlerContentType, StateEventHandlerContentType], ACallable[None]]
revert: Callable[[Portal, StateEventHandlerContentType], ACallable[None]]
action_name: str action_name: str
@ -155,6 +163,9 @@ class Portal(DBPortal, BasePortal):
_scheduled_resync: asyncio.Task | None _scheduled_resync: asyncio.Task | None
_resync_targets: dict[int, p.Puppet] _resync_targets: dict[int, p.Puppet]
_CHAT_TYPE_HANDLER_MAP: dict[ChatType, Callable[..., ACallable[list[EventID]]]]
_STATE_EVENT_HANDLER_MAP: dict[EventType, StateEventHandler]
def __init__( def __init__(
self, self,
ktid: Long, ktid: Long,
@ -226,7 +237,7 @@ class Portal(DBPortal, BasePortal):
16385: cls._handle_kakaotalk_deleted, 16385: cls._handle_kakaotalk_deleted,
} }
cls._STATE_EVENT_HANDLER_MAP: dict[EventType, StateEventHandler] = { cls._STATE_EVENT_HANDLER_MAP = {
EventType.ROOM_POWER_LEVELS: StateEventHandler( EventType.ROOM_POWER_LEVELS: StateEventHandler(
cls._handle_matrix_power_levels, cls._handle_matrix_power_levels,
cls._revert_matrix_power_levels, cls._revert_matrix_power_levels,
@ -506,7 +517,7 @@ class Portal(DBPortal, BasePortal):
decryption_info.url = url decryption_info.url = url
return url, info, decryption_info return url, info, decryption_info
async def _update_name(self, name: str) -> bool: async def _update_name(self, name: str | None) -> bool:
if not name: if not name:
self.log.warning("Got empty name in _update_name call") self.log.warning("Got empty name in _update_name call")
return False return False
@ -593,6 +604,8 @@ class Portal(DBPortal, BasePortal):
return False return False
if not puppet: if not puppet:
puppet = await self.get_dm_puppet() puppet = await self.get_dm_puppet()
if not puppet:
return False
changed = await self._update_name(puppet.name) changed = await self._update_name(puppet.name)
changed = await self._update_photo_from_puppet(puppet) or changed changed = await self._update_photo_from_puppet(puppet) or changed
return changed return changed
@ -815,15 +828,20 @@ class Portal(DBPortal, BasePortal):
async def _create_matrix_room( async def _create_matrix_room(
self, source: u.User, info: PortalChannelInfo | None = None self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID | None: ) -> RoomID:
if self.mxid: if self.mxid:
await self._update_matrix_room(source, info) await self._update_matrix_room(source, info)
return self.mxid return self.mxid
self.log.debug(f"Creating Matrix room") self.log.debug(f"Creating Matrix room")
if self.is_direct: if self.is_direct:
# NOTE Must do this to find the other member of the DM, since the channel ID != the member's ID! # NOTE Must do this to find the other member of the DM, since the channel ID != the member's ID!
if not info or not info.participantInfo:
info = await source.client.get_portal_channel_info(self.channel_props)
assert info.participantInfo
await self._update_participants(source, info.participantInfo) await self._update_participants(source, info.participantInfo)
name: str | None = None name: str | None = None
description: str | None = None description: str | None = None
initial_state = [ initial_state = [
@ -843,6 +861,9 @@ class Portal(DBPortal, BasePortal):
if self.is_open: if self.is_open:
preset = RoomCreatePreset.PUBLIC preset = RoomCreatePreset.PUBLIC
# TODO Find whether perms apply to any non-direct channel, or just open ones # TODO Find whether perms apply to any non-direct channel, or just open ones
if not info or not info.participantInfo:
info = await source.client.get_portal_channel_info(self.channel_props)
assert info.participantInfo
user_power_levels = await self._get_mapped_participant_power_levels(info.participantInfo.participants) user_power_levels = await self._get_mapped_participant_power_levels(info.participantInfo.participants)
# NOTE Giving the bot a +1 power level if necessary so it can demote non-puppet admins # NOTE Giving the bot a +1 power level if necessary so it can demote non-puppet admins
user_power_levels[self.main_intent.mxid] = max(100, 1 + FROM_PERM_MAP[OpenChannelUserPerm.OWNER]) user_power_levels[self.main_intent.mxid] = max(100, 1 + FROM_PERM_MAP[OpenChannelUserPerm.OWNER])
@ -924,6 +945,8 @@ class Portal(DBPortal, BasePortal):
await self._update_participants(source, info.participantInfo) await self._update_participants(source, info.participantInfo)
try: try:
# TODO Think of better typing for this
assert info.channel_info
await self.backfill(source, is_initial=True, channel_info=info.channel_info) await self.backfill(source, is_initial=True, channel_info=info.channel_info)
except Exception: except Exception:
self.log.exception("Failed to backfill new portal") self.log.exception("Failed to backfill new portal")
@ -1075,10 +1098,11 @@ class Portal(DBPortal, BasePortal):
mimetype = message.info.mimetype or magic.mimetype(data) mimetype = message.info.mimetype or magic.mimetype(data)
filename = message.body filename = message.body
width, height = None, None width, height = None, None
if message.info in (MessageType.IMAGE, MessageType.STICKER, MessageType.VIDEO): if message.msgtype in (MessageType.IMAGE, MessageType.STICKER, MessageType.VIDEO):
width = message.info.width width = message.info.width
height = message.info.height height = message.info.height
try: try:
ext = guess_extension(mimetype)
log_id = await sender.client.send_media( log_id = await sender.client.send_media(
self.channel_props, self.channel_props,
TO_MSGTYPE_MAP[message.msgtype], TO_MSGTYPE_MAP[message.msgtype],
@ -1086,7 +1110,7 @@ class Portal(DBPortal, BasePortal):
filename, filename,
width=width, width=width,
height=height, height=height,
ext=guess_extension(mimetype)[1:], ext=ext[1:] if ext else "",
) )
except CommandException as e: except CommandException as e:
self.log.debug(f"Error uploading media for Matrix message {event_id}: {e!s}") self.log.debug(f"Error uploading media for Matrix message {event_id}: {e!s}")
@ -1191,7 +1215,8 @@ class Portal(DBPortal, BasePortal):
return return
try: try:
effective_sender, _ = await self.get_relay_sender(sender, f"{handler.action_name} {evt.event_id}") effective_sender, _ = await self.get_relay_sender(sender, f"{handler.action_name} {evt.event_id}")
await handler.apply(self, effective_sender, evt.prev_content, evt.content) if effective_sender:
await handler.apply(self, effective_sender, evt.prev_content, evt.content)
except Exception as e: except Exception as e:
self.log.error( self.log.error(
f"Failed to handle Matrix {handler.action_name} {evt.event_id}: {e}", f"Failed to handle Matrix {handler.action_name} {evt.event_id}: {e}",
@ -1314,7 +1339,7 @@ class Portal(DBPortal, BasePortal):
await self.save() await self.save()
async def _revert_matrix_room_topic(self, prev_content: RoomTopicStateEventContent) -> None: async def _revert_matrix_room_topic(self, prev_content: RoomTopicStateEventContent) -> None:
await self.main_intent.set_room_topic(self.mxid, prev_content.topic) await self.main_intent.set_room_topic(self.mxid, prev_content.topic or "")
async def _handle_matrix_room_avatar( async def _handle_matrix_room_avatar(
self, self,
@ -1479,7 +1504,7 @@ class Portal(DBPortal, BasePortal):
chat_text: str | None, chat_text: str | None,
chat_type: ChatType, chat_type: ChatType,
**_ **_
) -> Awaitable[list[EventID]]: ) -> list[EventID]:
try: try:
type_str = KnownChatType(chat_type).name.lower() type_str = KnownChatType(chat_type).name.lower()
except ValueError: except ValueError:
@ -1546,14 +1571,14 @@ class Portal(DBPortal, BasePortal):
await self._add_kakaotalk_reply(content, attachment) await self._add_kakaotalk_reply(content, attachment)
return [await self._send_message(intent, content, timestamp=timestamp)] return [await self._send_message(intent, content, timestamp=timestamp)]
def _handle_kakaotalk_photo(self, **kwargs) -> Awaitable[list[EventID]]: async def _handle_kakaotalk_photo(self, **kwargs) -> list[EventID]:
return asyncio.gather(self._handle_kakaotalk_uniphoto(**kwargs)) return [await self._handle_kakaotalk_uniphoto(**kwargs)]
async def _handle_kakaotalk_multiphoto( async def _handle_kakaotalk_multiphoto(
self, self,
attachment: MultiPhotoAttachment, attachment: MultiPhotoAttachment,
**kwargs **kwargs
) -> Awaitable[list[EventID]]: ) -> list[EventID]:
# TODO Upload media concurrently, but post messages sequentially # TODO Upload media concurrently, but post messages sequentially
return [ return [
await self._handle_kakaotalk_uniphoto( await self._handle_kakaotalk_uniphoto(
@ -1594,12 +1619,12 @@ class Portal(DBPortal, BasePortal):
**kwargs **kwargs
) )
def _handle_kakaotalk_video( async def _handle_kakaotalk_video(
self, self,
attachment: VideoAttachment, attachment: VideoAttachment,
**kwargs **kwargs
) -> Awaitable[list[EventID]]: ) -> list[EventID]:
return asyncio.gather(self._handle_kakaotalk_media( return [await self._handle_kakaotalk_media(
attachment, attachment,
VideoInfo( VideoInfo(
duration=attachment.d, duration=attachment.d,
@ -1608,14 +1633,14 @@ class Portal(DBPortal, BasePortal):
), ),
MessageType.VIDEO, MessageType.VIDEO,
**kwargs **kwargs
)) )]
def _handle_kakaotalk_audio( async def _handle_kakaotalk_audio(
self, self,
attachment: AudioAttachment, attachment: AudioAttachment,
**kwargs **kwargs
) -> Awaitable[list[EventID]]: ) -> list[EventID]:
return asyncio.gather(self._handle_kakaotalk_media( return [await self._handle_kakaotalk_media(
attachment, attachment,
AudioInfo( AudioInfo(
size=attachment.s, size=attachment.s,
@ -1623,11 +1648,11 @@ class Portal(DBPortal, BasePortal):
), ),
MessageType.AUDIO, MessageType.AUDIO,
**kwargs **kwargs
)) )]
async def _handle_kakaotalk_media( async def _handle_kakaotalk_media(
self, self,
attachment: MediaAttachment, attachment: MediaAttachment | AudioAttachment,
info: MediaInfo, info: MediaInfo,
msgtype: MessageType, msgtype: MessageType,
*, *,
@ -1648,7 +1673,7 @@ class Portal(DBPortal, BasePortal):
info.size = additional_info.size info.size = additional_info.size
info.mimetype = additional_info.mimetype info.mimetype = additional_info.mimetype
content = MediaMessageEventContent( content = MediaMessageEventContent(
url=mxc, file=decryption_info, msgtype=msgtype, body=chat_text, info=info url=mxc, file=decryption_info, msgtype=msgtype, body=chat_text or "", info=info
) )
return await self._send_message(intent, content, timestamp=timestamp) return await self._send_message(intent, content, timestamp=timestamp)
@ -1787,7 +1812,7 @@ class Portal(DBPortal, BasePortal):
# TODO Should this be removed? With it, can't sync an empty portal! # TODO Should this be removed? With it, can't sync an empty portal!
#elif (not most_recent or not most_recent.timestamp) and not is_initial: #elif (not most_recent or not most_recent.timestamp) and not is_initial:
# self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log) # self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log)
elif last_log_id and most_recent and int(most_recent.ktid) >= int(last_log_id): elif last_log_id and most_recent and int(most_recent.ktid or 0) >= int(last_log_id):
self.log.debug( self.log.debug(
"Not backfilling %s: last activity is equal to most recent bridged " "Not backfilling %s: last activity is equal to most recent bridged "
"message (%s >= %s)", "message (%s >= %s)",

View File

@ -89,7 +89,7 @@ class User(DBUser, BaseUser):
by_mxid: dict[UserID, User] = {} by_mxid: dict[UserID, User] = {}
by_ktid: dict[int, User] = {} by_ktid: dict[int, User] = {}
client: Client | None _client: Client | None
_notice_room_lock: asyncio.Lock _notice_room_lock: asyncio.Lock
_notice_send_lock: asyncio.Lock _notice_send_lock: asyncio.Lock
@ -141,7 +141,7 @@ class User(DBUser, BaseUser):
self._logged_in_info = None self._logged_in_info = None
self._logged_in_info_time = 0 self._logged_in_info_time = 0
self.client = None self._client = None
@classmethod @classmethod
def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]: def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]:
@ -151,6 +151,12 @@ class User(DBUser, BaseUser):
cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"] cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in()) return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
@property
def client(self) -> Client:
if not self._client:
raise ValueError("User must be logged in before its client can be used")
return self._client
@property @property
def is_connected(self) -> bool | None: def is_connected(self) -> bool | None:
return self._is_connected return self._is_connected
@ -242,7 +248,10 @@ class User(DBUser, BaseUser):
@property @property
def oauth_credential(self) -> OAuthCredential: def oauth_credential(self) -> OAuthCredential:
assert None not in (self.ktid, self.uuid, self.access_token, self.refresh_token) assert self.ktid is not None
assert self.uuid is not None
assert self.access_token is not None
assert self.refresh_token is not None
return OAuthCredential( return OAuthCredential(
self.ktid, self.ktid,
self.uuid, self.uuid,
@ -259,9 +268,9 @@ class User(DBUser, BaseUser):
self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}") self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
self.uuid = oauth_credential.deviceUUID self.uuid = oauth_credential.deviceUUID
async def get_own_info(self) -> SettingsStruct: async def get_own_info(self) -> SettingsStruct | None:
if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic(): if self._client and (not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic()):
self._logged_in_info = await self.client.get_settings() self._logged_in_info = await self._client.get_settings()
self._logged_in_info_time = time.monotonic() self._logged_in_info_time = time.monotonic()
return self._logged_in_info return self._logged_in_info
@ -280,7 +289,7 @@ class User(DBUser, BaseUser):
user_info = await client.start() user_info = await client.start()
# NOTE On failure, client.start throws instead of returning something falsy # NOTE On failure, client.start throws instead of returning something falsy
self.log.info("Loaded session successfully") self.log.info("Loaded session successfully")
self.client = client self._client = client
self._logged_in_info = user_info self._logged_in_info = user_info
self._logged_in_info_time = time.monotonic() self._logged_in_info_time = time.monotonic()
self._track_metric(METRIC_LOGGED_IN, True) self._track_metric(METRIC_LOGGED_IN, True)
@ -304,7 +313,7 @@ class User(DBUser, BaseUser):
await self.logout(remove_ktid=False) await self.logout(remove_ktid=False)
async def is_logged_in(self, _override: bool = False) -> bool: async def is_logged_in(self, _override: bool = False) -> bool:
if not self.has_state or not self.client: if not self.has_state or not self._client:
return False return False
if self._is_logged_in is None or _override: if self._is_logged_in is None or _override:
try: try:
@ -360,9 +369,9 @@ class User(DBUser, BaseUser):
self._is_rpc_reconnecting = False self._is_rpc_reconnecting = False
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None: async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self.client: if self._client:
# TODO Look for a logout API call # TODO Look for a logout API call
await self.client.stop() await self._client.stop()
if remove_ktid: if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT) await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
self._track_metric(METRIC_LOGGED_IN, False) self._track_metric(METRIC_LOGGED_IN, False)
@ -374,7 +383,7 @@ class User(DBUser, BaseUser):
self._is_logged_in = False self._is_logged_in = False
self.is_connected = None self.is_connected = None
self.client = None self._client = None
if self.ktid and remove_ktid: if self.ktid and remove_ktid:
#await UserPortal.delete_all(self.ktid) #await UserPortal.delete_all(self.ktid)
@ -581,12 +590,12 @@ class User(DBUser, BaseUser):
state.remote_name = puppet.name state.remote_name = puppet.name
async def get_bridge_states(self) -> list[BridgeState]: async def get_bridge_states(self) -> list[BridgeState]:
if not self.state: if not self.has_state:
return [] return []
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR) state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected: if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED state.state_event = BridgeStateEvent.CONNECTED
elif self._is_rpc_reconnecting or self.client: elif self._is_rpc_reconnecting or self._client:
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state] return [state]
@ -660,8 +669,8 @@ class User(DBUser, BaseUser):
await self.logout() await self.logout()
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}") await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
def on_error(self, error: JSON) -> Awaitable[None]: async def on_error(self, error: JSON) -> None:
return self.send_bridge_notice( await self.send_bridge_notice(
f"Got error event from KakaoTalk:\n\n> {error}", f"Got error event from KakaoTalk:\n\n> {error}",
# TODO Which error code to use? # TODO Which error code to use?
#error_code="kt-connection-error", #error_code="kt-connection-error",
@ -671,7 +680,7 @@ class User(DBUser, BaseUser):
async def on_client_disconnect(self) -> None: async def on_client_disconnect(self) -> None:
self.is_connected = False self.is_connected = False
self._track_metric(METRIC_CONNECTED, False) self._track_metric(METRIC_CONNECTED, False)
self.client = None self._client = None
if self._is_logged_in: if self._is_logged_in:
if self.temp_disconnect_notices: if self.temp_disconnect_notices:
await self.send_bridge_notice( await self.send_bridge_notice(
@ -685,12 +694,12 @@ class User(DBUser, BaseUser):
self.log.debug(f"Successfully logged in as {oauth_credential.userId}") self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential self.oauth_credential = oauth_credential
await self.push_bridge_state(BridgeStateEvent.CONNECTING) await self.push_bridge_state(BridgeStateEvent.CONNECTING)
self.client = Client(self, log=self.log.getChild("ktclient")) self._client = Client(self, log=self.log.getChild("ktclient"))
await self.save() await self.save()
self._is_logged_in = True self._is_logged_in = True
# TODO Retry network connection failures here, or in the client (like token refreshes are)? # TODO Retry network connection failures here, or in the client (like token refreshes are)?
# Should also catch unlikely authentication errors # Should also catch unlikely authentication errors
self._logged_in_info = await self.client.start() self._logged_in_info = await self._client.start()
self._logged_in_info_time = time.monotonic() self._logged_in_info_time = time.monotonic()
asyncio.create_task(self.post_login(is_startup=True)) asyncio.create_task(self.post_login(is_startup=True))