Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
99ea716731 | |||
adb7453e1a |
matrix_appservice_kakaotalk
__main__.py
commands
config.pydb
formatter
kt
matrix.pyportal.pypuppet.pyrpc
user.py@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from mautrix.bridge import Bridge
|
||||
from mautrix.types import RoomID, UserID
|
||||
@ -46,7 +46,7 @@ class KakaoTalkBridge(Bridge):
|
||||
|
||||
config: Config
|
||||
matrix: MatrixHandler
|
||||
#public_website: PublicBridgeWebsite | None
|
||||
#public_website: Optional[PublicBridgeWebsite] None
|
||||
|
||||
def prepare_config(self)->None:
|
||||
super().prepare_config()
|
||||
@ -117,7 +117,7 @@ class KakaoTalkBridge(Bridge):
|
||||
async def count_logged_in_users(self) -> int:
|
||||
return len([user for user in User.by_ktid.values() if user.ktid])
|
||||
|
||||
async def manhole_global_namespace(self, user_id: UserID) -> dict[str, Any]:
|
||||
async def manhole_global_namespace(self, user_id: UserID) -> Dict[str, Any]:
|
||||
return {
|
||||
**await super().manhole_global_namespace(user_id),
|
||||
"User": User,
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Awaitable
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional
|
||||
import asyncio
|
||||
|
||||
from mautrix.bridge.commands import HelpSection, command_handler
|
||||
@ -209,7 +209,7 @@ class MentionFormatString(EntityString[SimpleEntity, EntityType], MarkdownString
|
||||
class MentionParser(MatrixParser[MentionFormatString]):
|
||||
fs = MentionFormatString
|
||||
|
||||
async def _get_id_from_mxid(mxid: UserID) -> Long | None:
|
||||
async def _get_id_from_mxid(mxid: UserID) -> Optional[Long]:
|
||||
user = await u.User.get_by_mxid(mxid, create=False)
|
||||
if user and user.ktid:
|
||||
return user.ktid
|
||||
@ -294,7 +294,7 @@ async def _edit_friend_by_uuid(evt: CommandEvent, uuid: str, add: bool) -> None:
|
||||
else:
|
||||
await _on_friend_edited(evt, friend_struct, add)
|
||||
|
||||
async def _on_friend_edited(evt: CommandEvent, friend_struct: FriendStruct | None, add: bool):
|
||||
async def _on_friend_edited(evt: CommandEvent, friend_struct: Optional[FriendStruct], add: bool):
|
||||
await evt.reply(f"Friend {'added' if add else 'removed'}")
|
||||
if friend_struct:
|
||||
puppet = await pu.Puppet.get_by_ktid(friend_struct.userId)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, List, Tuple
|
||||
import os
|
||||
|
||||
from mautrix.bridge.config import BaseBridgeConfig
|
||||
@ -31,7 +31,7 @@ class Config(BaseBridgeConfig):
|
||||
return super().__getitem__(key)
|
||||
|
||||
@property
|
||||
def forbidden_defaults(self) -> list[ForbiddenDefault]:
|
||||
def forbidden_defaults(self) -> List[ForbiddenDefault]:
|
||||
return [
|
||||
*super().forbidden_defaults,
|
||||
ForbiddenDefault("appservice.database", "postgres://username:password@hostname/db"),
|
||||
@ -118,14 +118,14 @@ class Config(BaseBridgeConfig):
|
||||
copy("rpc.connection.host")
|
||||
copy("rpc.connection.port")
|
||||
|
||||
def _get_permissions(self, key: str) -> tuple[bool, bool, bool, str]:
|
||||
def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, str]:
|
||||
level = self["bridge.permissions"].get(key, "")
|
||||
admin = level == "admin"
|
||||
user = level == "user" or admin
|
||||
relay = level == "relay" or user
|
||||
return relay, user, admin, level
|
||||
|
||||
def get_permissions(self, mxid: UserID) -> tuple[bool, bool, bool, str]:
|
||||
def get_permissions(self, mxid: UserID) -> Tuple[bool, bool, bool, str]:
|
||||
permissions = self["bridge.permissions"] or {}
|
||||
if mxid in permissions:
|
||||
return self._get_permissions(mxid)
|
||||
|
@ -15,9 +15,8 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Optional
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import UserID
|
||||
@ -35,7 +34,7 @@ class LoginCredential:
|
||||
password: str
|
||||
|
||||
@classmethod
|
||||
async def get_by_mxid(cls, mxid: UserID) -> LoginCredential | None:
|
||||
async def get_by_mxid(cls, mxid: UserID) -> Optional[LoginCredential]:
|
||||
q = "SELECT mxid, email, password FROM login_credential WHERE mxid=$1"
|
||||
row = await cls.db.fetchrow(q, mxid)
|
||||
return cls(**row) if row else None
|
||||
@ -51,7 +50,7 @@ class LoginCredential:
|
||||
"""
|
||||
await self.db.execute(q, self.mxid, self.email, self.password)
|
||||
|
||||
def get_form(self) -> dict[str, str]:
|
||||
def get_form(self) -> Dict[str, str]:
|
||||
return {
|
||||
"email": self.email,
|
||||
"password": self.password,
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass, field
|
||||
@ -34,7 +34,7 @@ class Message:
|
||||
|
||||
mxid: EventID
|
||||
mx_room: RoomID
|
||||
ktid: Long | None = field(converter=to_optional_long)
|
||||
ktid: Optional[Long] = field(converter=to_optional_long)
|
||||
index: int
|
||||
kt_channel: Long = field(converter=Long)
|
||||
kt_receiver: Long = field(converter=Long)
|
||||
@ -45,19 +45,19 @@ class Message:
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> Message | None:
|
||||
def _from_optional_row(cls, row: Optional[Record]) -> Optional[Message]:
|
||||
return cls._from_row(row) if row is not None else None
|
||||
|
||||
columns = 'mxid, mx_room, ktid, "index", kt_channel, kt_receiver, timestamp'
|
||||
|
||||
@classmethod
|
||||
async def get_all_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int) -> list[Message]:
|
||||
async def get_all_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int) -> List[Message]:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3"
|
||||
rows = await cls.db.fetch(q, ktid, kt_channel, kt_receiver)
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int, index: int = 0) -> Message | None:
|
||||
async def get_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int, index: int = 0) -> Optional[Message]:
|
||||
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3 AND "index"=$4'
|
||||
row = await cls.db.fetchrow(q, ktid, kt_channel, kt_receiver, index)
|
||||
return cls._from_optional_row(row)
|
||||
@ -67,13 +67,13 @@ class Message:
|
||||
await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
|
||||
|
||||
@classmethod
|
||||
async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
|
||||
async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional[Message]:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2"
|
||||
row = await cls.db.fetchrow(q, mxid, mx_room)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_most_recent(cls, kt_channel: int, kt_receiver: int) -> Message | None:
|
||||
async def get_most_recent(cls, kt_channel: int, kt_receiver: int) -> Optional[Message]:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
|
||||
@ -85,7 +85,7 @@ class Message:
|
||||
@classmethod
|
||||
async def get_closest_before(
|
||||
cls, kt_channel: int, kt_receiver: int, ktid: int
|
||||
) -> Message | None:
|
||||
) -> Optional[Message]:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid<=$3 "
|
||||
@ -95,7 +95,7 @@ class Message:
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_all_since(cls, kt_channel: int, kt_receiver: int, since_ktid: Long | None) -> list[Message]:
|
||||
async def get_all_since(cls, kt_channel: int, kt_receiver: int, since_ktid: Optional[Long]) -> List[Message]:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid>=$3 "
|
||||
@ -116,10 +116,10 @@ class Message:
|
||||
ktid: Long,
|
||||
kt_channel: Long,
|
||||
kt_receiver: Long,
|
||||
event_ids: list[EventID],
|
||||
event_ids: List[EventID],
|
||||
timestamp: int,
|
||||
mx_room: RoomID,
|
||||
) -> list[Message]:
|
||||
) -> List[Message]:
|
||||
if not event_ids:
|
||||
return []
|
||||
columns = [col.strip('"') for col in cls.columns.split(", ")]
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass, field
|
||||
@ -36,24 +36,24 @@ class Portal:
|
||||
ktid: Long = field(converter=Long)
|
||||
kt_receiver: Long = field(converter=Long)
|
||||
kt_type: ChannelType
|
||||
mxid: RoomID | None
|
||||
name: str | None
|
||||
description: str | None
|
||||
photo_id: str | None
|
||||
avatar_url: ContentURI | None
|
||||
mxid: Optional[RoomID]
|
||||
name: Optional[str]
|
||||
description: Optional[str]
|
||||
photo_id: Optional[str]
|
||||
avatar_url: Optional[ContentURI]
|
||||
encrypted: bool
|
||||
name_set: bool
|
||||
topic_set: bool
|
||||
avatar_set: bool
|
||||
fully_read_kt_chat: Long | None = field(converter=to_optional_long)
|
||||
relay_user_id: UserID | None
|
||||
fully_read_kt_chat: Optional[Long] = field(converter=to_optional_long)
|
||||
relay_user_id: Optional[UserID]
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record) -> Portal:
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> Portal | None:
|
||||
def _from_optional_row(cls, row: Optional[Record]) -> Optional[Portal]:
|
||||
return cls._from_row(row) if row is not None else None
|
||||
|
||||
_columns = (
|
||||
@ -62,25 +62,25 @@ class Portal:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
|
||||
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Optional[Portal]:
|
||||
q = f"SELECT {cls._columns} FROM portal WHERE ktid=$1 AND kt_receiver=$2"
|
||||
row = await cls.db.fetchrow(q, ktid, kt_receiver)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
|
||||
async def get_by_mxid(cls, mxid: RoomID) -> Optional[Portal]:
|
||||
q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
|
||||
row = await cls.db.fetchrow(q, mxid)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
|
||||
async def get_all_by_receiver(cls, kt_receiver: int) -> List[Portal]:
|
||||
q = f"SELECT {cls._columns} FROM portal WHERE kt_receiver=$1"
|
||||
rows = await cls.db.fetch(q, kt_receiver)
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[Portal]:
|
||||
async def all(cls) -> List[Portal]:
|
||||
q = f"SELECT {cls._columns} FROM portal"
|
||||
rows = await cls.db.fetch(q)
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass, field
|
||||
@ -34,17 +34,17 @@ class Puppet:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
ktid: Long = field(converter=Long)
|
||||
name: str | None
|
||||
photo_id: str | None
|
||||
photo_mxc: ContentURI | None
|
||||
name: Optional[str]
|
||||
photo_id: Optional[str]
|
||||
photo_mxc: Optional[ContentURI]
|
||||
name_set: bool
|
||||
avatar_set: bool
|
||||
is_registered: bool
|
||||
|
||||
custom_mxid: UserID | None
|
||||
access_token: str | None
|
||||
next_batch: SyncToken | None
|
||||
base_url: URL | None
|
||||
custom_mxid: Optional[UserID]
|
||||
access_token: Optional[str]
|
||||
next_batch: Optional[SyncToken]
|
||||
base_url: Optional[URL]
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record) -> Puppet:
|
||||
@ -53,11 +53,11 @@ class Puppet:
|
||||
return cls(**data, base_url=URL(base_url) if base_url else None)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> Puppet | None:
|
||||
def _from_optional_row(cls, row: Optional[Record]) -> Optional[Puppet]:
|
||||
return cls._from_row(row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: int) -> Puppet | None:
|
||||
async def get_by_ktid(cls, ktid: int) -> Optional[Puppet]:
|
||||
q = (
|
||||
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
|
||||
" custom_mxid, access_token, next_batch, base_url "
|
||||
@ -67,7 +67,7 @@ class Puppet:
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_by_name(cls, name: str) -> Puppet | None:
|
||||
async def get_by_name(cls, name: str) -> Optional[Puppet]:
|
||||
q = (
|
||||
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
|
||||
" custom_mxid, access_token, next_batch, base_url "
|
||||
@ -77,7 +77,7 @@ class Puppet:
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
|
||||
async def get_by_custom_mxid(cls, mxid: UserID) -> Optional[Puppet]:
|
||||
q = (
|
||||
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
|
||||
" custom_mxid, access_token, next_batch, base_url "
|
||||
@ -87,7 +87,7 @@ class Puppet:
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_all_with_custom_mxid(cls) -> list[Puppet]:
|
||||
async def get_all_with_custom_mxid(cls) -> List[Puppet]:
|
||||
q = (
|
||||
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
|
||||
" custom_mxid, access_token, next_batch, base_url "
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Set
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Set
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass, field
|
||||
@ -35,18 +35,18 @@ class User:
|
||||
mxid: UserID
|
||||
force_login: bool
|
||||
was_connected: bool
|
||||
ktid: Long | None = field(converter=to_optional_long)
|
||||
uuid: str | None
|
||||
access_token: str | None
|
||||
refresh_token: str | None
|
||||
notice_room: RoomID | None
|
||||
ktid: Optional[Long] = field(converter=to_optional_long)
|
||||
uuid: Optional[str]
|
||||
access_token: Optional[str]
|
||||
refresh_token: Optional[str]
|
||||
notice_room: Optional[RoomID]
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record) -> User:
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> User | None:
|
||||
def _from_optional_row(cls, row: Optional[Record]) -> Optional[User]:
|
||||
return cls._from_row(row) if row is not None else None
|
||||
|
||||
_columns = "mxid, force_login, was_connected, ktid, uuid, access_token, refresh_token, notice_room"
|
||||
@ -58,13 +58,13 @@ class User:
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: int) -> User | None:
|
||||
async def get_by_ktid(cls, ktid: int) -> Optional[User]:
|
||||
q = f'SELECT {cls._columns} FROM "user" WHERE ktid=$1'
|
||||
row = await cls.db.fetchrow(q, ktid)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_by_mxid(cls, mxid: UserID) -> User | None:
|
||||
async def get_by_mxid(cls, mxid: UserID) -> Optional[User]:
|
||||
q = f'SELECT {cls._columns} FROM "user" WHERE mxid=$1'
|
||||
row = await cls.db.fetchrow(q, mxid)
|
||||
return cls._from_optional_row(row)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Match
|
||||
from typing import List, Match, Optional
|
||||
import re
|
||||
|
||||
from mautrix.types import Format, MessageType, TextMessageEventContent
|
||||
@ -28,7 +28,7 @@ from .. import puppet as pu, user as u
|
||||
MENTION_REGEX = re.compile(r"@(\d+)\u2063(.+?)\u2063")
|
||||
|
||||
|
||||
async def kakaotalk_to_matrix(msg: str | None, mentions: list[MentionStruct] | None) -> TextMessageEventContent:
|
||||
async def kakaotalk_to_matrix(msg: Optional[str], mentions: Optional[List[MentionStruct]]) -> TextMessageEventContent:
|
||||
# TODO Shouts
|
||||
text = msg or ""
|
||||
content = TextMessageEventContent(msgtype=MessageType.TEXT, body=text)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
from mautrix.types import Format, MessageEventContent, RelationType, RoomID, UserID
|
||||
from mautrix.util import utf16_surrogate
|
||||
@ -40,7 +40,7 @@ from ..db import Message as DBMessage
|
||||
|
||||
class SendParams(NamedTuple):
|
||||
text: str
|
||||
mentions: list[MentionStruct] | None
|
||||
mentions: Optional[List[MentionStruct]]
|
||||
reply_to: ReplyAttachment
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ class ToKakaoTalkParser(MatrixParser[KakaoTalkFormatString]):
|
||||
fs = KakaoTalkFormatString
|
||||
|
||||
|
||||
async def _get_id_from_mxid(mxid: UserID, portal: po.Portal) -> Long | None:
|
||||
async def _get_id_from_mxid(mxid: UserID, portal: po.Portal) -> Optional[Long]:
|
||||
orig_sender = await u.User.get_by_mxid(mxid, create=False)
|
||||
if orig_sender and orig_sender.ktid:
|
||||
return orig_sender.ktid
|
||||
@ -161,7 +161,7 @@ async def matrix_to_kakaotalk(
|
||||
):
|
||||
parsed = await ToKakaoTalkParser().parse(utf16_surrogate.add(content["formatted_body"]))
|
||||
text = utf16_surrogate.remove(parsed.text)
|
||||
mentions_by_user: dict[Long, MentionStruct] = {}
|
||||
mentions_by_user: Dict[Long, MentionStruct] = {}
|
||||
# Make sure to not create remote mentions for any remote user not in the room
|
||||
if parsed.entities:
|
||||
joined_members = set(await portal.main_intent.get_room_members(room_id))
|
||||
@ -189,7 +189,7 @@ async def matrix_to_kakaotalk(
|
||||
return SendParams(text=text, mentions=mentions, reply_to=reply_to)
|
||||
|
||||
|
||||
_media_type_reply_body_map: dict[KnownChatType, str] = {
|
||||
_media_type_reply_body_map: Dict[KnownChatType, str] = {
|
||||
KnownChatType.PHOTO: "Photo",
|
||||
KnownChatType.VIDEO: "Video",
|
||||
KnownChatType.AUDIO: "Voice Note",
|
||||
|
@ -22,7 +22,7 @@ with any other potential backend.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast, Awaitable, Type, Optional, Union
|
||||
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Set, Type, Union, cast
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
@ -122,7 +122,7 @@ class Client:
|
||||
# region tokenless commands
|
||||
|
||||
@classmethod
|
||||
async def generate_uuid(cls, used_uuids: set[str]) -> str:
|
||||
async def generate_uuid(cls, used_uuids: Set[str]) -> str:
|
||||
"""Randomly generate a UUID for a (fake) device."""
|
||||
tries_remaining = 10
|
||||
while True:
|
||||
@ -160,10 +160,10 @@ class Client:
|
||||
|
||||
|
||||
user: u.User
|
||||
_rpc_disconnection_task: asyncio.Task | None
|
||||
_rpc_disconnection_task: Optional[asyncio.Task]
|
||||
http: ClientSession
|
||||
log: TraceLogger
|
||||
_handler_methods: list[str]
|
||||
_handler_methods: List[str]
|
||||
|
||||
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
|
||||
"""Create a per-user client object for user-specific client functionality."""
|
||||
@ -195,7 +195,7 @@ class Client:
|
||||
def get(
|
||||
self,
|
||||
url: Union[str, URL],
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
sandbox: bool = False,
|
||||
**kwargs,
|
||||
) -> _RequestContextManager:
|
||||
@ -215,7 +215,7 @@ class Client:
|
||||
|
||||
# region post-token commands
|
||||
|
||||
async def start(self) -> SettingsStruct | None:
|
||||
async def start(self) -> Optional[SettingsStruct]:
|
||||
"""
|
||||
Initialize user-specific bridging & state by providing a token obtained from a prior login.
|
||||
Receive the user's profile info in response.
|
||||
@ -252,7 +252,7 @@ class Client:
|
||||
self.user.oauth_credential = oauth_info.credential
|
||||
await self.user.save()
|
||||
|
||||
async def connect(self) -> LoginResult | None:
|
||||
async def connect(self) -> Optional[LoginResult]:
|
||||
"""
|
||||
Start a new talk session by providing a token obtained from a prior login.
|
||||
Receive a snapshot of account state in response.
|
||||
@ -304,14 +304,14 @@ class Client:
|
||||
channel_props=channel_props.serialize(),
|
||||
)
|
||||
|
||||
def get_participants(self, channel_props: ChannelProps) -> Awaitable[list[UserInfoUnion]]:
|
||||
def get_participants(self, channel_props: ChannelProps) -> Awaitable[List[UserInfoUnion]]:
|
||||
return self._api_user_request_result(
|
||||
ResultListType(UserInfoUnion),
|
||||
"get_participants",
|
||||
channel_props=channel_props.serialize(),
|
||||
)
|
||||
|
||||
def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> Awaitable[list[Chatlog]]:
|
||||
def get_chats(self, channel_props: ChannelProps, sync_from: Optional[Long], limit: Optional[int]) -> Awaitable[List[Chatlog]]:
|
||||
return self._api_user_request_result(
|
||||
ResultListType(Chatlog),
|
||||
"get_chats",
|
||||
@ -320,7 +320,7 @@ class Client:
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]:
|
||||
def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: List[Long]) -> Awaitable[List[Receipt]]:
|
||||
return self._api_user_request_result(
|
||||
ResultListType(Receipt),
|
||||
"get_read_receipts",
|
||||
@ -348,7 +348,7 @@ class Client:
|
||||
"list_friends",
|
||||
)
|
||||
|
||||
async def edit_friend(self, ktid: Long, add: bool) -> FriendStruct | None:
|
||||
async def edit_friend(self, ktid: Long, add: bool) -> Optional[FriendStruct]:
|
||||
try:
|
||||
friend_req_struct = await self._api_user_request_result(
|
||||
FriendReqStruct,
|
||||
@ -361,7 +361,7 @@ class Client:
|
||||
self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
|
||||
return None
|
||||
|
||||
async def edit_friend_by_uuid(self, uuid: str, add: bool) -> FriendStruct | None:
|
||||
async def edit_friend_by_uuid(self, uuid: str, add: bool) -> Optional[FriendStruct]:
|
||||
try:
|
||||
friend_req_struct = await self._api_user_request_result(
|
||||
FriendReqStruct,
|
||||
@ -374,7 +374,7 @@ class Client:
|
||||
self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
|
||||
return None
|
||||
|
||||
async def get_friend_dm_id(self, friend_id: Long) -> Long | None:
|
||||
async def get_friend_dm_id(self, friend_id: Long) -> Optional[Long]:
|
||||
try:
|
||||
return await self._api_user_request_result(
|
||||
Long,
|
||||
@ -385,7 +385,7 @@ class Client:
|
||||
self.log.exception(f"Could not find friend with ID {friend_id}")
|
||||
return None
|
||||
|
||||
async def get_memo_ids(self) -> list[Long]:
|
||||
async def get_memo_ids(self) -> List[Long]:
|
||||
return ResultListType(Long).deserialize(
|
||||
await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid)
|
||||
)
|
||||
@ -406,8 +406,8 @@ class Client:
|
||||
self,
|
||||
channel_props: ChannelProps,
|
||||
text: str,
|
||||
reply_to: ReplyAttachment | None,
|
||||
mentions: list[MentionStruct] | None,
|
||||
reply_to: Optional[ReplyAttachment],
|
||||
mentions: Optional[List[MentionStruct]],
|
||||
) -> Awaitable[Long]:
|
||||
return self._api_user_request_result(
|
||||
Long,
|
||||
@ -425,9 +425,9 @@ class Client:
|
||||
data: bytes,
|
||||
filename: str,
|
||||
*,
|
||||
width: int | None = None,
|
||||
height: int | None = None,
|
||||
ext: str | None = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
ext: Optional[str] = None,
|
||||
) -> Awaitable[Long]:
|
||||
return self._api_user_request_result(
|
||||
Long,
|
||||
@ -584,14 +584,14 @@ class Client:
|
||||
|
||||
# region listeners
|
||||
|
||||
def _on_chat(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_chat(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_chat(
|
||||
Chatlog.deserialize(data["chatlog"]),
|
||||
Long.deserialize(data["channelId"]),
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_chat_deleted(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_chat_deleted(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_chat_deleted(
|
||||
Long.deserialize(data["chatId"]),
|
||||
Long.deserialize(data["senderId"]),
|
||||
@ -600,7 +600,7 @@ class Client:
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_chat_read(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_chat_read(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_chat_read(
|
||||
Long.deserialize(data["chatId"]),
|
||||
Long.deserialize(data["senderId"]),
|
||||
@ -608,12 +608,12 @@ class Client:
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_profile_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_profile_changed(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_profile_changed(
|
||||
OpenLinkChannelUserInfo.deserialize(data["info"]),
|
||||
)
|
||||
|
||||
def _on_perm_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_perm_changed(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_perm_changed(
|
||||
Long.deserialize(data["userId"]),
|
||||
OpenChannelUserPerm(data["perm"]),
|
||||
@ -622,23 +622,23 @@ class Client:
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_channel_added(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_channel_added(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_channel_added(
|
||||
ChannelInfo.deserialize(data["channelInfo"]),
|
||||
)
|
||||
|
||||
def _on_channel_join(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_channel_join(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_channel_join(
|
||||
ChannelInfo.deserialize(data["channelInfo"]),
|
||||
)
|
||||
|
||||
def _on_channel_left(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_channel_left(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_channel_left(
|
||||
Long.deserialize(data["channelId"]),
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_channel_kicked(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_channel_kicked(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_channel_kicked(
|
||||
Long.deserialize(data["userId"]),
|
||||
Long.deserialize(data["senderId"]),
|
||||
@ -646,21 +646,21 @@ class Client:
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_user_join(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_user_join(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_user_join(
|
||||
Long.deserialize(data["userId"]),
|
||||
Long.deserialize(data["channelId"]),
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_user_left(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_user_left(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_user_left(
|
||||
Long.deserialize(data["userId"]),
|
||||
Long.deserialize(data["channelId"]),
|
||||
str(data["channelType"]),
|
||||
)
|
||||
|
||||
def _on_channel_meta_change(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_channel_meta_change(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_channel_meta_change(
|
||||
PortalChannelInfo.deserialize(data["info"]),
|
||||
Long.deserialize(data["channelId"]),
|
||||
@ -668,7 +668,7 @@ class Client:
|
||||
)
|
||||
|
||||
|
||||
def _on_listen_disconnect(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_listen_disconnect(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
try:
|
||||
res = KickoutRes.deserialize(data)
|
||||
except Exception:
|
||||
@ -676,18 +676,18 @@ class Client:
|
||||
res = None
|
||||
return self._on_disconnect(res)
|
||||
|
||||
def _on_switch_server(self, _: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_switch_server(self, _: Dict[str, JSON]) -> Awaitable[None]:
|
||||
# TODO Reconnect automatically instead
|
||||
return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
|
||||
|
||||
def _on_disconnect(self, res: KickoutRes | None) -> Awaitable[None]:
|
||||
def _on_disconnect(self, res: Optional[KickoutRes]) -> Awaitable[None]:
|
||||
self._stop_listen()
|
||||
return self.user.on_disconnect(res)
|
||||
|
||||
def _on_error(self, data: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_error(self, data: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_error(data)
|
||||
|
||||
def _on_unexpected_disconnect(self, _: dict[str, JSON]) -> Awaitable[None]:
|
||||
def _on_unexpected_disconnect(self, _: Dict[str, JSON]) -> Awaitable[None]:
|
||||
return self.user.on_unexpected_disconnect()
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
"""Internal helpers for error handling."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NoReturn, Type
|
||||
from typing import Dict, NoReturn, Type, Union
|
||||
|
||||
from .errors import (
|
||||
CommandException,
|
||||
@ -35,7 +35,7 @@ def raise_unsuccessful_response(resp: RootCommandResult) -> NoReturn:
|
||||
raise _error_code_class_map.get(resp.status, CommandException)(resp)
|
||||
|
||||
|
||||
_error_code_class_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int, Type[CommandException]] = {
|
||||
_error_code_class_map: Dict[Union[KnownAuthStatusCode, KnownDataStatusCode, int], Type[CommandException]] = {
|
||||
#KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number",
|
||||
#KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success",
|
||||
#KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)",
|
||||
|
@ -16,6 +16,8 @@
|
||||
"""Helper functions & types for status codes for the KakaoTalk API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
from ..types.api.auth_api_client import KnownAuthStatusCode
|
||||
from ..types.request import KnownDataStatusCode, RootCommandResult
|
||||
|
||||
@ -72,7 +74,7 @@ class ResponseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_status_code_message_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int, str] = {
|
||||
_status_code_message_map: Dict[Union[KnownAuthStatusCode, KnownDataStatusCode, int], str] = {
|
||||
KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number",
|
||||
KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success",
|
||||
KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)",
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""Custom wrapper classes around types defined by the KakaoTalk API."""
|
||||
|
||||
from typing import Optional, NewType, Union
|
||||
from typing import Dict, List, NewType, Optional, Union
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -73,8 +73,8 @@ class Receipt(SerializableAttrs):
|
||||
|
||||
@dataclass
|
||||
class PortalChannelParticipantInfo(SerializableAttrs):
|
||||
participants: list[UserInfoUnion]
|
||||
kickedUserIds: list[Long]
|
||||
participants: List[UserInfoUnion]
|
||||
kickedUserIds: List[Long]
|
||||
|
||||
@dataclass
|
||||
class PortalChannelInfo(SerializableAttrs):
|
||||
@ -92,7 +92,7 @@ class ChannelProps(SerializableAttrs):
|
||||
|
||||
|
||||
# TODO Add non-media types, like polls & maps
|
||||
TO_MSGTYPE_MAP: dict[MessageType, KnownChatType] = {
|
||||
TO_MSGTYPE_MAP: Dict[MessageType, KnownChatType] = {
|
||||
MessageType.TEXT: KnownChatType.TEXT,
|
||||
MessageType.IMAGE: KnownChatType.PHOTO,
|
||||
MessageType.STICKER: KnownChatType.PHOTO,
|
||||
@ -102,13 +102,13 @@ TO_MSGTYPE_MAP: dict[MessageType, KnownChatType] = {
|
||||
}
|
||||
|
||||
# https://stackoverflow.com/a/483833
|
||||
FROM_MSGTYPE_MAP: dict[KnownChatType, MessageType] = {v: k for k, v in TO_MSGTYPE_MAP.items()}
|
||||
FROM_MSGTYPE_MAP: Dict[KnownChatType, MessageType] = {v: k for k, v in TO_MSGTYPE_MAP.items()}
|
||||
|
||||
|
||||
# TODO Consider allowing custom power/perm mappings
|
||||
# But must update default user level & permissions to match!
|
||||
|
||||
FROM_PERM_MAP: dict[OpenChannelUserPerm, int] = {
|
||||
FROM_PERM_MAP: Dict[OpenChannelUserPerm, int] = {
|
||||
OpenChannelUserPerm.OWNER: 100,
|
||||
OpenChannelUserPerm.MANAGER: 50,
|
||||
# TODO Decide on an appropriate value for this
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -54,11 +54,11 @@ class MoreSettingsStruct(SerializableAttrs):
|
||||
|
||||
@dataclass
|
||||
class MoreApps(SerializableAttrs):
|
||||
recommend: Optional[list[str]] = None # NOTE From unknown[]
|
||||
all: Optional[list[str]] = None # NOTE From unknown[]
|
||||
recommend: Optional[List[str]] = None # NOTE From unknown[]
|
||||
all: Optional[List[str]] = None # NOTE From unknown[]
|
||||
moreApps: MoreApps
|
||||
|
||||
shortcuts: Optional[dict[str, int]] = None # NOTE Made optional
|
||||
shortcuts: Optional[Dict[str, int]] = None # NOTE Made optional
|
||||
seasonProfileRev: int
|
||||
seasonNoticeRev: int
|
||||
serviceUserId: Union[Long, int]
|
||||
@ -86,8 +86,8 @@ class MoreSettingsStruct(SerializableAttrs):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class LessSettingsStruct(SerializableAttrs):
|
||||
kakaoAutoLoginDomain: list[str]
|
||||
daumSsoDomain: list[str]
|
||||
kakaoAutoLoginDomain: List[str]
|
||||
daumSsoDomain: List[str]
|
||||
|
||||
@dataclass
|
||||
class GoogleMapsApi(SerializableAttrs):
|
||||
@ -124,7 +124,7 @@ class LessSettingsStruct(SerializableAttrs):
|
||||
newPostTerm: int
|
||||
postExpirationSetting: PostExpirationSetting
|
||||
|
||||
kakaoAlertIds: list[int]
|
||||
kakaoAlertIds: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -13,6 +13,8 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import SerializableAttrs
|
||||
@ -23,7 +25,7 @@ from .friend_struct import FriendStruct
|
||||
@dataclass
|
||||
class FriendBlockedListStruct(SerializableAttrs):
|
||||
total: int
|
||||
blockedFriends: list[FriendStruct]
|
||||
blockedFriends: List[FriendStruct]
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -13,6 +13,8 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import SerializableAttrs
|
||||
@ -24,7 +26,7 @@ from .friend_struct import FriendStruct
|
||||
@dataclass
|
||||
class FriendListStruct(SerializableAttrs):
|
||||
token: Long
|
||||
friends: list[FriendStruct]
|
||||
friends: List[FriendStruct]
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -25,7 +25,7 @@ from .friend_struct import FriendStruct
|
||||
@dataclass
|
||||
class FriendSearchUserListStruct(SerializableAttrs):
|
||||
count: int
|
||||
list: list[FriendStruct]
|
||||
list: List[FriendStruct]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -33,7 +33,7 @@ class FriendSearchStruct(SerializableAttrs):
|
||||
query: str
|
||||
user: Optional[FriendSearchUserListStruct] = None
|
||||
plus: Optional[FriendSearchUserListStruct] = None
|
||||
categories: list[str]
|
||||
categories: List[str]
|
||||
total_counts: int
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -34,7 +34,7 @@ class ProfileFeed(SerializableAttrs):
|
||||
serviceName: str
|
||||
typeIconUrl: str
|
||||
downloadId: str
|
||||
contents: list[ProfileFeedObject]
|
||||
contents: List[ProfileFeedObject]
|
||||
url: str
|
||||
serviceUrl: Optional[str] = None # NOTE Made optional
|
||||
webUrl: str
|
||||
@ -51,7 +51,7 @@ class ProfileFeed(SerializableAttrs):
|
||||
@dataclass(kw_only=True)
|
||||
class ProfileFeedList(SerializableAttrs):
|
||||
totalCnt: int # NOTE Renamed from "totalCnts"
|
||||
feeds: list[ProfileFeed]
|
||||
feeds: List[ProfileFeed]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -91,7 +91,7 @@ class ProfileStruct(SerializableAttrs):
|
||||
profileImageUrl: str
|
||||
fullProfileImageUrl: str
|
||||
originalProfileImageUrl: str
|
||||
decoration: list[ProfileDecoration]
|
||||
decoration: List[ProfileDecoration]
|
||||
profileFeeds: ProfileFeedList
|
||||
backgroundFeeds: ProfileFeedList
|
||||
allowStory: bool
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Generic, TypeVar, Optional
|
||||
from typing import Dict, Generic, List, Optional, TypeVar
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -37,7 +37,7 @@ class SetChannelMeta(ChannelMeta):
|
||||
authorId: Long
|
||||
updatedAt: int
|
||||
|
||||
ChannelMetaMap = dict[ChannelMetaType, SetChannelMeta] # Substitute for Record<ChannelMetaType, SetChannelMeta>
|
||||
ChannelMetaMap = Dict[ChannelMetaType, SetChannelMeta] # Substitute for Record<ChannelMetaType, SetChannelMeta>
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -50,7 +50,7 @@ class ChannelInfo(Channel):
|
||||
lastSeenLogId: Long
|
||||
lastChatLog: Optional[Chatlog] = None
|
||||
metaMap: ChannelMetaMap
|
||||
displayUserList: list[DisplayUserInfo]
|
||||
displayUserList: List[DisplayUserInfo]
|
||||
pushAlert: bool
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from attr import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
@ -149,10 +149,10 @@ BotDelCommandStruct = BotCommandStruct
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BotMetaContent(SerializableAttrs):
|
||||
add: Optional[list[BotAddCommandStruct]] = None
|
||||
update: Optional[list[BotAddCommandStruct]] = None
|
||||
full: Optional[list[BotAddCommandStruct]] = None
|
||||
delete: Optional[list[BotDelCommandStruct]] = field(json="del", default=None)
|
||||
add: Optional[List[BotAddCommandStruct]] = None
|
||||
update: Optional[List[BotAddCommandStruct]] = None
|
||||
full: Optional[List[BotAddCommandStruct]] = None
|
||||
delete: Optional[List[BotDelCommandStruct]] = field(json="del", default=None)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -25,8 +25,8 @@ from .mention import *
|
||||
@dataclass(kw_only=True)
|
||||
class Attachment(SerializableAttrs):
|
||||
shout: Optional[bool] = None
|
||||
mentions: Optional[list[MentionStruct]] = None
|
||||
urls: Optional[list[str]] = None
|
||||
mentions: Optional[List[MentionStruct]] = None
|
||||
urls: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -46,16 +46,16 @@ class PhotoAttachment(MediaKeyAttachment):
|
||||
|
||||
@dataclass
|
||||
class MultiPhotoAttachment(Attachment):
|
||||
kl: list[str]
|
||||
wl: list[int]
|
||||
hl: list[int]
|
||||
csl: list[str]
|
||||
imageUrls: list[str]
|
||||
thumbnailUrls: list[str]
|
||||
thumbnailWidths: list[int]
|
||||
thumbnailHeights: list[int]
|
||||
sl: list[int] # NOTE Changed to a list
|
||||
mtl: list[str] # NOTE Added
|
||||
kl: List[str]
|
||||
wl: List[int]
|
||||
hl: List[int]
|
||||
csl: List[str]
|
||||
imageUrls: List[str]
|
||||
thumbnailUrls: List[str]
|
||||
thumbnailWidths: List[int]
|
||||
thumbnailHeights: List[int]
|
||||
sl: List[int] # NOTE Changed to a list
|
||||
mtl: List[str] # NOTE Added
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Union
|
||||
from typing import List, Union
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -24,7 +24,7 @@ from ...bson import Long
|
||||
|
||||
@dataclass
|
||||
class MentionStruct(SerializableAttrs):
|
||||
at: list[int]
|
||||
at: List[int]
|
||||
len: int
|
||||
user_id: Union[Long, int]
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
from enum import IntEnum
|
||||
|
||||
from attr import dataclass
|
||||
@ -82,7 +82,7 @@ class PostItem:
|
||||
class Image(Unknown):
|
||||
t = KnownPostItemType.IMAGE
|
||||
tt: Optional[str] = None
|
||||
th: list[str]
|
||||
th: List[str]
|
||||
g: Optional[bool] = None
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -96,7 +96,7 @@ class PostItem:
|
||||
st: int
|
||||
tt: str
|
||||
ittpe: Optional[str] = None
|
||||
its: list[dict]
|
||||
its: List[dict]
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Video(Unknown):
|
||||
@ -119,7 +119,7 @@ class PostItem:
|
||||
@dataclass(kw_only=True)
|
||||
class PostAttachment(Attachment):
|
||||
subtype: Optional[PostSubItemType] = None
|
||||
os: list[PostItem.Unknown]
|
||||
os: List[PostItem.Unknown]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -29,7 +29,7 @@ class ReplyAttachment(Attachment):
|
||||
attach_type: Optional[ChatType] = None # NOTE Changed from int for outgoing typeless replies
|
||||
src_linkId: Optional[Long] = None
|
||||
src_logId: Long
|
||||
src_mentions: list[MentionStruct] = None # NOTE Made optional
|
||||
src_mentions: List[MentionStruct] = None # NOTE Made optional
|
||||
src_message: str
|
||||
src_type: ChatType
|
||||
src_userId: Long
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Union, Type
|
||||
from typing import Dict, Optional, Union, Type
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -62,7 +62,7 @@ class Chat(ChatTypeComponent):
|
||||
return obj
|
||||
|
||||
# TODO More
|
||||
_attachment_type_map: dict[KnownChatType, Type[Attachment]] = {
|
||||
_attachment_type_map: Dict[KnownChatType, Type[Attachment]] = {
|
||||
KnownChatType.PHOTO: PhotoAttachment,
|
||||
KnownChatType.VIDEO: VideoAttachment,
|
||||
KnownChatType.CONTACT: ContactAttachment,
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import NewType, Optional, Union
|
||||
from typing import List, NewType, Optional, Union
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@ -41,11 +41,11 @@ setattr(LoginDataItem, "deserialize", deserialize_channel_login_data_item)
|
||||
@dataclass
|
||||
class LoginResult(SerializableAttrs):
|
||||
"""Return value of TalkClient.login"""
|
||||
channelList: list[LoginDataItem]
|
||||
channelList: List[LoginDataItem]
|
||||
userId: Long
|
||||
lastTokenId: Long
|
||||
mcmRevision: int
|
||||
removedChannelIdList: list[Long]
|
||||
removedChannelIdList: List[Long]
|
||||
revision: int
|
||||
revisionInfo: str
|
||||
minLogId: Long
|
||||
|
@ -21,7 +21,7 @@ from .open_link_user_info import *
|
||||
"""
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from attr import dataclass
|
||||
from enum import IntEnum
|
||||
@ -79,7 +79,7 @@ class OpenLink(OpenLinkSettings, OpenLinkComponent, OpenTokenComponent, OpenPriv
|
||||
linkURL: str
|
||||
openToken: int
|
||||
linkOwner: "OpenLinkUserInfo"
|
||||
profileTagList: list[str]
|
||||
profileTagList: List[str]
|
||||
createdAt: int
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Generic, Type, TypeVar, Union
|
||||
from typing import Generic, List, Type, TypeVar, Union
|
||||
|
||||
from attr import dataclass
|
||||
from enum import IntEnum
|
||||
@ -94,12 +94,12 @@ ResultType = TypeVar("ResultType", bound=Serializable)
|
||||
|
||||
def ResultListType(result_type: Type[ResultType]):
|
||||
"""Custom type for setting a result to a list of serializable objects."""
|
||||
class _ResultListType(list[result_type], Serializable):
|
||||
def serialize(self) -> list[JSON]:
|
||||
class _ResultListType(List[result_type], Serializable):
|
||||
def serialize(self) -> List[JSON]:
|
||||
return [v.serialize() for v in self]
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: list[JSON]) -> "_ResultListType":
|
||||
def deserialize(cls, data: List[JSON]) -> "_ResultListType":
|
||||
return [result_type.deserialize(item) for item in data]
|
||||
|
||||
return _ResultListType
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from mautrix.bridge import BaseMatrixHandler
|
||||
from mautrix.types import (
|
||||
@ -175,7 +175,7 @@ class MatrixHandler(BaseMatrixHandler):
|
||||
await user.client.mark_read(portal.channel_props, message.ktid)
|
||||
|
||||
async def handle_ephemeral_event(
|
||||
self, evt: ReceiptEvent | Event
|
||||
self, evt: Union[ReceiptEvent, Event]
|
||||
) -> None:
|
||||
if evt.type == EventType.RECEIPT:
|
||||
await self.handle_receipt(evt)
|
||||
|
@ -22,9 +22,15 @@ from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from io import BytesIO
|
||||
@ -147,41 +153,41 @@ StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STAT
|
||||
|
||||
class Portal(DBPortal, BasePortal):
|
||||
invite_own_puppet_to_pm: bool = False
|
||||
by_mxid: dict[RoomID, Portal] = {}
|
||||
by_ktid: dict[tuple[int, int], Portal] = {}
|
||||
by_mxid: Dict[RoomID, Portal] = {}
|
||||
by_ktid: Dict[Tuple[int, int], Portal] = {}
|
||||
matrix: m.MatrixHandler
|
||||
config: Config
|
||||
|
||||
_main_intent: IntentAPI | None
|
||||
_kt_sender: Long | None
|
||||
_main_intent: Optional[IntentAPI]
|
||||
_kt_sender: Optional[Long]
|
||||
_create_room_lock: asyncio.Lock
|
||||
_send_locks: dict[int, asyncio.Lock]
|
||||
_send_locks: Dict[int, asyncio.Lock]
|
||||
_noop_lock: FakeLock = FakeLock()
|
||||
backfill_lock: SimpleLock
|
||||
_backfill_leave: set[IntentAPI] | None
|
||||
_backfill_leave: Optional[Set[IntentAPI]]
|
||||
_sleeping_to_resync: bool
|
||||
_scheduled_resync: asyncio.Task | None
|
||||
_resync_targets: dict[int, p.Puppet]
|
||||
_scheduled_resync: Optional[asyncio.Task]
|
||||
_resync_targets: Dict[int, p.Puppet]
|
||||
|
||||
_CHAT_TYPE_HANDLER_MAP: dict[ChatType, Callable[..., ACallable[list[EventID]]]]
|
||||
_STATE_EVENT_HANDLER_MAP: dict[EventType, StateEventHandler]
|
||||
_CHAT_TYPE_HANDLER_MAP: Dict[ChatType, Callable[..., ACallable[List[EventID]]]]
|
||||
_STATE_EVENT_HANDLER_MAP: Dict[EventType, StateEventHandler]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ktid: Long,
|
||||
kt_receiver: Long,
|
||||
kt_type: ChannelType,
|
||||
mxid: RoomID | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
photo_id: str | None = None,
|
||||
avatar_url: ContentURI | None = None,
|
||||
mxid: Optional[RoomID] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
photo_id: Optional[str] = None,
|
||||
avatar_url: Optional[ContentURI] = None,
|
||||
encrypted: bool = False,
|
||||
name_set: bool = False,
|
||||
topic_set: bool = False,
|
||||
avatar_set: bool = False,
|
||||
fully_read_kt_chat: Long | None = None,
|
||||
relay_user_id: UserID | None = None,
|
||||
fully_read_kt_chat: Optional[Long] = None,
|
||||
relay_user_id: Optional[UserID] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ktid,
|
||||
@ -284,7 +290,7 @@ class Portal(DBPortal, BasePortal):
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
def ktid_full(self) -> tuple[Long, Long]:
|
||||
def ktid_full(self) -> Tuple[Long, Long]:
|
||||
return self.ktid, self.kt_receiver
|
||||
|
||||
@property
|
||||
@ -302,7 +308,7 @@ class Portal(DBPortal, BasePortal):
|
||||
return KnownChannelType.is_open(self.kt_type)
|
||||
|
||||
@property
|
||||
def kt_sender(self) -> int | None:
|
||||
def kt_sender(self) -> Optional[int]:
|
||||
if self.is_direct:
|
||||
if not self._kt_sender:
|
||||
raise ValueError("Direct chat portal must set sender")
|
||||
@ -324,7 +330,7 @@ class Portal(DBPortal, BasePortal):
|
||||
raise ValueError("Portal must be postinit()ed before main_intent can be used")
|
||||
return self._main_intent
|
||||
|
||||
async def get_dm_puppet(self) -> p.Puppet | None:
|
||||
async def get_dm_puppet(self) -> Optional[p.Puppet]:
|
||||
if not self.is_direct:
|
||||
return None
|
||||
return await p.Puppet.get_by_ktid(self.kt_sender)
|
||||
@ -365,7 +371,7 @@ class Portal(DBPortal, BasePortal):
|
||||
async def update_info(
|
||||
self,
|
||||
source: u.User,
|
||||
info: PortalChannelInfo | None = None,
|
||||
info: Optional[PortalChannelInfo] = None,
|
||||
force_save: bool = False,
|
||||
) -> PortalChannelInfo:
|
||||
if not info:
|
||||
@ -380,6 +386,8 @@ class Portal(DBPortal, BasePortal):
|
||||
self._update_photo(source, info.photoURL),
|
||||
)
|
||||
)
|
||||
else:
|
||||
changed = await self._update_name(info.name)
|
||||
if info.participantInfo:
|
||||
changed = await self._update_participants(source, info.participantInfo) or changed
|
||||
if changed or force_save:
|
||||
@ -387,8 +395,8 @@ class Portal(DBPortal, BasePortal):
|
||||
await self.save()
|
||||
return info
|
||||
|
||||
async def _get_mapped_participant_power_levels(self, participants: list[UserInfoUnion]) -> dict[UserID, int]:
|
||||
user_power_levels: dict[UserID, int] = {}
|
||||
async def _get_mapped_participant_power_levels(self, participants: List[UserInfoUnion]) -> Dict[UserID, int]:
|
||||
user_power_levels: Dict[UserID, int] = {}
|
||||
for participant in participants:
|
||||
if not isinstance(participant, OpenChannelUserInfo):
|
||||
self.log.warning(f"Info object for participant {participant.userId} of open channel is not an OpenChannelUserInfo")
|
||||
@ -397,7 +405,7 @@ class Portal(DBPortal, BasePortal):
|
||||
return user_power_levels
|
||||
|
||||
@staticmethod
|
||||
async def _update_mapped_ktid_power_levels(user_power_levels: dict[UserID, int], ktid: int, perm: OpenChannelUserPerm) -> None:
|
||||
async def _update_mapped_ktid_power_levels(user_power_levels: Dict[UserID, int], ktid: int, perm: OpenChannelUserPerm) -> None:
|
||||
power_level = FROM_PERM_MAP[perm]
|
||||
user = await u.User.get_by_ktid(ktid)
|
||||
if user:
|
||||
@ -406,7 +414,7 @@ class Portal(DBPortal, BasePortal):
|
||||
if puppet:
|
||||
user_power_levels[puppet.mxid] = power_level
|
||||
|
||||
async def _set_user_power_levels(self, sender: p.Puppet | None, user_power_levels: dict[UserID, int]) -> None:
|
||||
async def _set_user_power_levels(self, sender: Optional[p.Puppet], user_power_levels: Dict[UserID, int]) -> None:
|
||||
if not self.mxid:
|
||||
return
|
||||
orig_power_levels = await self.main_intent.get_power_levels(self.mxid)
|
||||
@ -421,7 +429,7 @@ class Portal(DBPortal, BasePortal):
|
||||
}
|
||||
sender_intent = sender.intent_for(self) if sender else self.main_intent
|
||||
admin_level = orig_power_levels.get_user_level(sender_intent.mxid)
|
||||
demoter_ids: list[UserID] = []
|
||||
demoter_ids: List[UserID] = []
|
||||
power_levels = evolve(orig_power_levels)
|
||||
for user_id, new_level in user_power_levels.items():
|
||||
curr_level = orig_power_levels.get_user_level(user_id)
|
||||
@ -454,12 +462,12 @@ class Portal(DBPortal, BasePortal):
|
||||
source: u.User,
|
||||
intent: IntentAPI,
|
||||
*,
|
||||
filename: str | None = None,
|
||||
mimetype: str | None,
|
||||
filename: Optional[str] = None,
|
||||
mimetype: Optional[str],
|
||||
encrypt: bool = False,
|
||||
find_size: bool = False,
|
||||
convert_audio: bool = False,
|
||||
) -> tuple[ContentURI, FileInfo | VideoInfo | AudioInfo | ImageInfo, EncryptedFile | None]:
|
||||
) -> Tuple[ContentURI, Union[FileInfo, VideoInfo, AudioInfo, ImageInfo], Optional[EncryptedFile]]:
|
||||
if not url:
|
||||
raise ValueError("URL not provided")
|
||||
sandbox = cls.config["bridge.sandbox_media_download"]
|
||||
@ -485,12 +493,12 @@ class Portal(DBPortal, BasePortal):
|
||||
data: bytes,
|
||||
intent: IntentAPI,
|
||||
*,
|
||||
filename: str | None = None,
|
||||
mimetype: str | None = None,
|
||||
filename: Optional[str] = None,
|
||||
mimetype: Optional[str] = None,
|
||||
encrypt: bool = False,
|
||||
find_size: bool = False,
|
||||
convert_audio: bool = False,
|
||||
) -> tuple[ContentURI, FileInfo | VideoInfo | AudioInfo | ImageInfo, EncryptedFile | None]:
|
||||
) -> Tuple[ContentURI, Union[FileInfo, VideoInfo, AudioInfo, ImageInfo], Optional[EncryptedFile]]:
|
||||
if not mimetype:
|
||||
mimetype = magic.mimetype(data)
|
||||
if convert_audio and mimetype != "audio/ogg":
|
||||
@ -519,14 +527,14 @@ class Portal(DBPortal, BasePortal):
|
||||
decryption_info.url = url
|
||||
return url, info, decryption_info
|
||||
|
||||
async def _update_name(self, name: str | None) -> bool:
|
||||
async def _update_name(self, name: Optional[str]) -> bool:
|
||||
if not name:
|
||||
self.log.warning("Got empty name in _update_name call")
|
||||
return False
|
||||
if self.name != name or not self.name_set:
|
||||
self.log.trace("Updating name %s -> %s", self.name, name)
|
||||
self.name = name
|
||||
if self.mxid and (self.encrypted or not self.is_direct):
|
||||
if self.mxid:
|
||||
try:
|
||||
await self.main_intent.set_room_name(self.mxid, self.name)
|
||||
self.name_set = True
|
||||
@ -536,7 +544,7 @@ class Portal(DBPortal, BasePortal):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _update_description(self, description: str | None) -> bool:
|
||||
async def _update_description(self, description: Optional[str]) -> bool:
|
||||
if self.description != description or not self.topic_set:
|
||||
self.log.trace("Updating description %s -> %s", self.description, description)
|
||||
self.description = description
|
||||
@ -550,7 +558,7 @@ class Portal(DBPortal, BasePortal):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _update_photo(self, source: u.User, photo_id: str | None) -> bool:
|
||||
async def _update_photo(self, source: u.User, photo_id: Optional[str]) -> bool:
|
||||
if self.is_direct and not self.encrypted:
|
||||
return False
|
||||
if self.photo_id is not None and photo_id is None:
|
||||
@ -601,7 +609,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self.avatar_set = False
|
||||
return True
|
||||
|
||||
async def update_info_from_puppet(self, puppet: p.Puppet | None = None) -> bool:
|
||||
async def update_info_from_puppet(self, puppet: Optional[p.Puppet] = None) -> bool:
|
||||
if not self.is_direct:
|
||||
return False
|
||||
if not puppet:
|
||||
@ -653,7 +661,7 @@ class Portal(DBPortal, BasePortal):
|
||||
async def _update_participants(
|
||||
self,
|
||||
source: u.User,
|
||||
participant_info: PortalChannelParticipantInfo | None = None,
|
||||
participant_info: Optional[PortalChannelParticipantInfo] = None,
|
||||
) -> bool:
|
||||
# NOTE This handles only non-logged-in users, because logged-in users should be handled by the channel list listeners
|
||||
# TODO nick map?
|
||||
@ -721,13 +729,13 @@ class Portal(DBPortal, BasePortal):
|
||||
# endregion
|
||||
# region Matrix room creation
|
||||
|
||||
async def update_matrix_room(self, source: u.User, info: PortalChannelInfo | None = None) -> None:
|
||||
async def update_matrix_room(self, source: u.User, info: Optional[PortalChannelInfo] = None) -> None:
|
||||
try:
|
||||
await self._update_matrix_room(source, info)
|
||||
except Exception:
|
||||
self.log.exception("Failed to update portal")
|
||||
|
||||
def _get_invite_content(self, double_puppet: p.Puppet | None) -> dict[str, Any]:
|
||||
def _get_invite_content(self, double_puppet: Optional[p.Puppet]) -> Dict[str, Any]:
|
||||
invite_content = {}
|
||||
if double_puppet:
|
||||
invite_content["fi.mau.will_auto_accept"] = True
|
||||
@ -736,7 +744,7 @@ class Portal(DBPortal, BasePortal):
|
||||
return invite_content
|
||||
|
||||
async def _update_matrix_room(
|
||||
self, source: u.User, info: PortalChannelInfo | None = None
|
||||
self, source: u.User, info: Optional[PortalChannelInfo] = None
|
||||
) -> None:
|
||||
puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
|
||||
await self.main_intent.invite_user(
|
||||
@ -785,8 +793,8 @@ class Portal(DBPortal, BasePortal):
|
||||
await self.save()
|
||||
|
||||
async def create_matrix_room(
|
||||
self, source: u.User, info: PortalChannelInfo | None = None
|
||||
) -> RoomID | None:
|
||||
self, source: u.User, info: Optional[PortalChannelInfo] = None
|
||||
) -> Optional[RoomID]:
|
||||
if self.mxid:
|
||||
try:
|
||||
await self._update_matrix_room(source, info)
|
||||
@ -805,7 +813,7 @@ class Portal(DBPortal, BasePortal):
|
||||
return f"net.miscworks.kakaotalk://kakaotalk/{self.ktid}"
|
||||
|
||||
@property
|
||||
def bridge_info(self) -> dict[str, Any]:
|
||||
def bridge_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"bridgebot": self.az.bot_mxid,
|
||||
"creator": self.main_intent.mxid,
|
||||
@ -838,7 +846,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self.log.warning("Failed to update bridge info", exc_info=True)
|
||||
|
||||
async def _create_matrix_room(
|
||||
self, source: u.User, info: PortalChannelInfo | None = None
|
||||
self, source: u.User, info: Optional[PortalChannelInfo] = None
|
||||
) -> RoomID:
|
||||
if self.mxid:
|
||||
await self._update_matrix_room(source, info=info)
|
||||
@ -853,8 +861,8 @@ class Portal(DBPortal, BasePortal):
|
||||
assert info.participantInfo
|
||||
await self._update_participants(source, info.participantInfo)
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
initial_state = [
|
||||
{
|
||||
"type": str(StateBridge),
|
||||
@ -975,7 +983,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self._send_locks[user_id] = lock
|
||||
return lock
|
||||
|
||||
def optional_send_lock(self, user_id: int) -> asyncio.Lock | FakeLock:
|
||||
def optional_send_lock(self, user_id: int) -> Union[asyncio.Lock, FakeLock]:
|
||||
try:
|
||||
return self._send_locks[user_id]
|
||||
except KeyError:
|
||||
@ -1051,7 +1059,7 @@ class Portal(DBPortal, BasePortal):
|
||||
raise NotImplementedError(f"Unsupported message type {message.msgtype}")
|
||||
|
||||
async def _send_chat(
|
||||
self, sender: u.User, message: TextMessageEventContent, event_id: EventID | None = None
|
||||
self, sender: u.User, message: TextMessageEventContent, event_id: Optional[EventID] = None
|
||||
) -> Long:
|
||||
converted = await matrix_to_kakaotalk(message, self.mxid, self.log, self)
|
||||
try:
|
||||
@ -1065,7 +1073,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self.log.debug(f"Error handling Matrix message {event_id if event_id else '<extra>'}: {e!s}")
|
||||
raise
|
||||
|
||||
async def _make_dbm(self, event_id: EventID, ktid: Long | None = None) -> DBMessage:
|
||||
async def _make_dbm(self, event_id: EventID, ktid: Optional[Long] = None) -> DBMessage:
|
||||
dbm = DBMessage(
|
||||
mxid=event_id,
|
||||
mx_room=self.mxid,
|
||||
@ -1257,8 +1265,8 @@ class Portal(DBPortal, BasePortal):
|
||||
prev_content: PowerLevelStateEventContent,
|
||||
content: PowerLevelStateEventContent,
|
||||
) -> None:
|
||||
ktid_perms: dict[Long, OpenChannelUserPerm] = {}
|
||||
user_power_levels: dict[UserID, int] = {}
|
||||
ktid_perms: Dict[Long, OpenChannelUserPerm] = {}
|
||||
user_power_levels: Dict[UserID, int] = {}
|
||||
for user_id, level in content.users.items():
|
||||
if level == prev_content.get_user_level(user_id):
|
||||
continue
|
||||
@ -1301,7 +1309,7 @@ class Portal(DBPortal, BasePortal):
|
||||
)
|
||||
|
||||
async def _revert_matrix_power_levels(self, prev_content: PowerLevelStateEventContent) -> None:
|
||||
managed_power_levels: dict[UserID, int] = {}
|
||||
managed_power_levels: Dict[UserID, int] = {}
|
||||
for user_id, level in prev_content.users.items():
|
||||
if await p.Puppet.get_by_mxid(user_id) or await u.User.get_by_mxid(user_id):
|
||||
managed_power_levels[user_id] = level
|
||||
@ -1510,10 +1518,10 @@ class Portal(DBPortal, BasePortal):
|
||||
self,
|
||||
intent: IntentAPI,
|
||||
timestamp: int,
|
||||
chat_text: str | None,
|
||||
chat_text: Optional[str],
|
||||
chat_type: ChatType,
|
||||
**_
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
try:
|
||||
type_str = KnownChatType(chat_type).name.lower()
|
||||
except ValueError:
|
||||
@ -1543,9 +1551,9 @@ class Portal(DBPortal, BasePortal):
|
||||
async def _handle_kakaotalk_feed(
|
||||
self,
|
||||
timestamp: int,
|
||||
chat_text: str | None,
|
||||
chat_text: Optional[str],
|
||||
**_
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
self.log.info("Got feed message at %s: %s", timestamp, chat_text or "none")
|
||||
return []
|
||||
|
||||
@ -1553,18 +1561,18 @@ class Portal(DBPortal, BasePortal):
|
||||
self,
|
||||
timestamp: int,
|
||||
**_
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
self.log.info(f"Got deleted (?) message at {timestamp}")
|
||||
return []
|
||||
|
||||
async def _handle_kakaotalk_text(
|
||||
self,
|
||||
intent: IntentAPI,
|
||||
attachment: Attachment | None,
|
||||
attachment: Optional[Attachment],
|
||||
timestamp: int,
|
||||
chat_text: str | None,
|
||||
chat_text: Optional[str],
|
||||
**_
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
content = await kakaotalk_to_matrix(chat_text, attachment.mentions if attachment else None)
|
||||
return [await self._send_message(intent, content, timestamp=timestamp)]
|
||||
|
||||
@ -1575,19 +1583,19 @@ class Portal(DBPortal, BasePortal):
|
||||
timestamp: int,
|
||||
chat_text: str,
|
||||
**_
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
content = await kakaotalk_to_matrix(chat_text, attachment.mentions)
|
||||
await self._add_kakaotalk_reply(content, attachment)
|
||||
return [await self._send_message(intent, content, timestamp=timestamp)]
|
||||
|
||||
async def _handle_kakaotalk_photo(self, **kwargs) -> list[EventID]:
|
||||
async def _handle_kakaotalk_photo(self, **kwargs) -> List[EventID]:
|
||||
return [await self._handle_kakaotalk_uniphoto(**kwargs)]
|
||||
|
||||
async def _handle_kakaotalk_multiphoto(
|
||||
self,
|
||||
attachment: MultiPhotoAttachment,
|
||||
**kwargs
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
# TODO Upload media concurrently, but post messages sequentially
|
||||
return [
|
||||
await self._handle_kakaotalk_uniphoto(
|
||||
@ -1632,7 +1640,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self,
|
||||
attachment: VideoAttachment,
|
||||
**kwargs
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
return [await self._handle_kakaotalk_media(
|
||||
attachment,
|
||||
VideoInfo(
|
||||
@ -1648,7 +1656,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self,
|
||||
attachment: AudioAttachment,
|
||||
**kwargs
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
return [await self._handle_kakaotalk_media(
|
||||
attachment,
|
||||
AudioInfo(
|
||||
@ -1661,14 +1669,14 @@ class Portal(DBPortal, BasePortal):
|
||||
|
||||
async def _handle_kakaotalk_media(
|
||||
self,
|
||||
attachment: MediaAttachment | AudioAttachment,
|
||||
attachment: Union[MediaAttachment, AudioAttachment],
|
||||
info: MediaInfo,
|
||||
msgtype: MessageType,
|
||||
*,
|
||||
source: u.User,
|
||||
intent: IntentAPI,
|
||||
timestamp: int,
|
||||
chat_text: str | None,
|
||||
chat_text: Optional[str],
|
||||
**_
|
||||
) -> EventID:
|
||||
mxc, additional_info, decryption_info = await self._reupload_kakaotalk_file_from_url(
|
||||
@ -1692,9 +1700,9 @@ class Portal(DBPortal, BasePortal):
|
||||
intent: IntentAPI,
|
||||
attachment: FileAttachment,
|
||||
timestamp: int,
|
||||
chat_text: str | None,
|
||||
chat_text: Optional[str],
|
||||
**_
|
||||
) -> list[EventID]:
|
||||
) -> List[EventID]:
|
||||
data = await source.client.download_file(self.channel_props, attachment.k)
|
||||
mxc, info, decryption_info = await self._reupload_kakaotalk_file_from_bytes(
|
||||
data,
|
||||
@ -1741,7 +1749,7 @@ class Portal(DBPortal, BasePortal):
|
||||
async def handle_kakaotalk_perm_changed(
|
||||
self, source: u.User, sender: p.Puppet, user_id: Long, perm: OpenChannelUserPerm
|
||||
) -> None:
|
||||
user_power_levels: dict[UserID, int] = {}
|
||||
user_power_levels: Dict[UserID, int] = {}
|
||||
await self._update_mapped_ktid_power_levels(user_power_levels, user_id, perm)
|
||||
await self._set_user_power_levels(sender, user_power_levels)
|
||||
|
||||
@ -1755,7 +1763,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self.schedule_resync(source, user)
|
||||
|
||||
async def handle_kakaotalk_user_left(
|
||||
self, source: u.User, sender: p.Puppet | None, removed: p.Puppet
|
||||
self, source: u.User, sender: Optional[p.Puppet], removed: p.Puppet
|
||||
) -> None:
|
||||
sender_intent = sender.intent_for(self) if sender else self.main_intent
|
||||
removed_user = await u.User.get_by_ktid(removed.ktid)
|
||||
@ -1791,7 +1799,7 @@ class Portal(DBPortal, BasePortal):
|
||||
# TODO Find when or if there is a listener for this
|
||||
# TODO Confirm whether this can refer to any user that was kicked, or only to the current user
|
||||
async def handle_kakaotalk_user_unkick(
|
||||
self, source: u.User, sender: p.Puppet | None, unkicked: p.Puppet
|
||||
self, source: u.User, sender: Optional[p.Puppet], unkicked: p.Puppet
|
||||
) -> None:
|
||||
assert sender != unkicked, f"Puppet for {unkicked.mxid} tried to unkick itself"
|
||||
sender_intent = sender.intent_for(self) if sender else self.main_intent
|
||||
@ -1848,8 +1856,8 @@ class Portal(DBPortal, BasePortal):
|
||||
async def _backfill(
|
||||
self,
|
||||
source: u.User,
|
||||
limit: int | None,
|
||||
after_log_id: Long | None,
|
||||
limit: Optional[int],
|
||||
after_log_id: Optional[Long],
|
||||
) -> None:
|
||||
self.log.debug(f"Backfilling history through {source.mxid}")
|
||||
self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
|
||||
@ -1898,7 +1906,7 @@ class Portal(DBPortal, BasePortal):
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
|
||||
async def get_by_mxid(cls, mxid: RoomID) -> Optional[Portal]:
|
||||
try:
|
||||
return cls.by_mxid[mxid]
|
||||
except KeyError:
|
||||
@ -1919,8 +1927,8 @@ class Portal(DBPortal, BasePortal):
|
||||
*,
|
||||
kt_receiver: int = 0,
|
||||
create: bool = True,
|
||||
kt_type: ChannelType | None = None,
|
||||
) -> Portal | None:
|
||||
kt_type: Optional[ChannelType] = None,
|
||||
) -> Optional[Portal]:
|
||||
# TODO Direct chats are shared, so can remove kt_receiver if DM portals should be shared
|
||||
if kt_type:
|
||||
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, Dict, Optional, cast
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
|
||||
@ -46,24 +46,24 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
hs_domain: str
|
||||
mxid_template: SimpleTemplate[int]
|
||||
|
||||
by_ktid: dict[int, Puppet] = {}
|
||||
by_custom_mxid: dict[UserID, Puppet] = {}
|
||||
by_ktid: Dict[int, Puppet] = {}
|
||||
by_custom_mxid: Dict[UserID, Puppet] = {}
|
||||
|
||||
_last_info_sync: datetime | None
|
||||
_last_info_sync: Optional[datetime]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ktid: Long,
|
||||
name: str | None = None,
|
||||
photo_id: str | None = None,
|
||||
photo_mxc: ContentURI | None = None,
|
||||
name: Optional[str] = None,
|
||||
photo_id: Optional[str] = None,
|
||||
photo_mxc: Optional[ContentURI] = None,
|
||||
name_set: bool = False,
|
||||
avatar_set: bool = False,
|
||||
is_registered: bool = False,
|
||||
custom_mxid: UserID | None = None,
|
||||
access_token: str | None = None,
|
||||
next_batch: SyncToken | None = None,
|
||||
base_url: URL | None = None,
|
||||
custom_mxid: Optional[UserID] = None,
|
||||
access_token: Optional[str] = None,
|
||||
next_batch: Optional[SyncToken] = None,
|
||||
base_url: Optional[URL] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ktid,
|
||||
@ -249,7 +249,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Puppet | None:
|
||||
async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Optional[Puppet]:
|
||||
try:
|
||||
return cls.by_ktid[int]
|
||||
except KeyError:
|
||||
@ -269,7 +269,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
|
||||
async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional[Puppet]:
|
||||
ktid = cls.get_id_from_mxid(mxid)
|
||||
if ktid:
|
||||
return await cls.get_by_ktid(ktid, create=create)
|
||||
@ -277,7 +277,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
|
||||
async def get_by_custom_mxid(cls, mxid: UserID) -> Optional[Puppet]:
|
||||
try:
|
||||
return cls.by_custom_mxid[mxid]
|
||||
except KeyError:
|
||||
@ -291,7 +291,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid: UserID) -> int | None:
|
||||
def get_id_from_mxid(cls, mxid: UserID) -> Optional[int]:
|
||||
return cls.mxid_template.parse(mxid)
|
||||
|
||||
@classmethod
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@ -26,16 +26,16 @@ from mautrix.types.primitive import JSON
|
||||
from ..config import Config
|
||||
from .types import RPCError
|
||||
|
||||
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
|
||||
EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
class CancelableEvent:
|
||||
_event: asyncio.Event
|
||||
_task: asyncio.Task | None
|
||||
_task: Optional[asyncio.Task]
|
||||
_cancelled: bool
|
||||
_loop: asyncio.AbstractEventLoop
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop | None):
|
||||
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]):
|
||||
self._event = asyncio.Event()
|
||||
self._task = None
|
||||
self._cancelled = False
|
||||
@ -74,15 +74,15 @@ class RPCClient:
|
||||
loop: asyncio.AbstractEventLoop
|
||||
log: logging.Logger = logging.getLogger("mau.rpc")
|
||||
|
||||
_reader: asyncio.StreamReader | None
|
||||
_writer: asyncio.StreamWriter | None
|
||||
_reader: Optional[asyncio.StreamReader]
|
||||
_writer: Optional[asyncio.StreamWriter]
|
||||
_req_id: int
|
||||
_min_broadcast_id: int
|
||||
_response_waiters: dict[int, asyncio.Future[JSON]]
|
||||
_event_handlers: dict[str, list[EventHandler]]
|
||||
_response_waiters: Dict[int, asyncio.Future[JSON]]
|
||||
_event_handlers: Dict[str, List[EventHandler]]
|
||||
_command_queue: asyncio.Queue
|
||||
_read_task: asyncio.Task | None
|
||||
_connection_task: asyncio.Task | None
|
||||
_read_task: Optional[asyncio.Task]
|
||||
_connection_task: Optional[asyncio.Task]
|
||||
_is_connected: CancelableEvent
|
||||
_is_disconnected: CancelableEvent
|
||||
_connection_lock: asyncio.Lock
|
||||
@ -203,10 +203,10 @@ class RPCClient:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
|
||||
def set_event_handlers(self, method: str, handlers: List[EventHandler]) -> None:
|
||||
self._event_handlers[method] = handlers
|
||||
|
||||
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
|
||||
async def _run_event_handler(self, req_id: int, command: str, req: Dict[str, Any]) -> None:
|
||||
if req_id > self._min_broadcast_id:
|
||||
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
|
||||
return
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, Dict, List, Optional, cast
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
@ -92,22 +92,22 @@ class User(DBUser, BaseUser):
|
||||
shutdown: bool = False
|
||||
config: Config
|
||||
|
||||
by_mxid: dict[UserID, User] = {}
|
||||
by_ktid: dict[int, User] = {}
|
||||
by_mxid: Dict[UserID, User] = {}
|
||||
by_ktid: Dict[int, User] = {}
|
||||
|
||||
_client: Client | None
|
||||
_client: Optional[Client]
|
||||
|
||||
_notice_room_lock: asyncio.Lock
|
||||
_notice_send_lock: asyncio.Lock
|
||||
is_admin: bool
|
||||
permission_level: str
|
||||
_is_logged_in: bool | None
|
||||
_is_connected: bool | None
|
||||
_is_logged_in: Optional[bool]
|
||||
_is_connected: Optional[bool]
|
||||
_connection_time: float
|
||||
_db_instance: DBUser | None
|
||||
_db_instance: Optional[DBUser]
|
||||
_sync_lock: SimpleLock
|
||||
_is_rpc_reconnecting: bool
|
||||
_logged_in_info: SettingsStruct | None
|
||||
_logged_in_info: Optional[SettingsStruct]
|
||||
_logged_in_info_time: float
|
||||
|
||||
def __init__(
|
||||
@ -115,11 +115,11 @@ class User(DBUser, BaseUser):
|
||||
mxid: UserID,
|
||||
force_login: bool,
|
||||
was_connected: bool,
|
||||
ktid: Long | None = None,
|
||||
uuid: str | None = None,
|
||||
access_token: str | None = None,
|
||||
refresh_token: str | None = None,
|
||||
notice_room: RoomID | None = None,
|
||||
ktid: Optional[Long] = None,
|
||||
uuid: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
notice_room: Optional[RoomID] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
mxid=mxid,
|
||||
@ -168,11 +168,11 @@ class User(DBUser, BaseUser):
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool | None:
|
||||
def is_connected(self) -> Optional[bool]:
|
||||
return self._is_connected
|
||||
|
||||
@is_connected.setter
|
||||
def is_connected(self, val: bool | None) -> None:
|
||||
def is_connected(self, val: Optional[bool]) -> None:
|
||||
if self._is_connected != val:
|
||||
self._is_connected = val
|
||||
self._connection_time = time.monotonic()
|
||||
@ -209,7 +209,7 @@ class User(DBUser, BaseUser):
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> User | None:
|
||||
async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> Optional[User]:
|
||||
if pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid:
|
||||
return None
|
||||
try:
|
||||
@ -233,7 +233,7 @@ class User(DBUser, BaseUser):
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_ktid(cls, ktid: int) -> User | None:
|
||||
async def get_by_ktid(cls, ktid: int) -> Optional[User]:
|
||||
try:
|
||||
return cls.by_ktid[ktid]
|
||||
except KeyError:
|
||||
@ -281,7 +281,7 @@ class User(DBUser, BaseUser):
|
||||
self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
|
||||
self.uuid = oauth_credential.deviceUUID
|
||||
|
||||
async def get_own_info(self, *, force: bool = False) -> SettingsStruct | None:
|
||||
async def get_own_info(self, *, force: bool = False) -> Optional[SettingsStruct]:
|
||||
if force or self._client and (not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic()):
|
||||
self._logged_in_info = await self._client.get_settings()
|
||||
self._logged_in_info_time = time.monotonic()
|
||||
@ -297,7 +297,7 @@ class User(DBUser, BaseUser):
|
||||
error="logged-out",
|
||||
)
|
||||
return False
|
||||
latest_exc: Exception | None = None
|
||||
latest_exc: Optional[Exception] = None
|
||||
password_ok = False
|
||||
try:
|
||||
creds = await LoginCredential.get_by_mxid(self.mxid)
|
||||
@ -357,7 +357,7 @@ class User(DBUser, BaseUser):
|
||||
asyncio.create_task(self.post_login(is_startup=is_startup, is_token_login=not password_ok))
|
||||
return True
|
||||
|
||||
async def _send_reset_notice(self, e: OAuthException, edit: EventID | None = None) -> None:
|
||||
async def _send_reset_notice(self, e: OAuthException, edit: Optional[EventID] = None) -> None:
|
||||
await self.send_bridge_notice(
|
||||
"Got authentication error from KakaoTalk:\n\n"
|
||||
f"> {e.message}\n\n"
|
||||
@ -389,7 +389,7 @@ class User(DBUser, BaseUser):
|
||||
return self._is_logged_in
|
||||
|
||||
async def reload_session(
|
||||
self, event_id: EventID | None = None, retries: int = 3, is_startup: bool = False
|
||||
self, event_id: Optional[EventID] = None, retries: int = 3, is_startup: bool = False
|
||||
) -> None:
|
||||
try:
|
||||
if not await self._load_session(is_startup=is_startup) and self.has_state:
|
||||
@ -487,7 +487,7 @@ class User(DBUser, BaseUser):
|
||||
self.was_connected = False
|
||||
await self.save()
|
||||
|
||||
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
|
||||
async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]:
|
||||
return {
|
||||
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
|
||||
async for portal in po.Portal.get_all_by_receiver(self.ktid)
|
||||
@ -495,7 +495,7 @@ class User(DBUser, BaseUser):
|
||||
} if self.ktid else {}
|
||||
|
||||
@async_time(METRIC_CONNECT_AND_SYNC)
|
||||
async def connect_and_sync(self, sync_count: int | None, force_sync: bool) -> bool:
|
||||
async def connect_and_sync(self, sync_count: Optional[int], force_sync: bool) -> bool:
|
||||
# TODO Look for a way to sync all channels without (re-)logging in
|
||||
try:
|
||||
login_result = await self.client.connect()
|
||||
@ -517,7 +517,7 @@ class User(DBUser, BaseUser):
|
||||
await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e))
|
||||
return False
|
||||
|
||||
async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None:
|
||||
async def _sync_channels(self, login_result: LoginResult, sync_count: Optional[int]) -> None:
|
||||
if sync_count is None:
|
||||
sync_count = self.config["bridge.initial_chat_sync"]
|
||||
if not sync_count:
|
||||
@ -625,12 +625,12 @@ class User(DBUser, BaseUser):
|
||||
async def send_bridge_notice(
|
||||
self,
|
||||
text: str,
|
||||
edit: EventID | None = None,
|
||||
state_event: BridgeStateEvent | None = None,
|
||||
edit: Optional[EventID] = None,
|
||||
state_event: Optional[BridgeStateEvent] = None,
|
||||
important: bool = False,
|
||||
error_code: str | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> EventID | None:
|
||||
error_code: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Optional[EventID]:
|
||||
if state_event:
|
||||
await self.push_bridge_state(
|
||||
state_event,
|
||||
@ -662,7 +662,7 @@ class User(DBUser, BaseUser):
|
||||
puppet = await pu.Puppet.get_by_ktid(self.ktid)
|
||||
state.remote_name = puppet.name
|
||||
|
||||
async def get_bridge_states(self) -> list[BridgeState]:
|
||||
async def get_bridge_states(self) -> List[BridgeState]:
|
||||
if not self.has_state:
|
||||
return []
|
||||
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
|
||||
@ -672,12 +672,12 @@ class User(DBUser, BaseUser):
|
||||
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
|
||||
return [state]
|
||||
|
||||
async def get_puppet(self) -> pu.Puppet | None:
|
||||
async def get_puppet(self) -> Optional[pu.Puppet]:
|
||||
if not self.ktid:
|
||||
return None
|
||||
return await pu.Puppet.get_by_ktid(self.ktid)
|
||||
|
||||
async def get_portal_with(self, puppet: pu.Puppet, create: bool = True) -> po.Portal | None:
|
||||
async def get_portal_with(self, puppet: pu.Puppet, create: bool = True) -> Optional[po.Portal]:
|
||||
# TODO Make upstream request to return custom failure message
|
||||
if not self.ktid or not self.is_connected:
|
||||
return None
|
||||
@ -732,7 +732,7 @@ class User(DBUser, BaseUser):
|
||||
await self.save()
|
||||
return not skip_sync
|
||||
|
||||
async def on_disconnect(self, res: KickoutRes | None) -> None:
|
||||
async def on_disconnect(self, res: Optional[KickoutRes]) -> None:
|
||||
self.is_connected = False
|
||||
self._track_metric(METRIC_CONNECTED, False)
|
||||
logout = not self.config["bridge.remain_logged_in_on_disconnect"]
|
||||
@ -934,7 +934,7 @@ class User(DBUser, BaseUser):
|
||||
await self._sync_channel(channel_info)
|
||||
|
||||
@async_time(METRIC_CHANNEL_LEFT)
|
||||
async def on_channel_left(self, channel_id: Long, channel_type: ChannelType | None) -> None:
|
||||
async def on_channel_left(self, channel_id: Long, channel_type: Optional[ChannelType]) -> None:
|
||||
assert self.ktid
|
||||
portal = await po.Portal.get_by_ktid(
|
||||
channel_id,
|
||||
|
Loading…
Reference in New Issue
Block a user