Many fixes thanks to mypy
Also add some missing license headers
This commit is contained in:
parent
2143282195
commit
e952c05d35
|
@ -1,3 +1,18 @@
|
||||||
|
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from .auth import SECTION_AUTH
|
from .auth import SECTION_AUTH
|
||||||
from .conn import SECTION_CONNECTION
|
from .conn import SECTION_CONNECTION
|
||||||
from .kakaotalk import SECTION_FRIENDS
|
from .kakaotalk import SECTION_FRIENDS
|
||||||
|
|
|
@ -60,16 +60,19 @@ async def whoami(evt: CommandEvent) -> None:
|
||||||
await evt.mark_read()
|
await evt.mark_read()
|
||||||
try:
|
try:
|
||||||
own_info = await evt.sender.get_own_info()
|
own_info = await evt.sender.get_own_info()
|
||||||
|
except SerializerError:
|
||||||
|
evt.sender.log.exception("Failed to deserialize settings struct")
|
||||||
|
own_info = None
|
||||||
|
except CommandException as e:
|
||||||
|
await evt.reply(f"Error from KakaoTalk: {e}")
|
||||||
|
if own_info:
|
||||||
await evt.reply(
|
await evt.reply(
|
||||||
f"You're logged in as `{own_info.more.uuid}` (nickname: {own_info.more.nickName}, user ID: {evt.sender.ktid})."
|
f"You're logged in as `{own_info.more.uuid}` (nickname: {own_info.more.nickName}, user ID: {evt.sender.ktid})."
|
||||||
)
|
)
|
||||||
except SerializerError:
|
else:
|
||||||
evt.sender.log.exception("Failed to deserialize settings struct")
|
|
||||||
await evt.reply(
|
await evt.reply(
|
||||||
f"You're logged in, but the bridge is unable to retrieve your profile information (user ID: {evt.sender.ktid})."
|
f"You're logged in, but the bridge is unable to retrieve your profile information (user ID: {evt.sender.ktid})."
|
||||||
)
|
)
|
||||||
except CommandException as e:
|
|
||||||
await evt.reply(f"Error from KakaoTalk: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@command_handler(
|
@command_handler(
|
||||||
|
|
|
@ -1,12 +1,31 @@
|
||||||
|
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from mautrix.bridge.commands import CommandEvent as BaseCommandEvent
|
from mautrix.bridge.commands import CommandEvent as BaseCommandEvent
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..__main__ import KakaoTalkBridge
|
from ..__main__ import KakaoTalkBridge
|
||||||
|
from ..portal import Portal
|
||||||
from ..user import User
|
from ..user import User
|
||||||
|
|
||||||
|
|
||||||
class CommandEvent(BaseCommandEvent):
|
class CommandEvent(BaseCommandEvent):
|
||||||
bridge: "KakaoTalkBridge"
|
bridge: KakaoTalkBridge
|
||||||
sender: "User"
|
portal: Portal
|
||||||
|
sender: User
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from mautrix.util.async_db import Database
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
from .message import Message
|
from .message import Message
|
||||||
|
|
|
@ -23,7 +23,7 @@ from attr import dataclass, field
|
||||||
from mautrix.types import EventID, RoomID
|
from mautrix.types import EventID, RoomID
|
||||||
from mautrix.util.async_db import Database, Scheme
|
from mautrix.util.async_db import Database, Scheme
|
||||||
|
|
||||||
from ..kt.types.bson import Long
|
from ..kt.types.bson import Long, to_optional_long
|
||||||
|
|
||||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class Message:
|
||||||
|
|
||||||
mxid: EventID
|
mxid: EventID
|
||||||
mx_room: RoomID
|
mx_room: RoomID
|
||||||
ktid: Long | None = field(converter=lambda ktid: Long(ktid) if ktid is not None else None)
|
ktid: Long | None = field(converter=to_optional_long)
|
||||||
index: int
|
index: int
|
||||||
kt_chat: Long = field(converter=Long)
|
kt_chat: Long = field(converter=Long)
|
||||||
kt_receiver: Long = field(converter=Long)
|
kt_receiver: Long = field(converter=Long)
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from mautrix.util.async_db import UpgradeTable
|
from mautrix.util.async_db import UpgradeTable
|
||||||
|
|
||||||
upgrade_table = UpgradeTable()
|
upgrade_table = UpgradeTable()
|
||||||
|
|
|
@ -23,7 +23,7 @@ from attr import dataclass, field
|
||||||
from mautrix.types import RoomID, UserID
|
from mautrix.types import RoomID, UserID
|
||||||
from mautrix.util.async_db import Database
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
from ..kt.types.bson import Long
|
from ..kt.types.bson import Long, to_optional_long
|
||||||
|
|
||||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class User:
|
||||||
db: ClassVar[Database] = fake_db
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
mxid: UserID
|
mxid: UserID
|
||||||
ktid: Long | None = field(converter=lambda x: Long(x) if x is not None else None)
|
ktid: Long | None = field(converter=to_optional_long)
|
||||||
uuid: str | None
|
uuid: str | None
|
||||||
access_token: str | None
|
access_token: str | None
|
||||||
refresh_token: str | None
|
refresh_token: str | None
|
||||||
|
|
|
@ -1,2 +1,17 @@
|
||||||
|
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from .from_kakaotalk import kakaotalk_to_matrix
|
from .from_kakaotalk import kakaotalk_to_matrix
|
||||||
from .from_matrix import matrix_to_kakaotalk
|
from .from_matrix import matrix_to_kakaotalk
|
||||||
|
|
|
@ -111,7 +111,7 @@ async def matrix_to_kakaotalk(
|
||||||
# NOTE By design, this *throws* if user intent can't be matched (i.e. if a reply can't be created)
|
# NOTE By design, this *throws* if user intent can't be matched (i.e. if a reply can't be created)
|
||||||
if content.relates_to.rel_type == RelationType.REPLY and not skip_reply:
|
if content.relates_to.rel_type == RelationType.REPLY and not skip_reply:
|
||||||
message = await DBMessage.get_by_mxid(content.relates_to.event_id, room_id)
|
message = await DBMessage.get_by_mxid(content.relates_to.event_id, room_id)
|
||||||
if not message:
|
if not message or not message.ktid:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Couldn't find reply target {content.relates_to.event_id}"
|
f"Couldn't find reply target {content.relates_to.event_id}"
|
||||||
" to bridge text message reply metadata to KakaoTalk"
|
" to bridge text message reply metadata to KakaoTalk"
|
||||||
|
@ -167,17 +167,17 @@ async def matrix_to_kakaotalk(
|
||||||
mxid = mention.extra_info["user_id"]
|
mxid = mention.extra_info["user_id"]
|
||||||
if mxid not in joined_members:
|
if mxid not in joined_members:
|
||||||
continue
|
continue
|
||||||
ktid = await _get_id_from_mxid(mxid)
|
ktid = await _get_id_from_mxid(mxid, portal)
|
||||||
if ktid is None:
|
if ktid is None:
|
||||||
continue
|
continue
|
||||||
at += text[last_offset:mention.offset+1].count("@")
|
at += text[last_offset:mention.offset+1].count("@")
|
||||||
last_offset = mention.offset+1
|
last_offset = mention.offset+1
|
||||||
mention = mentions_by_user.setdefault(ktid, MentionStruct(
|
mention_by_user = mentions_by_user.setdefault(ktid, MentionStruct(
|
||||||
at=[],
|
at=[],
|
||||||
len=mention.length,
|
len=mention.length,
|
||||||
user_id=ktid,
|
user_id=ktid,
|
||||||
))
|
))
|
||||||
mention.at.append(at)
|
mention_by_user.at.append(at)
|
||||||
mentions = list(mentions_by_user.values()) if mentions_by_user else None
|
mentions = list(mentions_by_user.values()) if mentions_by_user else None
|
||||||
else:
|
else:
|
||||||
text = content.body
|
text = content.body
|
||||||
|
|
|
@ -137,7 +137,7 @@ class Client:
|
||||||
return cls._api_request_void("register_device", passcode=passcode, **req)
|
return cls._api_request_void("register_device", passcode=passcode, **req)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def login(cls, **req: JSON) -> Awaitable[OAuthCredential]:
|
def login(cls, **req: JSON) -> Awaitable[OAuthCredential]:
|
||||||
"""
|
"""
|
||||||
Obtain a session token by logging in with user-provided credentials.
|
Obtain a session token by logging in with user-provided credentials.
|
||||||
Must have first called register_device with these credentials.
|
Must have first called register_device with these credentials.
|
||||||
|
@ -493,7 +493,7 @@ class Client:
|
||||||
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
return await self._api_user_request_result(
|
await self._api_user_request_void(
|
||||||
command, oauth_credential=self._oauth_credential, renew=False, **data
|
command, oauth_credential=self._oauth_credential, renew=False, **data
|
||||||
)
|
)
|
||||||
except InvalidAccessToken:
|
except InvalidAccessToken:
|
||||||
|
@ -599,7 +599,7 @@ class Client:
|
||||||
res = None
|
res = None
|
||||||
return self._on_disconnect(res)
|
return self._on_disconnect(res)
|
||||||
|
|
||||||
def _on_switch_server(self) -> Awaitable[None]:
|
def _on_switch_server(self, _: dict[str, JSON]) -> Awaitable[None]:
|
||||||
# TODO Reconnect automatically instead
|
# TODO Reconnect automatically instead
|
||||||
return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
|
return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ class ResponseError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
_status_code_message_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int] = {
|
_status_code_message_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int, str] = {
|
||||||
KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number",
|
KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number",
|
||||||
KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success",
|
KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success",
|
||||||
KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)",
|
KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)",
|
||||||
|
|
|
@ -26,9 +26,9 @@ from ...bson import Long
|
||||||
class OpenChatSettingsStruct(SerializableAttrs):
|
class OpenChatSettingsStruct(SerializableAttrs):
|
||||||
chatMemberMaxJoin: int
|
chatMemberMaxJoin: int
|
||||||
chatRoomMaxJoin: int
|
chatRoomMaxJoin: int
|
||||||
createLinkLimit: 10;
|
createLinkLimit = 10
|
||||||
createCardLinkLimit: 3;
|
createCardLinkLimit = 3
|
||||||
numOfStaffLimit: 5;
|
numOfStaffLimit = 5
|
||||||
rewritable: bool
|
rewritable: bool
|
||||||
handoverEnabled: bool
|
handoverEnabled: bool
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from mautrix.types import Serializable, JSON
|
from mautrix.types import Serializable, JSON
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,3 +25,6 @@ class Long(int, Serializable):
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, raw: JSON) -> "Long":
|
def deserialize(cls, raw: JSON) -> "Long":
|
||||||
return cls(raw)
|
return cls(raw)
|
||||||
|
|
||||||
|
def to_optional_long(x: Any | None) -> Long | None:
|
||||||
|
return Long(x) if x is not None else None
|
||||||
|
|
|
@ -21,7 +21,6 @@ from attr import dataclass
|
||||||
from mautrix.types import SerializableAttrs
|
from mautrix.types import SerializableAttrs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class KnownKickoutType(IntEnum):
|
class KnownKickoutType(IntEnum):
|
||||||
CHANGE_SERVER = -2
|
CHANGE_SERVER = -2
|
||||||
LOGIN_ANOTHER = 0
|
LOGIN_ANOTHER = 0
|
||||||
|
|
|
@ -21,8 +21,10 @@ from typing import (
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
NamedTuple,
|
Coroutine,
|
||||||
|
Generic,
|
||||||
Pattern,
|
Pattern,
|
||||||
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
@ -31,6 +33,8 @@ import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from attr import dataclass
|
||||||
|
|
||||||
from mautrix.appservice import IntentAPI
|
from mautrix.appservice import IntentAPI
|
||||||
from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock
|
from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock
|
||||||
from mautrix.errors import MatrixError, MForbidden, MNotFound, SessionNotFound
|
from mautrix.errors import MatrixError, MForbidden, MNotFound, SessionNotFound
|
||||||
|
@ -125,11 +129,15 @@ class FakeLock:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StateEventHandler(NamedTuple):
|
T = TypeVar("T")
|
||||||
# TODO Can this use Generic to force the two StateEventContent parameters to be of the same type?
|
ACallable = Coroutine[None, None, T]
|
||||||
# Or, just have a single StateEvent parameter
|
|
||||||
apply: Callable[[u.User, StateEventContent, StateEventContent], Awaitable[None]]
|
StateEventHandlerContentType = TypeVar("StateEventHandlerContentType", bound=StateEventContent)
|
||||||
revert: Callable[[StateEventContent], Awaitable[None]]
|
|
||||||
|
@dataclass
|
||||||
|
class StateEventHandler(Generic[StateEventHandlerContentType]):
|
||||||
|
apply: Callable[[Portal, u.User, StateEventHandlerContentType, StateEventHandlerContentType], ACallable[None]]
|
||||||
|
revert: Callable[[Portal, StateEventHandlerContentType], ACallable[None]]
|
||||||
action_name: str
|
action_name: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,6 +163,9 @@ class Portal(DBPortal, BasePortal):
|
||||||
_scheduled_resync: asyncio.Task | None
|
_scheduled_resync: asyncio.Task | None
|
||||||
_resync_targets: dict[int, p.Puppet]
|
_resync_targets: dict[int, p.Puppet]
|
||||||
|
|
||||||
|
_CHAT_TYPE_HANDLER_MAP: dict[ChatType, Callable[..., ACallable[list[EventID]]]]
|
||||||
|
_STATE_EVENT_HANDLER_MAP: dict[EventType, StateEventHandler]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ktid: Long,
|
ktid: Long,
|
||||||
|
@ -226,7 +237,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
16385: cls._handle_kakaotalk_deleted,
|
16385: cls._handle_kakaotalk_deleted,
|
||||||
}
|
}
|
||||||
|
|
||||||
cls._STATE_EVENT_HANDLER_MAP: dict[EventType, StateEventHandler] = {
|
cls._STATE_EVENT_HANDLER_MAP = {
|
||||||
EventType.ROOM_POWER_LEVELS: StateEventHandler(
|
EventType.ROOM_POWER_LEVELS: StateEventHandler(
|
||||||
cls._handle_matrix_power_levels,
|
cls._handle_matrix_power_levels,
|
||||||
cls._revert_matrix_power_levels,
|
cls._revert_matrix_power_levels,
|
||||||
|
@ -506,7 +517,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
decryption_info.url = url
|
decryption_info.url = url
|
||||||
return url, info, decryption_info
|
return url, info, decryption_info
|
||||||
|
|
||||||
async def _update_name(self, name: str) -> bool:
|
async def _update_name(self, name: str | None) -> bool:
|
||||||
if not name:
|
if not name:
|
||||||
self.log.warning("Got empty name in _update_name call")
|
self.log.warning("Got empty name in _update_name call")
|
||||||
return False
|
return False
|
||||||
|
@ -593,6 +604,8 @@ class Portal(DBPortal, BasePortal):
|
||||||
return False
|
return False
|
||||||
if not puppet:
|
if not puppet:
|
||||||
puppet = await self.get_dm_puppet()
|
puppet = await self.get_dm_puppet()
|
||||||
|
if not puppet:
|
||||||
|
return False
|
||||||
changed = await self._update_name(puppet.name)
|
changed = await self._update_name(puppet.name)
|
||||||
changed = await self._update_photo_from_puppet(puppet) or changed
|
changed = await self._update_photo_from_puppet(puppet) or changed
|
||||||
return changed
|
return changed
|
||||||
|
@ -815,15 +828,20 @@ class Portal(DBPortal, BasePortal):
|
||||||
|
|
||||||
async def _create_matrix_room(
|
async def _create_matrix_room(
|
||||||
self, source: u.User, info: PortalChannelInfo | None = None
|
self, source: u.User, info: PortalChannelInfo | None = None
|
||||||
) -> RoomID | None:
|
) -> RoomID:
|
||||||
if self.mxid:
|
if self.mxid:
|
||||||
await self._update_matrix_room(source, info)
|
await self._update_matrix_room(source, info)
|
||||||
return self.mxid
|
return self.mxid
|
||||||
|
|
||||||
self.log.debug(f"Creating Matrix room")
|
self.log.debug(f"Creating Matrix room")
|
||||||
|
|
||||||
if self.is_direct:
|
if self.is_direct:
|
||||||
# NOTE Must do this to find the other member of the DM, since the channel ID != the member's ID!
|
# NOTE Must do this to find the other member of the DM, since the channel ID != the member's ID!
|
||||||
|
if not info or not info.participantInfo:
|
||||||
|
info = await source.client.get_portal_channel_info(self.channel_props)
|
||||||
|
assert info.participantInfo
|
||||||
await self._update_participants(source, info.participantInfo)
|
await self._update_participants(source, info.participantInfo)
|
||||||
|
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
initial_state = [
|
initial_state = [
|
||||||
|
@ -843,6 +861,9 @@ class Portal(DBPortal, BasePortal):
|
||||||
if self.is_open:
|
if self.is_open:
|
||||||
preset = RoomCreatePreset.PUBLIC
|
preset = RoomCreatePreset.PUBLIC
|
||||||
# TODO Find whether perms apply to any non-direct channel, or just open ones
|
# TODO Find whether perms apply to any non-direct channel, or just open ones
|
||||||
|
if not info or not info.participantInfo:
|
||||||
|
info = await source.client.get_portal_channel_info(self.channel_props)
|
||||||
|
assert info.participantInfo
|
||||||
user_power_levels = await self._get_mapped_participant_power_levels(info.participantInfo.participants)
|
user_power_levels = await self._get_mapped_participant_power_levels(info.participantInfo.participants)
|
||||||
# NOTE Giving the bot a +1 power level if necessary so it can demote non-puppet admins
|
# NOTE Giving the bot a +1 power level if necessary so it can demote non-puppet admins
|
||||||
user_power_levels[self.main_intent.mxid] = max(100, 1 + FROM_PERM_MAP[OpenChannelUserPerm.OWNER])
|
user_power_levels[self.main_intent.mxid] = max(100, 1 + FROM_PERM_MAP[OpenChannelUserPerm.OWNER])
|
||||||
|
@ -924,6 +945,8 @@ class Portal(DBPortal, BasePortal):
|
||||||
await self._update_participants(source, info.participantInfo)
|
await self._update_participants(source, info.participantInfo)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# TODO Think of better typing for this
|
||||||
|
assert info.channel_info
|
||||||
await self.backfill(source, is_initial=True, channel_info=info.channel_info)
|
await self.backfill(source, is_initial=True, channel_info=info.channel_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to backfill new portal")
|
self.log.exception("Failed to backfill new portal")
|
||||||
|
@ -1075,10 +1098,11 @@ class Portal(DBPortal, BasePortal):
|
||||||
mimetype = message.info.mimetype or magic.mimetype(data)
|
mimetype = message.info.mimetype or magic.mimetype(data)
|
||||||
filename = message.body
|
filename = message.body
|
||||||
width, height = None, None
|
width, height = None, None
|
||||||
if message.info in (MessageType.IMAGE, MessageType.STICKER, MessageType.VIDEO):
|
if message.msgtype in (MessageType.IMAGE, MessageType.STICKER, MessageType.VIDEO):
|
||||||
width = message.info.width
|
width = message.info.width
|
||||||
height = message.info.height
|
height = message.info.height
|
||||||
try:
|
try:
|
||||||
|
ext = guess_extension(mimetype)
|
||||||
log_id = await sender.client.send_media(
|
log_id = await sender.client.send_media(
|
||||||
self.channel_props,
|
self.channel_props,
|
||||||
TO_MSGTYPE_MAP[message.msgtype],
|
TO_MSGTYPE_MAP[message.msgtype],
|
||||||
|
@ -1086,7 +1110,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
filename,
|
filename,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
ext=guess_extension(mimetype)[1:],
|
ext=ext[1:] if ext else "",
|
||||||
)
|
)
|
||||||
except CommandException as e:
|
except CommandException as e:
|
||||||
self.log.debug(f"Error uploading media for Matrix message {event_id}: {e!s}")
|
self.log.debug(f"Error uploading media for Matrix message {event_id}: {e!s}")
|
||||||
|
@ -1191,7 +1215,8 @@ class Portal(DBPortal, BasePortal):
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
effective_sender, _ = await self.get_relay_sender(sender, f"{handler.action_name} {evt.event_id}")
|
effective_sender, _ = await self.get_relay_sender(sender, f"{handler.action_name} {evt.event_id}")
|
||||||
await handler.apply(self, effective_sender, evt.prev_content, evt.content)
|
if effective_sender:
|
||||||
|
await handler.apply(self, effective_sender, evt.prev_content, evt.content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error(
|
self.log.error(
|
||||||
f"Failed to handle Matrix {handler.action_name} {evt.event_id}: {e}",
|
f"Failed to handle Matrix {handler.action_name} {evt.event_id}: {e}",
|
||||||
|
@ -1314,7 +1339,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
await self.save()
|
await self.save()
|
||||||
|
|
||||||
async def _revert_matrix_room_topic(self, prev_content: RoomTopicStateEventContent) -> None:
|
async def _revert_matrix_room_topic(self, prev_content: RoomTopicStateEventContent) -> None:
|
||||||
await self.main_intent.set_room_topic(self.mxid, prev_content.topic)
|
await self.main_intent.set_room_topic(self.mxid, prev_content.topic or "")
|
||||||
|
|
||||||
async def _handle_matrix_room_avatar(
|
async def _handle_matrix_room_avatar(
|
||||||
self,
|
self,
|
||||||
|
@ -1479,7 +1504,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
chat_text: str | None,
|
chat_text: str | None,
|
||||||
chat_type: ChatType,
|
chat_type: ChatType,
|
||||||
**_
|
**_
|
||||||
) -> Awaitable[list[EventID]]:
|
) -> list[EventID]:
|
||||||
try:
|
try:
|
||||||
type_str = KnownChatType(chat_type).name.lower()
|
type_str = KnownChatType(chat_type).name.lower()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -1546,14 +1571,14 @@ class Portal(DBPortal, BasePortal):
|
||||||
await self._add_kakaotalk_reply(content, attachment)
|
await self._add_kakaotalk_reply(content, attachment)
|
||||||
return [await self._send_message(intent, content, timestamp=timestamp)]
|
return [await self._send_message(intent, content, timestamp=timestamp)]
|
||||||
|
|
||||||
def _handle_kakaotalk_photo(self, **kwargs) -> Awaitable[list[EventID]]:
|
async def _handle_kakaotalk_photo(self, **kwargs) -> list[EventID]:
|
||||||
return asyncio.gather(self._handle_kakaotalk_uniphoto(**kwargs))
|
return [await self._handle_kakaotalk_uniphoto(**kwargs)]
|
||||||
|
|
||||||
async def _handle_kakaotalk_multiphoto(
|
async def _handle_kakaotalk_multiphoto(
|
||||||
self,
|
self,
|
||||||
attachment: MultiPhotoAttachment,
|
attachment: MultiPhotoAttachment,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Awaitable[list[EventID]]:
|
) -> list[EventID]:
|
||||||
# TODO Upload media concurrently, but post messages sequentially
|
# TODO Upload media concurrently, but post messages sequentially
|
||||||
return [
|
return [
|
||||||
await self._handle_kakaotalk_uniphoto(
|
await self._handle_kakaotalk_uniphoto(
|
||||||
|
@ -1594,12 +1619,12 @@ class Portal(DBPortal, BasePortal):
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_kakaotalk_video(
|
async def _handle_kakaotalk_video(
|
||||||
self,
|
self,
|
||||||
attachment: VideoAttachment,
|
attachment: VideoAttachment,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Awaitable[list[EventID]]:
|
) -> list[EventID]:
|
||||||
return asyncio.gather(self._handle_kakaotalk_media(
|
return [await self._handle_kakaotalk_media(
|
||||||
attachment,
|
attachment,
|
||||||
VideoInfo(
|
VideoInfo(
|
||||||
duration=attachment.d,
|
duration=attachment.d,
|
||||||
|
@ -1608,14 +1633,14 @@ class Portal(DBPortal, BasePortal):
|
||||||
),
|
),
|
||||||
MessageType.VIDEO,
|
MessageType.VIDEO,
|
||||||
**kwargs
|
**kwargs
|
||||||
))
|
)]
|
||||||
|
|
||||||
def _handle_kakaotalk_audio(
|
async def _handle_kakaotalk_audio(
|
||||||
self,
|
self,
|
||||||
attachment: AudioAttachment,
|
attachment: AudioAttachment,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Awaitable[list[EventID]]:
|
) -> list[EventID]:
|
||||||
return asyncio.gather(self._handle_kakaotalk_media(
|
return [await self._handle_kakaotalk_media(
|
||||||
attachment,
|
attachment,
|
||||||
AudioInfo(
|
AudioInfo(
|
||||||
size=attachment.s,
|
size=attachment.s,
|
||||||
|
@ -1623,11 +1648,11 @@ class Portal(DBPortal, BasePortal):
|
||||||
),
|
),
|
||||||
MessageType.AUDIO,
|
MessageType.AUDIO,
|
||||||
**kwargs
|
**kwargs
|
||||||
))
|
)]
|
||||||
|
|
||||||
async def _handle_kakaotalk_media(
|
async def _handle_kakaotalk_media(
|
||||||
self,
|
self,
|
||||||
attachment: MediaAttachment,
|
attachment: MediaAttachment | AudioAttachment,
|
||||||
info: MediaInfo,
|
info: MediaInfo,
|
||||||
msgtype: MessageType,
|
msgtype: MessageType,
|
||||||
*,
|
*,
|
||||||
|
@ -1648,7 +1673,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
info.size = additional_info.size
|
info.size = additional_info.size
|
||||||
info.mimetype = additional_info.mimetype
|
info.mimetype = additional_info.mimetype
|
||||||
content = MediaMessageEventContent(
|
content = MediaMessageEventContent(
|
||||||
url=mxc, file=decryption_info, msgtype=msgtype, body=chat_text, info=info
|
url=mxc, file=decryption_info, msgtype=msgtype, body=chat_text or "", info=info
|
||||||
)
|
)
|
||||||
return await self._send_message(intent, content, timestamp=timestamp)
|
return await self._send_message(intent, content, timestamp=timestamp)
|
||||||
|
|
||||||
|
@ -1787,7 +1812,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
# TODO Should this be removed? With it, can't sync an empty portal!
|
# TODO Should this be removed? With it, can't sync an empty portal!
|
||||||
#elif (not most_recent or not most_recent.timestamp) and not is_initial:
|
#elif (not most_recent or not most_recent.timestamp) and not is_initial:
|
||||||
# self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log)
|
# self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log)
|
||||||
elif last_log_id and most_recent and int(most_recent.ktid) >= int(last_log_id):
|
elif last_log_id and most_recent and int(most_recent.ktid or 0) >= int(last_log_id):
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
"Not backfilling %s: last activity is equal to most recent bridged "
|
"Not backfilling %s: last activity is equal to most recent bridged "
|
||||||
"message (%s >= %s)",
|
"message (%s >= %s)",
|
||||||
|
|
|
@ -89,7 +89,7 @@ class User(DBUser, BaseUser):
|
||||||
by_mxid: dict[UserID, User] = {}
|
by_mxid: dict[UserID, User] = {}
|
||||||
by_ktid: dict[int, User] = {}
|
by_ktid: dict[int, User] = {}
|
||||||
|
|
||||||
client: Client | None
|
_client: Client | None
|
||||||
|
|
||||||
_notice_room_lock: asyncio.Lock
|
_notice_room_lock: asyncio.Lock
|
||||||
_notice_send_lock: asyncio.Lock
|
_notice_send_lock: asyncio.Lock
|
||||||
|
@ -141,7 +141,7 @@ class User(DBUser, BaseUser):
|
||||||
self._logged_in_info = None
|
self._logged_in_info = None
|
||||||
self._logged_in_info_time = 0
|
self._logged_in_info_time = 0
|
||||||
|
|
||||||
self.client = None
|
self._client = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]:
|
def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]:
|
||||||
|
@ -151,6 +151,12 @@ class User(DBUser, BaseUser):
|
||||||
cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
|
cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
|
||||||
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
|
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> Client:
|
||||||
|
if not self._client:
|
||||||
|
raise ValueError("User must be logged in before its client can be used")
|
||||||
|
return self._client
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool | None:
|
def is_connected(self) -> bool | None:
|
||||||
return self._is_connected
|
return self._is_connected
|
||||||
|
@ -242,7 +248,10 @@ class User(DBUser, BaseUser):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def oauth_credential(self) -> OAuthCredential:
|
def oauth_credential(self) -> OAuthCredential:
|
||||||
assert None not in (self.ktid, self.uuid, self.access_token, self.refresh_token)
|
assert self.ktid is not None
|
||||||
|
assert self.uuid is not None
|
||||||
|
assert self.access_token is not None
|
||||||
|
assert self.refresh_token is not None
|
||||||
return OAuthCredential(
|
return OAuthCredential(
|
||||||
self.ktid,
|
self.ktid,
|
||||||
self.uuid,
|
self.uuid,
|
||||||
|
@ -259,9 +268,9 @@ class User(DBUser, BaseUser):
|
||||||
self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
|
self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
|
||||||
self.uuid = oauth_credential.deviceUUID
|
self.uuid = oauth_credential.deviceUUID
|
||||||
|
|
||||||
async def get_own_info(self) -> SettingsStruct:
|
async def get_own_info(self) -> SettingsStruct | None:
|
||||||
if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic():
|
if self._client and (not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic()):
|
||||||
self._logged_in_info = await self.client.get_settings()
|
self._logged_in_info = await self._client.get_settings()
|
||||||
self._logged_in_info_time = time.monotonic()
|
self._logged_in_info_time = time.monotonic()
|
||||||
return self._logged_in_info
|
return self._logged_in_info
|
||||||
|
|
||||||
|
@ -280,7 +289,7 @@ class User(DBUser, BaseUser):
|
||||||
user_info = await client.start()
|
user_info = await client.start()
|
||||||
# NOTE On failure, client.start throws instead of returning something falsy
|
# NOTE On failure, client.start throws instead of returning something falsy
|
||||||
self.log.info("Loaded session successfully")
|
self.log.info("Loaded session successfully")
|
||||||
self.client = client
|
self._client = client
|
||||||
self._logged_in_info = user_info
|
self._logged_in_info = user_info
|
||||||
self._logged_in_info_time = time.monotonic()
|
self._logged_in_info_time = time.monotonic()
|
||||||
self._track_metric(METRIC_LOGGED_IN, True)
|
self._track_metric(METRIC_LOGGED_IN, True)
|
||||||
|
@ -304,7 +313,7 @@ class User(DBUser, BaseUser):
|
||||||
await self.logout(remove_ktid=False)
|
await self.logout(remove_ktid=False)
|
||||||
|
|
||||||
async def is_logged_in(self, _override: bool = False) -> bool:
|
async def is_logged_in(self, _override: bool = False) -> bool:
|
||||||
if not self.has_state or not self.client:
|
if not self.has_state or not self._client:
|
||||||
return False
|
return False
|
||||||
if self._is_logged_in is None or _override:
|
if self._is_logged_in is None or _override:
|
||||||
try:
|
try:
|
||||||
|
@ -360,9 +369,9 @@ class User(DBUser, BaseUser):
|
||||||
self._is_rpc_reconnecting = False
|
self._is_rpc_reconnecting = False
|
||||||
|
|
||||||
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
|
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
|
||||||
if self.client:
|
if self._client:
|
||||||
# TODO Look for a logout API call
|
# TODO Look for a logout API call
|
||||||
await self.client.stop()
|
await self._client.stop()
|
||||||
if remove_ktid:
|
if remove_ktid:
|
||||||
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
|
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
|
||||||
self._track_metric(METRIC_LOGGED_IN, False)
|
self._track_metric(METRIC_LOGGED_IN, False)
|
||||||
|
@ -374,7 +383,7 @@ class User(DBUser, BaseUser):
|
||||||
|
|
||||||
self._is_logged_in = False
|
self._is_logged_in = False
|
||||||
self.is_connected = None
|
self.is_connected = None
|
||||||
self.client = None
|
self._client = None
|
||||||
|
|
||||||
if self.ktid and remove_ktid:
|
if self.ktid and remove_ktid:
|
||||||
#await UserPortal.delete_all(self.ktid)
|
#await UserPortal.delete_all(self.ktid)
|
||||||
|
@ -581,12 +590,12 @@ class User(DBUser, BaseUser):
|
||||||
state.remote_name = puppet.name
|
state.remote_name = puppet.name
|
||||||
|
|
||||||
async def get_bridge_states(self) -> list[BridgeState]:
|
async def get_bridge_states(self) -> list[BridgeState]:
|
||||||
if not self.state:
|
if not self.has_state:
|
||||||
return []
|
return []
|
||||||
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
|
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
state.state_event = BridgeStateEvent.CONNECTED
|
state.state_event = BridgeStateEvent.CONNECTED
|
||||||
elif self._is_rpc_reconnecting or self.client:
|
elif self._is_rpc_reconnecting or self._client:
|
||||||
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
|
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
|
||||||
return [state]
|
return [state]
|
||||||
|
|
||||||
|
@ -660,8 +669,8 @@ class User(DBUser, BaseUser):
|
||||||
await self.logout()
|
await self.logout()
|
||||||
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
|
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
|
||||||
|
|
||||||
def on_error(self, error: JSON) -> Awaitable[None]:
|
async def on_error(self, error: JSON) -> None:
|
||||||
return self.send_bridge_notice(
|
await self.send_bridge_notice(
|
||||||
f"Got error event from KakaoTalk:\n\n> {error}",
|
f"Got error event from KakaoTalk:\n\n> {error}",
|
||||||
# TODO Which error code to use?
|
# TODO Which error code to use?
|
||||||
#error_code="kt-connection-error",
|
#error_code="kt-connection-error",
|
||||||
|
@ -671,7 +680,7 @@ class User(DBUser, BaseUser):
|
||||||
async def on_client_disconnect(self) -> None:
|
async def on_client_disconnect(self) -> None:
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self._track_metric(METRIC_CONNECTED, False)
|
self._track_metric(METRIC_CONNECTED, False)
|
||||||
self.client = None
|
self._client = None
|
||||||
if self._is_logged_in:
|
if self._is_logged_in:
|
||||||
if self.temp_disconnect_notices:
|
if self.temp_disconnect_notices:
|
||||||
await self.send_bridge_notice(
|
await self.send_bridge_notice(
|
||||||
|
@ -685,12 +694,12 @@ class User(DBUser, BaseUser):
|
||||||
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
|
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
|
||||||
self.oauth_credential = oauth_credential
|
self.oauth_credential = oauth_credential
|
||||||
await self.push_bridge_state(BridgeStateEvent.CONNECTING)
|
await self.push_bridge_state(BridgeStateEvent.CONNECTING)
|
||||||
self.client = Client(self, log=self.log.getChild("ktclient"))
|
self._client = Client(self, log=self.log.getChild("ktclient"))
|
||||||
await self.save()
|
await self.save()
|
||||||
self._is_logged_in = True
|
self._is_logged_in = True
|
||||||
# TODO Retry network connection failures here, or in the client (like token refreshes are)?
|
# TODO Retry network connection failures here, or in the client (like token refreshes are)?
|
||||||
# Should also catch unlikely authentication errors
|
# Should also catch unlikely authentication errors
|
||||||
self._logged_in_info = await self.client.start()
|
self._logged_in_info = await self._client.start()
|
||||||
self._logged_in_info_time = time.monotonic()
|
self._logged_in_info_time = time.monotonic()
|
||||||
asyncio.create_task(self.post_login(is_startup=True))
|
asyncio.create_task(self.post_login(is_startup=True))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue