Compare commits

...

2 Commits

Author SHA1 Message Date
99ea716731 3.8-ify one last thing 2022-07-14 23:49:57 -04:00
adb7453e1a Actually make compatible with Python 3.8
- Replace builtin generic type annotations with classes from Typing
- Replace union type expressions with Union and Optional
2022-07-14 01:41:24 -04:00
35 changed files with 323 additions and 310 deletions

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any
from typing import Any, Dict
from mautrix.bridge import Bridge
from mautrix.types import RoomID, UserID
@ -46,7 +46,7 @@ class KakaoTalkBridge(Bridge):
config: Config
matrix: MatrixHandler
#public_website: PublicBridgeWebsite | None
#public_website: Optional[PublicBridgeWebsite] None
def prepare_config(self)->None:
super().prepare_config()
@ -117,7 +117,7 @@ class KakaoTalkBridge(Bridge):
async def count_logged_in_users(self) -> int:
return len([user for user in User.by_ktid.values() if user.ktid])
async def manhole_global_namespace(self, user_id: UserID) -> dict[str, Any]:
async def manhole_global_namespace(self, user_id: UserID) -> Dict[str, Any]:
return {
**await super().manhole_global_namespace(user_id),
"User": User,

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Awaitable
from typing import TYPE_CHECKING, Awaitable, Optional
import asyncio
from mautrix.bridge.commands import HelpSection, command_handler
@ -209,7 +209,7 @@ class MentionFormatString(EntityString[SimpleEntity, EntityType], MarkdownString
class MentionParser(MatrixParser[MentionFormatString]):
fs = MentionFormatString
async def _get_id_from_mxid(mxid: UserID) -> Long | None:
async def _get_id_from_mxid(mxid: UserID) -> Optional[Long]:
user = await u.User.get_by_mxid(mxid, create=False)
if user and user.ktid:
return user.ktid
@ -294,7 +294,7 @@ async def _edit_friend_by_uuid(evt: CommandEvent, uuid: str, add: bool) -> None:
else:
await _on_friend_edited(evt, friend_struct, add)
async def _on_friend_edited(evt: CommandEvent, friend_struct: FriendStruct | None, add: bool):
async def _on_friend_edited(evt: CommandEvent, friend_struct: Optional[FriendStruct], add: bool):
await evt.reply(f"Friend {'added' if add else 'removed'}")
if friend_struct:
puppet = await pu.Puppet.get_by_ktid(friend_struct.userId)

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any
from typing import Any, List, Tuple
import os
from mautrix.bridge.config import BaseBridgeConfig
@ -31,7 +31,7 @@ class Config(BaseBridgeConfig):
return super().__getitem__(key)
@property
def forbidden_defaults(self) -> list[ForbiddenDefault]:
def forbidden_defaults(self) -> List[ForbiddenDefault]:
return [
*super().forbidden_defaults,
ForbiddenDefault("appservice.database", "postgres://username:password@hostname/db"),
@ -118,14 +118,14 @@ class Config(BaseBridgeConfig):
copy("rpc.connection.host")
copy("rpc.connection.port")
def _get_permissions(self, key: str) -> tuple[bool, bool, bool, str]:
def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, str]:
level = self["bridge.permissions"].get(key, "")
admin = level == "admin"
user = level == "user" or admin
relay = level == "relay" or user
return relay, user, admin, level
def get_permissions(self, mxid: UserID) -> tuple[bool, bool, bool, str]:
def get_permissions(self, mxid: UserID) -> Tuple[bool, bool, bool, str]:
permissions = self["bridge.permissions"] or {}
if mxid in permissions:
return self._get_permissions(mxid)

View File

@ -15,9 +15,8 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, Dict, Optional
from asyncpg import Record
from attr import dataclass
from mautrix.types import UserID
@ -35,7 +34,7 @@ class LoginCredential:
password: str
@classmethod
async def get_by_mxid(cls, mxid: UserID) -> LoginCredential | None:
async def get_by_mxid(cls, mxid: UserID) -> Optional[LoginCredential]:
q = "SELECT mxid, email, password FROM login_credential WHERE mxid=$1"
row = await cls.db.fetchrow(q, mxid)
return cls(**row) if row else None
@ -51,7 +50,7 @@ class LoginCredential:
"""
await self.db.execute(q, self.mxid, self.email, self.password)
def get_form(self) -> dict[str, str]:
def get_form(self) -> Dict[str, str]:
return {
"email": self.email,
"password": self.password,

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, List, Optional
from asyncpg import Record
from attr import dataclass, field
@ -34,7 +34,7 @@ class Message:
mxid: EventID
mx_room: RoomID
ktid: Long | None = field(converter=to_optional_long)
ktid: Optional[Long] = field(converter=to_optional_long)
index: int
kt_channel: Long = field(converter=Long)
kt_receiver: Long = field(converter=Long)
@ -45,19 +45,19 @@ class Message:
return cls(**row)
@classmethod
def _from_optional_row(cls, row: Record | None) -> Message | None:
def _from_optional_row(cls, row: Optional[Record]) -> Optional[Message]:
return cls._from_row(row) if row is not None else None
columns = 'mxid, mx_room, ktid, "index", kt_channel, kt_receiver, timestamp'
@classmethod
async def get_all_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int) -> list[Message]:
async def get_all_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int) -> List[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3"
rows = await cls.db.fetch(q, ktid, kt_channel, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
async def get_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int, index: int = 0) -> Message | None:
async def get_by_ktid(cls, ktid: int, kt_channel: int, kt_receiver: int, index: int = 0) -> Optional[Message]:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_channel=$2 AND kt_receiver=$3 AND "index"=$4'
row = await cls.db.fetchrow(q, ktid, kt_channel, kt_receiver, index)
return cls._from_optional_row(row)
@ -67,13 +67,13 @@ class Message:
await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
@classmethod
async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional[Message]:
q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2"
row = await cls.db.fetchrow(q, mxid, mx_room)
return cls._from_optional_row(row)
@classmethod
async def get_most_recent(cls, kt_channel: int, kt_receiver: int) -> Message | None:
async def get_most_recent(cls, kt_channel: int, kt_receiver: int) -> Optional[Message]:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
@ -85,7 +85,7 @@ class Message:
@classmethod
async def get_closest_before(
cls, kt_channel: int, kt_receiver: int, ktid: int
) -> Message | None:
) -> Optional[Message]:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid<=$3 "
@ -95,7 +95,7 @@ class Message:
return cls._from_optional_row(row)
@classmethod
async def get_all_since(cls, kt_channel: int, kt_receiver: int, since_ktid: Long | None) -> list[Message]:
async def get_all_since(cls, kt_channel: int, kt_receiver: int, since_ktid: Optional[Long]) -> List[Message]:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_channel=$1 AND kt_receiver=$2 AND ktid>=$3 "
@ -116,10 +116,10 @@ class Message:
ktid: Long,
kt_channel: Long,
kt_receiver: Long,
event_ids: list[EventID],
event_ids: List[EventID],
timestamp: int,
mx_room: RoomID,
) -> list[Message]:
) -> List[Message]:
if not event_ids:
return []
columns = [col.strip('"') for col in cls.columns.split(", ")]

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, List, Optional
from asyncpg import Record
from attr import dataclass, field
@ -36,24 +36,24 @@ class Portal:
ktid: Long = field(converter=Long)
kt_receiver: Long = field(converter=Long)
kt_type: ChannelType
mxid: RoomID | None
name: str | None
description: str | None
photo_id: str | None
avatar_url: ContentURI | None
mxid: Optional[RoomID]
name: Optional[str]
description: Optional[str]
photo_id: Optional[str]
avatar_url: Optional[ContentURI]
encrypted: bool
name_set: bool
topic_set: bool
avatar_set: bool
fully_read_kt_chat: Long | None = field(converter=to_optional_long)
relay_user_id: UserID | None
fully_read_kt_chat: Optional[Long] = field(converter=to_optional_long)
relay_user_id: Optional[UserID]
@classmethod
def _from_row(cls, row: Record) -> Portal:
return cls(**row)
@classmethod
def _from_optional_row(cls, row: Record | None) -> Portal | None:
def _from_optional_row(cls, row: Optional[Record]) -> Optional[Portal]:
return cls._from_row(row) if row is not None else None
_columns = (
@ -62,25 +62,25 @@ class Portal:
)
@classmethod
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Optional[Portal]:
q = f"SELECT {cls._columns} FROM portal WHERE ktid=$1 AND kt_receiver=$2"
row = await cls.db.fetchrow(q, ktid, kt_receiver)
return cls._from_optional_row(row)
@classmethod
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
async def get_by_mxid(cls, mxid: RoomID) -> Optional[Portal]:
q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
row = await cls.db.fetchrow(q, mxid)
return cls._from_optional_row(row)
@classmethod
async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
async def get_all_by_receiver(cls, kt_receiver: int) -> List[Portal]:
q = f"SELECT {cls._columns} FROM portal WHERE kt_receiver=$1"
rows = await cls.db.fetch(q, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
async def all(cls) -> list[Portal]:
async def all(cls) -> List[Portal]:
q = f"SELECT {cls._columns} FROM portal"
rows = await cls.db.fetch(q)
return [cls._from_row(row) for row in rows if row]

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, List, Optional
from asyncpg import Record
from attr import dataclass, field
@ -34,17 +34,17 @@ class Puppet:
db: ClassVar[Database] = fake_db
ktid: Long = field(converter=Long)
name: str | None
photo_id: str | None
photo_mxc: ContentURI | None
name: Optional[str]
photo_id: Optional[str]
photo_mxc: Optional[ContentURI]
name_set: bool
avatar_set: bool
is_registered: bool
custom_mxid: UserID | None
access_token: str | None
next_batch: SyncToken | None
base_url: URL | None
custom_mxid: Optional[UserID]
access_token: Optional[str]
next_batch: Optional[SyncToken]
base_url: Optional[URL]
@classmethod
def _from_row(cls, row: Record) -> Puppet:
@ -53,11 +53,11 @@ class Puppet:
return cls(**data, base_url=URL(base_url) if base_url else None)
@classmethod
def _from_optional_row(cls, row: Record | None) -> Puppet | None:
def _from_optional_row(cls, row: Optional[Record]) -> Optional[Puppet]:
return cls._from_row(row) if row is not None else None
@classmethod
async def get_by_ktid(cls, ktid: int) -> Puppet | None:
async def get_by_ktid(cls, ktid: int) -> Optional[Puppet]:
q = (
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
" custom_mxid, access_token, next_batch, base_url "
@ -67,7 +67,7 @@ class Puppet:
return cls._from_optional_row(row)
@classmethod
async def get_by_name(cls, name: str) -> Puppet | None:
async def get_by_name(cls, name: str) -> Optional[Puppet]:
q = (
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
" custom_mxid, access_token, next_batch, base_url "
@ -77,7 +77,7 @@ class Puppet:
return cls._from_optional_row(row)
@classmethod
async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
async def get_by_custom_mxid(cls, mxid: UserID) -> Optional[Puppet]:
q = (
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
" custom_mxid, access_token, next_batch, base_url "
@ -87,7 +87,7 @@ class Puppet:
return cls._from_optional_row(row)
@classmethod
async def get_all_with_custom_mxid(cls) -> list[Puppet]:
async def get_all_with_custom_mxid(cls) -> List[Puppet]:
q = (
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
" custom_mxid, access_token, next_batch, base_url "

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar, List, Set
from typing import TYPE_CHECKING, ClassVar, List, Optional, Set
from asyncpg import Record
from attr import dataclass, field
@ -35,18 +35,18 @@ class User:
mxid: UserID
force_login: bool
was_connected: bool
ktid: Long | None = field(converter=to_optional_long)
uuid: str | None
access_token: str | None
refresh_token: str | None
notice_room: RoomID | None
ktid: Optional[Long] = field(converter=to_optional_long)
uuid: Optional[str]
access_token: Optional[str]
refresh_token: Optional[str]
notice_room: Optional[RoomID]
@classmethod
def _from_row(cls, row: Record) -> User:
return cls(**row)
@classmethod
def _from_optional_row(cls, row: Record | None) -> User | None:
def _from_optional_row(cls, row: Optional[Record]) -> Optional[User]:
return cls._from_row(row) if row is not None else None
_columns = "mxid, force_login, was_connected, ktid, uuid, access_token, refresh_token, notice_room"
@ -58,13 +58,13 @@ class User:
return [cls._from_row(row) for row in rows if row]
@classmethod
async def get_by_ktid(cls, ktid: int) -> User | None:
async def get_by_ktid(cls, ktid: int) -> Optional[User]:
q = f'SELECT {cls._columns} FROM "user" WHERE ktid=$1'
row = await cls.db.fetchrow(q, ktid)
return cls._from_optional_row(row)
@classmethod
async def get_by_mxid(cls, mxid: UserID) -> User | None:
async def get_by_mxid(cls, mxid: UserID) -> Optional[User]:
q = f'SELECT {cls._columns} FROM "user" WHERE mxid=$1'
row = await cls.db.fetchrow(q, mxid)
return cls._from_optional_row(row)

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Match
from typing import List, Match, Optional
import re
from mautrix.types import Format, MessageType, TextMessageEventContent
@ -28,7 +28,7 @@ from .. import puppet as pu, user as u
MENTION_REGEX = re.compile(r"@(\d+)\u2063(.+?)\u2063")
async def kakaotalk_to_matrix(msg: str | None, mentions: list[MentionStruct] | None) -> TextMessageEventContent:
async def kakaotalk_to_matrix(msg: Optional[str], mentions: Optional[List[MentionStruct]]) -> TextMessageEventContent:
# TODO Shouts
text = msg or ""
content = TextMessageEventContent(msgtype=MessageType.TEXT, body=text)

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import NamedTuple
from typing import Dict, List, NamedTuple, Optional
from mautrix.types import Format, MessageEventContent, RelationType, RoomID, UserID
from mautrix.util import utf16_surrogate
@ -40,7 +40,7 @@ from ..db import Message as DBMessage
class SendParams(NamedTuple):
text: str
mentions: list[MentionStruct] | None
mentions: Optional[List[MentionStruct]]
reply_to: ReplyAttachment
@ -89,7 +89,7 @@ class ToKakaoTalkParser(MatrixParser[KakaoTalkFormatString]):
fs = KakaoTalkFormatString
async def _get_id_from_mxid(mxid: UserID, portal: po.Portal) -> Long | None:
async def _get_id_from_mxid(mxid: UserID, portal: po.Portal) -> Optional[Long]:
orig_sender = await u.User.get_by_mxid(mxid, create=False)
if orig_sender and orig_sender.ktid:
return orig_sender.ktid
@ -161,7 +161,7 @@ async def matrix_to_kakaotalk(
):
parsed = await ToKakaoTalkParser().parse(utf16_surrogate.add(content["formatted_body"]))
text = utf16_surrogate.remove(parsed.text)
mentions_by_user: dict[Long, MentionStruct] = {}
mentions_by_user: Dict[Long, MentionStruct] = {}
# Make sure to not create remote mentions for any remote user not in the room
if parsed.entities:
joined_members = set(await portal.main_intent.get_room_members(room_id))
@ -189,7 +189,7 @@ async def matrix_to_kakaotalk(
return SendParams(text=text, mentions=mentions, reply_to=reply_to)
_media_type_reply_body_map: dict[KnownChatType, str] = {
_media_type_reply_body_map: Dict[KnownChatType, str] = {
KnownChatType.PHOTO: "Photo",
KnownChatType.VIDEO: "Video",
KnownChatType.AUDIO: "Voice Note",

View File

@ -22,7 +22,7 @@ with any other potential backend.
from __future__ import annotations
from typing import TYPE_CHECKING, cast, Awaitable, Type, Optional, Union
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Set, Type, Union, cast
import asyncio
from contextlib import asynccontextmanager
import logging
@ -122,7 +122,7 @@ class Client:
# region tokenless commands
@classmethod
async def generate_uuid(cls, used_uuids: set[str]) -> str:
async def generate_uuid(cls, used_uuids: Set[str]) -> str:
"""Randomly generate a UUID for a (fake) device."""
tries_remaining = 10
while True:
@ -160,10 +160,10 @@ class Client:
user: u.User
_rpc_disconnection_task: asyncio.Task | None
_rpc_disconnection_task: Optional[asyncio.Task]
http: ClientSession
log: TraceLogger
_handler_methods: list[str]
_handler_methods: List[str]
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
"""Create a per-user client object for user-specific client functionality."""
@ -195,7 +195,7 @@ class Client:
def get(
self,
url: Union[str, URL],
headers: Optional[dict[str, str]] = None,
headers: Optional[Dict[str, str]] = None,
sandbox: bool = False,
**kwargs,
) -> _RequestContextManager:
@ -215,7 +215,7 @@ class Client:
# region post-token commands
async def start(self) -> SettingsStruct | None:
async def start(self) -> Optional[SettingsStruct]:
"""
Initialize user-specific bridging & state by providing a token obtained from a prior login.
Receive the user's profile info in response.
@ -252,7 +252,7 @@ class Client:
self.user.oauth_credential = oauth_info.credential
await self.user.save()
async def connect(self) -> LoginResult | None:
async def connect(self) -> Optional[LoginResult]:
"""
Start a new talk session by providing a token obtained from a prior login.
Receive a snapshot of account state in response.
@ -304,14 +304,14 @@ class Client:
channel_props=channel_props.serialize(),
)
def get_participants(self, channel_props: ChannelProps) -> Awaitable[list[UserInfoUnion]]:
def get_participants(self, channel_props: ChannelProps) -> Awaitable[List[UserInfoUnion]]:
return self._api_user_request_result(
ResultListType(UserInfoUnion),
"get_participants",
channel_props=channel_props.serialize(),
)
def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> Awaitable[list[Chatlog]]:
def get_chats(self, channel_props: ChannelProps, sync_from: Optional[Long], limit: Optional[int]) -> Awaitable[List[Chatlog]]:
return self._api_user_request_result(
ResultListType(Chatlog),
"get_chats",
@ -320,7 +320,7 @@ class Client:
limit=limit,
)
def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]:
def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: List[Long]) -> Awaitable[List[Receipt]]:
return self._api_user_request_result(
ResultListType(Receipt),
"get_read_receipts",
@ -348,7 +348,7 @@ class Client:
"list_friends",
)
async def edit_friend(self, ktid: Long, add: bool) -> FriendStruct | None:
async def edit_friend(self, ktid: Long, add: bool) -> Optional[FriendStruct]:
try:
friend_req_struct = await self._api_user_request_result(
FriendReqStruct,
@ -361,7 +361,7 @@ class Client:
self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
return None
async def edit_friend_by_uuid(self, uuid: str, add: bool) -> FriendStruct | None:
async def edit_friend_by_uuid(self, uuid: str, add: bool) -> Optional[FriendStruct]:
try:
friend_req_struct = await self._api_user_request_result(
FriendReqStruct,
@ -374,7 +374,7 @@ class Client:
self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
return None
async def get_friend_dm_id(self, friend_id: Long) -> Long | None:
async def get_friend_dm_id(self, friend_id: Long) -> Optional[Long]:
try:
return await self._api_user_request_result(
Long,
@ -385,7 +385,7 @@ class Client:
self.log.exception(f"Could not find friend with ID {friend_id}")
return None
async def get_memo_ids(self) -> list[Long]:
async def get_memo_ids(self) -> List[Long]:
return ResultListType(Long).deserialize(
await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid)
)
@ -406,8 +406,8 @@ class Client:
self,
channel_props: ChannelProps,
text: str,
reply_to: ReplyAttachment | None,
mentions: list[MentionStruct] | None,
reply_to: Optional[ReplyAttachment],
mentions: Optional[List[MentionStruct]],
) -> Awaitable[Long]:
return self._api_user_request_result(
Long,
@ -425,9 +425,9 @@ class Client:
data: bytes,
filename: str,
*,
width: int | None = None,
height: int | None = None,
ext: str | None = None,
width: Optional[int] = None,
height: Optional[int] = None,
ext: Optional[str] = None,
) -> Awaitable[Long]:
return self._api_user_request_result(
Long,
@ -584,14 +584,14 @@ class Client:
# region listeners
def _on_chat(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_chat(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_chat(
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
def _on_chat_deleted(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_chat_deleted(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_chat_deleted(
Long.deserialize(data["chatId"]),
Long.deserialize(data["senderId"]),
@ -600,7 +600,7 @@ class Client:
str(data["channelType"]),
)
def _on_chat_read(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_chat_read(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_chat_read(
Long.deserialize(data["chatId"]),
Long.deserialize(data["senderId"]),
@ -608,12 +608,12 @@ class Client:
str(data["channelType"]),
)
def _on_profile_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_profile_changed(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_profile_changed(
OpenLinkChannelUserInfo.deserialize(data["info"]),
)
def _on_perm_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_perm_changed(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_perm_changed(
Long.deserialize(data["userId"]),
OpenChannelUserPerm(data["perm"]),
@ -622,23 +622,23 @@ class Client:
str(data["channelType"]),
)
def _on_channel_added(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_channel_added(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_channel_added(
ChannelInfo.deserialize(data["channelInfo"]),
)
def _on_channel_join(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_channel_join(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_channel_join(
ChannelInfo.deserialize(data["channelInfo"]),
)
def _on_channel_left(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_channel_left(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_channel_left(
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
def _on_channel_kicked(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_channel_kicked(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_channel_kicked(
Long.deserialize(data["userId"]),
Long.deserialize(data["senderId"]),
@ -646,21 +646,21 @@ class Client:
str(data["channelType"]),
)
def _on_user_join(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_user_join(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_user_join(
Long.deserialize(data["userId"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
def _on_user_left(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_user_left(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_user_left(
Long.deserialize(data["userId"]),
Long.deserialize(data["channelId"]),
str(data["channelType"]),
)
def _on_channel_meta_change(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_channel_meta_change(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_channel_meta_change(
PortalChannelInfo.deserialize(data["info"]),
Long.deserialize(data["channelId"]),
@ -668,7 +668,7 @@ class Client:
)
def _on_listen_disconnect(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_listen_disconnect(self, data: Dict[str, JSON]) -> Awaitable[None]:
try:
res = KickoutRes.deserialize(data)
except Exception:
@ -676,18 +676,18 @@ class Client:
res = None
return self._on_disconnect(res)
def _on_switch_server(self, _: dict[str, JSON]) -> Awaitable[None]:
def _on_switch_server(self, _: Dict[str, JSON]) -> Awaitable[None]:
# TODO Reconnect automatically instead
return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
def _on_disconnect(self, res: KickoutRes | None) -> Awaitable[None]:
def _on_disconnect(self, res: Optional[KickoutRes]) -> Awaitable[None]:
self._stop_listen()
return self.user.on_disconnect(res)
def _on_error(self, data: dict[str, JSON]) -> Awaitable[None]:
def _on_error(self, data: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_error(data)
def _on_unexpected_disconnect(self, _: dict[str, JSON]) -> Awaitable[None]:
def _on_unexpected_disconnect(self, _: Dict[str, JSON]) -> Awaitable[None]:
return self.user.on_unexpected_disconnect()

View File

@ -16,7 +16,7 @@
"""Internal helpers for error handling."""
from __future__ import annotations
from typing import NoReturn, Type
from typing import Dict, NoReturn, Type, Union
from .errors import (
CommandException,
@ -35,7 +35,7 @@ def raise_unsuccessful_response(resp: RootCommandResult) -> NoReturn:
raise _error_code_class_map.get(resp.status, CommandException)(resp)
_error_code_class_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int, Type[CommandException]] = {
_error_code_class_map: Dict[Union[KnownAuthStatusCode, KnownDataStatusCode, int], Type[CommandException]] = {
#KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number",
#KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success",
#KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)",

View File

@ -16,6 +16,8 @@
"""Helper functions & types for status codes for the KakaoTalk API."""
from __future__ import annotations
from typing import Dict, Union
from ..types.api.auth_api_client import KnownAuthStatusCode
from ..types.request import KnownDataStatusCode, RootCommandResult
@ -72,7 +74,7 @@ class ResponseError(Exception):
pass
_status_code_message_map: dict[KnownAuthStatusCode | KnownDataStatusCode | int, str] = {
_status_code_message_map: Dict[Union[KnownAuthStatusCode, KnownDataStatusCode, int], str] = {
KnownAuthStatusCode.INVALID_PHONE_NUMBER: "Invalid phone number",
KnownAuthStatusCode.SUCCESS_WITH_ACCOUNT: "Success",
KnownAuthStatusCode.SUCCESS_WITH_DEVICE_CHANGED: "Success (device changed)",

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Custom wrapper classes around types defined by the KakaoTalk API."""
from typing import Optional, NewType, Union
from typing import Dict, List, NewType, Optional, Union
from attr import dataclass
@ -73,8 +73,8 @@ class Receipt(SerializableAttrs):
@dataclass
class PortalChannelParticipantInfo(SerializableAttrs):
participants: list[UserInfoUnion]
kickedUserIds: list[Long]
participants: List[UserInfoUnion]
kickedUserIds: List[Long]
@dataclass
class PortalChannelInfo(SerializableAttrs):
@ -92,7 +92,7 @@ class ChannelProps(SerializableAttrs):
# TODO Add non-media types, like polls & maps
TO_MSGTYPE_MAP: dict[MessageType, KnownChatType] = {
TO_MSGTYPE_MAP: Dict[MessageType, KnownChatType] = {
MessageType.TEXT: KnownChatType.TEXT,
MessageType.IMAGE: KnownChatType.PHOTO,
MessageType.STICKER: KnownChatType.PHOTO,
@ -102,13 +102,13 @@ TO_MSGTYPE_MAP: dict[MessageType, KnownChatType] = {
}
# https://stackoverflow.com/a/483833
FROM_MSGTYPE_MAP: dict[KnownChatType, MessageType] = {v: k for k, v in TO_MSGTYPE_MAP.items()}
FROM_MSGTYPE_MAP: Dict[KnownChatType, MessageType] = {v: k for k, v in TO_MSGTYPE_MAP.items()}
# TODO Consider allowing custom power/perm mappings
# But must update default user level & permissions to match!
FROM_PERM_MAP: dict[OpenChannelUserPerm, int] = {
FROM_PERM_MAP: Dict[OpenChannelUserPerm, int] = {
OpenChannelUserPerm.OWNER: 100,
OpenChannelUserPerm.MANAGER: 50,
# TODO Decide on an appropriate value for this

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Union
from typing import Dict, List, Optional, Union
from attr import dataclass
@ -54,11 +54,11 @@ class MoreSettingsStruct(SerializableAttrs):
@dataclass
class MoreApps(SerializableAttrs):
recommend: Optional[list[str]] = None # NOTE From unknown[]
all: Optional[list[str]] = None # NOTE From unknown[]
recommend: Optional[List[str]] = None # NOTE From unknown[]
all: Optional[List[str]] = None # NOTE From unknown[]
moreApps: MoreApps
shortcuts: Optional[dict[str, int]] = None # NOTE Made optional
shortcuts: Optional[Dict[str, int]] = None # NOTE Made optional
seasonProfileRev: int
seasonNoticeRev: int
serviceUserId: Union[Long, int]
@ -86,8 +86,8 @@ class MoreSettingsStruct(SerializableAttrs):
@dataclass(kw_only=True)
class LessSettingsStruct(SerializableAttrs):
kakaoAutoLoginDomain: list[str]
daumSsoDomain: list[str]
kakaoAutoLoginDomain: List[str]
daumSsoDomain: List[str]
@dataclass
class GoogleMapsApi(SerializableAttrs):
@ -124,7 +124,7 @@ class LessSettingsStruct(SerializableAttrs):
newPostTerm: int
postExpirationSetting: PostExpirationSetting
kakaoAlertIds: list[int]
kakaoAlertIds: List[int]
@dataclass

View File

@ -13,6 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List
from attr import dataclass
from mautrix.types import SerializableAttrs
@ -23,7 +25,7 @@ from .friend_struct import FriendStruct
@dataclass
class FriendBlockedListStruct(SerializableAttrs):
total: int
blockedFriends: list[FriendStruct]
blockedFriends: List[FriendStruct]
__all__ = [

View File

@ -13,6 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List
from attr import dataclass
from mautrix.types import SerializableAttrs
@ -24,7 +26,7 @@ from .friend_struct import FriendStruct
@dataclass
class FriendListStruct(SerializableAttrs):
token: Long
friends: list[FriendStruct]
friends: List[FriendStruct]
__all__ = [

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import List, Optional
from attr import dataclass
@ -25,7 +25,7 @@ from .friend_struct import FriendStruct
@dataclass
class FriendSearchUserListStruct(SerializableAttrs):
count: int
list: list[FriendStruct]
list: List[FriendStruct]
@dataclass(kw_only=True)
@ -33,7 +33,7 @@ class FriendSearchStruct(SerializableAttrs):
query: str
user: Optional[FriendSearchUserListStruct] = None
plus: Optional[FriendSearchUserListStruct] = None
categories: list[str]
categories: List[str]
total_counts: int

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Union
from typing import List, Optional, Union
from attr import dataclass
@ -34,7 +34,7 @@ class ProfileFeed(SerializableAttrs):
serviceName: str
typeIconUrl: str
downloadId: str
contents: list[ProfileFeedObject]
contents: List[ProfileFeedObject]
url: str
serviceUrl: Optional[str] = None # NOTE Made optional
webUrl: str
@ -51,7 +51,7 @@ class ProfileFeed(SerializableAttrs):
@dataclass(kw_only=True)
class ProfileFeedList(SerializableAttrs):
totalCnt: int # NOTE Renamed from "totalCnts"
feeds: list[ProfileFeed]
feeds: List[ProfileFeed]
@dataclass
@ -91,7 +91,7 @@ class ProfileStruct(SerializableAttrs):
profileImageUrl: str
fullProfileImageUrl: str
originalProfileImageUrl: str
decoration: list[ProfileDecoration]
decoration: List[ProfileDecoration]
profileFeeds: ProfileFeedList
backgroundFeeds: ProfileFeedList
allowStory: bool

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Generic, TypeVar, Optional
from typing import Dict, Generic, List, Optional, TypeVar
from attr import dataclass
@ -37,7 +37,7 @@ class SetChannelMeta(ChannelMeta):
authorId: Long
updatedAt: int
ChannelMetaMap = dict[ChannelMetaType, SetChannelMeta] # Substitute for Record<ChannelMetaType, SetChannelMeta>
ChannelMetaMap = Dict[ChannelMetaType, SetChannelMeta] # Substitute for Record<ChannelMetaType, SetChannelMeta>
@dataclass(kw_only=True)
@ -50,7 +50,7 @@ class ChannelInfo(Channel):
lastSeenLogId: Long
lastChatLog: Optional[Chatlog] = None
metaMap: ChannelMetaMap
displayUserList: list[DisplayUserInfo]
displayUserList: List[DisplayUserInfo]
pushAlert: bool

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Union
from typing import List, Optional, Union
from attr import dataclass
from enum import Enum, IntEnum
@ -149,10 +149,10 @@ BotDelCommandStruct = BotCommandStruct
@dataclass(kw_only=True)
class BotMetaContent(SerializableAttrs):
add: Optional[list[BotAddCommandStruct]] = None
update: Optional[list[BotAddCommandStruct]] = None
full: Optional[list[BotAddCommandStruct]] = None
delete: Optional[list[BotDelCommandStruct]] = field(json="del", default=None)
add: Optional[List[BotAddCommandStruct]] = None
update: Optional[List[BotAddCommandStruct]] = None
full: Optional[List[BotAddCommandStruct]] = None
delete: Optional[List[BotDelCommandStruct]] = field(json="del", default=None)
__all__ = [

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import List, Optional
from attr import dataclass
@ -25,8 +25,8 @@ from .mention import *
@dataclass(kw_only=True)
class Attachment(SerializableAttrs):
shout: Optional[bool] = None
mentions: Optional[list[MentionStruct]] = None
urls: Optional[list[str]] = None
mentions: Optional[List[MentionStruct]] = None
urls: Optional[List[str]] = None
@dataclass

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import List, Optional
from attr import dataclass
@ -46,16 +46,16 @@ class PhotoAttachment(MediaKeyAttachment):
@dataclass
class MultiPhotoAttachment(Attachment):
kl: list[str]
wl: list[int]
hl: list[int]
csl: list[str]
imageUrls: list[str]
thumbnailUrls: list[str]
thumbnailWidths: list[int]
thumbnailHeights: list[int]
sl: list[int] # NOTE Changed to a list
mtl: list[str] # NOTE Added
kl: List[str]
wl: List[int]
hl: List[int]
csl: List[str]
imageUrls: List[str]
thumbnailUrls: List[str]
thumbnailWidths: List[int]
thumbnailHeights: List[int]
sl: List[int] # NOTE Changed to a list
mtl: List[str] # NOTE Added
@dataclass

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union
from typing import List, Union
from attr import dataclass
@ -24,7 +24,7 @@ from ...bson import Long
@dataclass
class MentionStruct(SerializableAttrs):
at: list[int]
at: List[int]
len: int
user_id: Union[Long, int]

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Union
from typing import List, Optional, Union
from enum import IntEnum
from attr import dataclass
@ -82,7 +82,7 @@ class PostItem:
class Image(Unknown):
t = KnownPostItemType.IMAGE
tt: Optional[str] = None
th: list[str]
th: List[str]
g: Optional[bool] = None
@dataclass(kw_only=True)
@ -96,7 +96,7 @@ class PostItem:
st: int
tt: str
ittpe: Optional[str] = None
its: list[dict]
its: List[dict]
@dataclass(kw_only=True)
class Video(Unknown):
@ -119,7 +119,7 @@ class PostItem:
@dataclass(kw_only=True)
class PostAttachment(Attachment):
subtype: Optional[PostSubItemType] = None
os: list[PostItem.Unknown]
os: List[PostItem.Unknown]
@dataclass

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import List, Optional
from attr import dataclass
@ -29,7 +29,7 @@ class ReplyAttachment(Attachment):
attach_type: Optional[ChatType] = None # NOTE Changed from int for outgoing typeless replies
src_linkId: Optional[Long] = None
src_logId: Long
src_mentions: list[MentionStruct] = None # NOTE Made optional
src_mentions: List[MentionStruct] = None # NOTE Made optional
src_message: str
src_type: ChatType
src_userId: Long

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Union, Type
from typing import Dict, Optional, Union, Type
from attr import dataclass
@ -62,7 +62,7 @@ class Chat(ChatTypeComponent):
return obj
# TODO More
_attachment_type_map: dict[KnownChatType, Type[Attachment]] = {
_attachment_type_map: Dict[KnownChatType, Type[Attachment]] = {
KnownChatType.PHOTO: PhotoAttachment,
KnownChatType.VIDEO: VideoAttachment,
KnownChatType.CONTACT: ContactAttachment,

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import NewType, Optional, Union
from typing import List, NewType, Optional, Union
from attr import dataclass
@ -41,11 +41,11 @@ setattr(LoginDataItem, "deserialize", deserialize_channel_login_data_item)
@dataclass
class LoginResult(SerializableAttrs):
"""Return value of TalkClient.login"""
channelList: list[LoginDataItem]
channelList: List[LoginDataItem]
userId: Long
lastTokenId: Long
mcmRevision: int
removedChannelIdList: list[Long]
removedChannelIdList: List[Long]
revision: int
revisionInfo: str
minLogId: Long

View File

@ -21,7 +21,7 @@ from .open_link_user_info import *
"""
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
from attr import dataclass
from enum import IntEnum
@ -79,7 +79,7 @@ class OpenLink(OpenLinkSettings, OpenLinkComponent, OpenTokenComponent, OpenPriv
linkURL: str
openToken: int
linkOwner: "OpenLinkUserInfo"
profileTagList: list[str]
profileTagList: List[str]
createdAt: int

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Generic, Type, TypeVar, Union
from typing import Generic, List, Type, TypeVar, Union
from attr import dataclass
from enum import IntEnum
@ -94,12 +94,12 @@ ResultType = TypeVar("ResultType", bound=Serializable)
def ResultListType(result_type: Type[ResultType]):
"""Custom type for setting a result to a list of serializable objects."""
class _ResultListType(list[result_type], Serializable):
def serialize(self) -> list[JSON]:
class _ResultListType(List[result_type], Serializable):
def serialize(self) -> List[JSON]:
return [v.serialize() for v in self]
@classmethod
def deserialize(cls, data: list[JSON]) -> "_ResultListType":
def deserialize(cls, data: List[JSON]) -> "_ResultListType":
return [result_type.deserialize(item) for item in data]
return _ResultListType

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from mautrix.bridge import BaseMatrixHandler
from mautrix.types import (
@ -175,7 +175,7 @@ class MatrixHandler(BaseMatrixHandler):
await user.client.mark_read(portal.channel_props, message.ktid)
async def handle_ephemeral_event(
self, evt: ReceiptEvent | Event
self, evt: Union[ReceiptEvent, Event]
) -> None:
if evt.type == EventType.RECEIPT:
await self.handle_receipt(evt)

View File

@ -22,9 +22,15 @@ from typing import (
Awaitable,
Callable,
Coroutine,
Dict,
Generic,
List,
Optional,
Pattern,
Set,
Tuple,
TypeVar,
Union,
cast,
)
from io import BytesIO
@ -147,41 +153,41 @@ StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STAT
class Portal(DBPortal, BasePortal):
invite_own_puppet_to_pm: bool = False
by_mxid: dict[RoomID, Portal] = {}
by_ktid: dict[tuple[int, int], Portal] = {}
by_mxid: Dict[RoomID, Portal] = {}
by_ktid: Dict[Tuple[int, int], Portal] = {}
matrix: m.MatrixHandler
config: Config
_main_intent: IntentAPI | None
_kt_sender: Long | None
_main_intent: Optional[IntentAPI]
_kt_sender: Optional[Long]
_create_room_lock: asyncio.Lock
_send_locks: dict[int, asyncio.Lock]
_send_locks: Dict[int, asyncio.Lock]
_noop_lock: FakeLock = FakeLock()
backfill_lock: SimpleLock
_backfill_leave: set[IntentAPI] | None
_backfill_leave: Optional[Set[IntentAPI]]
_sleeping_to_resync: bool
_scheduled_resync: asyncio.Task | None
_resync_targets: dict[int, p.Puppet]
_scheduled_resync: Optional[asyncio.Task]
_resync_targets: Dict[int, p.Puppet]
_CHAT_TYPE_HANDLER_MAP: dict[ChatType, Callable[..., ACallable[list[EventID]]]]
_STATE_EVENT_HANDLER_MAP: dict[EventType, StateEventHandler]
_CHAT_TYPE_HANDLER_MAP: Dict[ChatType, Callable[..., ACallable[List[EventID]]]]
_STATE_EVENT_HANDLER_MAP: Dict[EventType, StateEventHandler]
def __init__(
self,
ktid: Long,
kt_receiver: Long,
kt_type: ChannelType,
mxid: RoomID | None = None,
name: str | None = None,
description: str | None = None,
photo_id: str | None = None,
avatar_url: ContentURI | None = None,
mxid: Optional[RoomID] = None,
name: Optional[str] = None,
description: Optional[str] = None,
photo_id: Optional[str] = None,
avatar_url: Optional[ContentURI] = None,
encrypted: bool = False,
name_set: bool = False,
topic_set: bool = False,
avatar_set: bool = False,
fully_read_kt_chat: Long | None = None,
relay_user_id: UserID | None = None,
fully_read_kt_chat: Optional[Long] = None,
relay_user_id: Optional[UserID] = None,
) -> None:
super().__init__(
ktid,
@ -284,7 +290,7 @@ class Portal(DBPortal, BasePortal):
# region Properties
@property
def ktid_full(self) -> tuple[Long, Long]:
def ktid_full(self) -> Tuple[Long, Long]:
return self.ktid, self.kt_receiver
@property
@ -302,7 +308,7 @@ class Portal(DBPortal, BasePortal):
return KnownChannelType.is_open(self.kt_type)
@property
def kt_sender(self) -> int | None:
def kt_sender(self) -> Optional[int]:
if self.is_direct:
if not self._kt_sender:
raise ValueError("Direct chat portal must set sender")
@ -324,7 +330,7 @@ class Portal(DBPortal, BasePortal):
raise ValueError("Portal must be postinit()ed before main_intent can be used")
return self._main_intent
async def get_dm_puppet(self) -> p.Puppet | None:
async def get_dm_puppet(self) -> Optional[p.Puppet]:
if not self.is_direct:
return None
return await p.Puppet.get_by_ktid(self.kt_sender)
@ -365,7 +371,7 @@ class Portal(DBPortal, BasePortal):
async def update_info(
self,
source: u.User,
info: PortalChannelInfo | None = None,
info: Optional[PortalChannelInfo] = None,
force_save: bool = False,
) -> PortalChannelInfo:
if not info:
@ -380,6 +386,8 @@ class Portal(DBPortal, BasePortal):
self._update_photo(source, info.photoURL),
)
)
else:
changed = await self._update_name(info.name)
if info.participantInfo:
changed = await self._update_participants(source, info.participantInfo) or changed
if changed or force_save:
@ -387,8 +395,8 @@ class Portal(DBPortal, BasePortal):
await self.save()
return info
async def _get_mapped_participant_power_levels(self, participants: list[UserInfoUnion]) -> dict[UserID, int]:
user_power_levels: dict[UserID, int] = {}
async def _get_mapped_participant_power_levels(self, participants: List[UserInfoUnion]) -> Dict[UserID, int]:
user_power_levels: Dict[UserID, int] = {}
for participant in participants:
if not isinstance(participant, OpenChannelUserInfo):
self.log.warning(f"Info object for participant {participant.userId} of open channel is not an OpenChannelUserInfo")
@ -397,7 +405,7 @@ class Portal(DBPortal, BasePortal):
return user_power_levels
@staticmethod
async def _update_mapped_ktid_power_levels(user_power_levels: dict[UserID, int], ktid: int, perm: OpenChannelUserPerm) -> None:
async def _update_mapped_ktid_power_levels(user_power_levels: Dict[UserID, int], ktid: int, perm: OpenChannelUserPerm) -> None:
power_level = FROM_PERM_MAP[perm]
user = await u.User.get_by_ktid(ktid)
if user:
@ -406,7 +414,7 @@ class Portal(DBPortal, BasePortal):
if puppet:
user_power_levels[puppet.mxid] = power_level
async def _set_user_power_levels(self, sender: p.Puppet | None, user_power_levels: dict[UserID, int]) -> None:
async def _set_user_power_levels(self, sender: Optional[p.Puppet], user_power_levels: Dict[UserID, int]) -> None:
if not self.mxid:
return
orig_power_levels = await self.main_intent.get_power_levels(self.mxid)
@ -421,7 +429,7 @@ class Portal(DBPortal, BasePortal):
}
sender_intent = sender.intent_for(self) if sender else self.main_intent
admin_level = orig_power_levels.get_user_level(sender_intent.mxid)
demoter_ids: list[UserID] = []
demoter_ids: List[UserID] = []
power_levels = evolve(orig_power_levels)
for user_id, new_level in user_power_levels.items():
curr_level = orig_power_levels.get_user_level(user_id)
@ -454,12 +462,12 @@ class Portal(DBPortal, BasePortal):
source: u.User,
intent: IntentAPI,
*,
filename: str | None = None,
mimetype: str | None,
filename: Optional[str] = None,
mimetype: Optional[str],
encrypt: bool = False,
find_size: bool = False,
convert_audio: bool = False,
) -> tuple[ContentURI, FileInfo | VideoInfo | AudioInfo | ImageInfo, EncryptedFile | None]:
) -> Tuple[ContentURI, Union[FileInfo, VideoInfo, AudioInfo, ImageInfo], Optional[EncryptedFile]]:
if not url:
raise ValueError("URL not provided")
sandbox = cls.config["bridge.sandbox_media_download"]
@ -485,12 +493,12 @@ class Portal(DBPortal, BasePortal):
data: bytes,
intent: IntentAPI,
*,
filename: str | None = None,
mimetype: str | None = None,
filename: Optional[str] = None,
mimetype: Optional[str] = None,
encrypt: bool = False,
find_size: bool = False,
convert_audio: bool = False,
) -> tuple[ContentURI, FileInfo | VideoInfo | AudioInfo | ImageInfo, EncryptedFile | None]:
) -> Tuple[ContentURI, Union[FileInfo, VideoInfo, AudioInfo, ImageInfo], Optional[EncryptedFile]]:
if not mimetype:
mimetype = magic.mimetype(data)
if convert_audio and mimetype != "audio/ogg":
@ -519,14 +527,14 @@ class Portal(DBPortal, BasePortal):
decryption_info.url = url
return url, info, decryption_info
async def _update_name(self, name: str | None) -> bool:
async def _update_name(self, name: Optional[str]) -> bool:
if not name:
self.log.warning("Got empty name in _update_name call")
return False
if self.name != name or not self.name_set:
self.log.trace("Updating name %s -> %s", self.name, name)
self.name = name
if self.mxid and (self.encrypted or not self.is_direct):
if self.mxid:
try:
await self.main_intent.set_room_name(self.mxid, self.name)
self.name_set = True
@ -536,7 +544,7 @@ class Portal(DBPortal, BasePortal):
return True
return False
async def _update_description(self, description: str | None) -> bool:
async def _update_description(self, description: Optional[str]) -> bool:
if self.description != description or not self.topic_set:
self.log.trace("Updating description %s -> %s", self.description, description)
self.description = description
@ -550,7 +558,7 @@ class Portal(DBPortal, BasePortal):
return True
return False
async def _update_photo(self, source: u.User, photo_id: str | None) -> bool:
async def _update_photo(self, source: u.User, photo_id: Optional[str]) -> bool:
if self.is_direct and not self.encrypted:
return False
if self.photo_id is not None and photo_id is None:
@ -601,7 +609,7 @@ class Portal(DBPortal, BasePortal):
self.avatar_set = False
return True
async def update_info_from_puppet(self, puppet: p.Puppet | None = None) -> bool:
async def update_info_from_puppet(self, puppet: Optional[p.Puppet] = None) -> bool:
if not self.is_direct:
return False
if not puppet:
@ -653,7 +661,7 @@ class Portal(DBPortal, BasePortal):
async def _update_participants(
self,
source: u.User,
participant_info: PortalChannelParticipantInfo | None = None,
participant_info: Optional[PortalChannelParticipantInfo] = None,
) -> bool:
# NOTE This handles only non-logged-in users, because logged-in users should be handled by the channel list listeners
# TODO nick map?
@ -721,13 +729,13 @@ class Portal(DBPortal, BasePortal):
# endregion
# region Matrix room creation
async def update_matrix_room(self, source: u.User, info: PortalChannelInfo | None = None) -> None:
async def update_matrix_room(self, source: u.User, info: Optional[PortalChannelInfo] = None) -> None:
try:
await self._update_matrix_room(source, info)
except Exception:
self.log.exception("Failed to update portal")
def _get_invite_content(self, double_puppet: p.Puppet | None) -> dict[str, Any]:
def _get_invite_content(self, double_puppet: Optional[p.Puppet]) -> Dict[str, Any]:
invite_content = {}
if double_puppet:
invite_content["fi.mau.will_auto_accept"] = True
@ -736,7 +744,7 @@ class Portal(DBPortal, BasePortal):
return invite_content
async def _update_matrix_room(
self, source: u.User, info: PortalChannelInfo | None = None
self, source: u.User, info: Optional[PortalChannelInfo] = None
) -> None:
puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
await self.main_intent.invite_user(
@ -785,8 +793,8 @@ class Portal(DBPortal, BasePortal):
await self.save()
async def create_matrix_room(
self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID | None:
self, source: u.User, info: Optional[PortalChannelInfo] = None
) -> Optional[RoomID]:
if self.mxid:
try:
await self._update_matrix_room(source, info)
@ -805,7 +813,7 @@ class Portal(DBPortal, BasePortal):
return f"net.miscworks.kakaotalk://kakaotalk/{self.ktid}"
@property
def bridge_info(self) -> dict[str, Any]:
def bridge_info(self) -> Dict[str, Any]:
return {
"bridgebot": self.az.bot_mxid,
"creator": self.main_intent.mxid,
@ -838,7 +846,7 @@ class Portal(DBPortal, BasePortal):
self.log.warning("Failed to update bridge info", exc_info=True)
async def _create_matrix_room(
self, source: u.User, info: PortalChannelInfo | None = None
self, source: u.User, info: Optional[PortalChannelInfo] = None
) -> RoomID:
if self.mxid:
await self._update_matrix_room(source, info=info)
@ -853,8 +861,8 @@ class Portal(DBPortal, BasePortal):
assert info.participantInfo
await self._update_participants(source, info.participantInfo)
name: str | None = None
description: str | None = None
name: Optional[str] = None
description: Optional[str] = None
initial_state = [
{
"type": str(StateBridge),
@ -975,7 +983,7 @@ class Portal(DBPortal, BasePortal):
self._send_locks[user_id] = lock
return lock
def optional_send_lock(self, user_id: int) -> asyncio.Lock | FakeLock:
def optional_send_lock(self, user_id: int) -> Union[asyncio.Lock, FakeLock]:
try:
return self._send_locks[user_id]
except KeyError:
@ -1051,7 +1059,7 @@ class Portal(DBPortal, BasePortal):
raise NotImplementedError(f"Unsupported message type {message.msgtype}")
async def _send_chat(
self, sender: u.User, message: TextMessageEventContent, event_id: EventID | None = None
self, sender: u.User, message: TextMessageEventContent, event_id: Optional[EventID] = None
) -> Long:
converted = await matrix_to_kakaotalk(message, self.mxid, self.log, self)
try:
@ -1065,7 +1073,7 @@ class Portal(DBPortal, BasePortal):
self.log.debug(f"Error handling Matrix message {event_id if event_id else '<extra>'}: {e!s}")
raise
async def _make_dbm(self, event_id: EventID, ktid: Long | None = None) -> DBMessage:
async def _make_dbm(self, event_id: EventID, ktid: Optional[Long] = None) -> DBMessage:
dbm = DBMessage(
mxid=event_id,
mx_room=self.mxid,
@ -1257,8 +1265,8 @@ class Portal(DBPortal, BasePortal):
prev_content: PowerLevelStateEventContent,
content: PowerLevelStateEventContent,
) -> None:
ktid_perms: dict[Long, OpenChannelUserPerm] = {}
user_power_levels: dict[UserID, int] = {}
ktid_perms: Dict[Long, OpenChannelUserPerm] = {}
user_power_levels: Dict[UserID, int] = {}
for user_id, level in content.users.items():
if level == prev_content.get_user_level(user_id):
continue
@ -1301,7 +1309,7 @@ class Portal(DBPortal, BasePortal):
)
async def _revert_matrix_power_levels(self, prev_content: PowerLevelStateEventContent) -> None:
managed_power_levels: dict[UserID, int] = {}
managed_power_levels: Dict[UserID, int] = {}
for user_id, level in prev_content.users.items():
if await p.Puppet.get_by_mxid(user_id) or await u.User.get_by_mxid(user_id):
managed_power_levels[user_id] = level
@ -1510,10 +1518,10 @@ class Portal(DBPortal, BasePortal):
self,
intent: IntentAPI,
timestamp: int,
chat_text: str | None,
chat_text: Optional[str],
chat_type: ChatType,
**_
) -> list[EventID]:
) -> List[EventID]:
try:
type_str = KnownChatType(chat_type).name.lower()
except ValueError:
@ -1543,9 +1551,9 @@ class Portal(DBPortal, BasePortal):
async def _handle_kakaotalk_feed(
self,
timestamp: int,
chat_text: str | None,
chat_text: Optional[str],
**_
) -> list[EventID]:
) -> List[EventID]:
self.log.info("Got feed message at %s: %s", timestamp, chat_text or "none")
return []
@ -1553,18 +1561,18 @@ class Portal(DBPortal, BasePortal):
self,
timestamp: int,
**_
) -> list[EventID]:
) -> List[EventID]:
self.log.info(f"Got deleted (?) message at {timestamp}")
return []
async def _handle_kakaotalk_text(
self,
intent: IntentAPI,
attachment: Attachment | None,
attachment: Optional[Attachment],
timestamp: int,
chat_text: str | None,
chat_text: Optional[str],
**_
) -> list[EventID]:
) -> List[EventID]:
content = await kakaotalk_to_matrix(chat_text, attachment.mentions if attachment else None)
return [await self._send_message(intent, content, timestamp=timestamp)]
@ -1575,19 +1583,19 @@ class Portal(DBPortal, BasePortal):
timestamp: int,
chat_text: str,
**_
) -> list[EventID]:
) -> List[EventID]:
content = await kakaotalk_to_matrix(chat_text, attachment.mentions)
await self._add_kakaotalk_reply(content, attachment)
return [await self._send_message(intent, content, timestamp=timestamp)]
async def _handle_kakaotalk_photo(self, **kwargs) -> list[EventID]:
async def _handle_kakaotalk_photo(self, **kwargs) -> List[EventID]:
return [await self._handle_kakaotalk_uniphoto(**kwargs)]
async def _handle_kakaotalk_multiphoto(
self,
attachment: MultiPhotoAttachment,
**kwargs
) -> list[EventID]:
) -> List[EventID]:
# TODO Upload media concurrently, but post messages sequentially
return [
await self._handle_kakaotalk_uniphoto(
@ -1632,7 +1640,7 @@ class Portal(DBPortal, BasePortal):
self,
attachment: VideoAttachment,
**kwargs
) -> list[EventID]:
) -> List[EventID]:
return [await self._handle_kakaotalk_media(
attachment,
VideoInfo(
@ -1648,7 +1656,7 @@ class Portal(DBPortal, BasePortal):
self,
attachment: AudioAttachment,
**kwargs
) -> list[EventID]:
) -> List[EventID]:
return [await self._handle_kakaotalk_media(
attachment,
AudioInfo(
@ -1661,14 +1669,14 @@ class Portal(DBPortal, BasePortal):
async def _handle_kakaotalk_media(
self,
attachment: MediaAttachment | AudioAttachment,
attachment: Union[MediaAttachment, AudioAttachment],
info: MediaInfo,
msgtype: MessageType,
*,
source: u.User,
intent: IntentAPI,
timestamp: int,
chat_text: str | None,
chat_text: Optional[str],
**_
) -> EventID:
mxc, additional_info, decryption_info = await self._reupload_kakaotalk_file_from_url(
@ -1692,9 +1700,9 @@ class Portal(DBPortal, BasePortal):
intent: IntentAPI,
attachment: FileAttachment,
timestamp: int,
chat_text: str | None,
chat_text: Optional[str],
**_
) -> list[EventID]:
) -> List[EventID]:
data = await source.client.download_file(self.channel_props, attachment.k)
mxc, info, decryption_info = await self._reupload_kakaotalk_file_from_bytes(
data,
@ -1741,7 +1749,7 @@ class Portal(DBPortal, BasePortal):
async def handle_kakaotalk_perm_changed(
self, source: u.User, sender: p.Puppet, user_id: Long, perm: OpenChannelUserPerm
) -> None:
user_power_levels: dict[UserID, int] = {}
user_power_levels: Dict[UserID, int] = {}
await self._update_mapped_ktid_power_levels(user_power_levels, user_id, perm)
await self._set_user_power_levels(sender, user_power_levels)
@ -1755,7 +1763,7 @@ class Portal(DBPortal, BasePortal):
self.schedule_resync(source, user)
async def handle_kakaotalk_user_left(
self, source: u.User, sender: p.Puppet | None, removed: p.Puppet
self, source: u.User, sender: Optional[p.Puppet], removed: p.Puppet
) -> None:
sender_intent = sender.intent_for(self) if sender else self.main_intent
removed_user = await u.User.get_by_ktid(removed.ktid)
@ -1791,7 +1799,7 @@ class Portal(DBPortal, BasePortal):
# TODO Find when or if there is a listener for this
# TODO Confirm whether this can refer to any user that was kicked, or only to the current user
async def handle_kakaotalk_user_unkick(
self, source: u.User, sender: p.Puppet | None, unkicked: p.Puppet
self, source: u.User, sender: Optional[p.Puppet], unkicked: p.Puppet
) -> None:
assert sender != unkicked, f"Puppet for {unkicked.mxid} tried to unkick itself"
sender_intent = sender.intent_for(self) if sender else self.main_intent
@ -1848,8 +1856,8 @@ class Portal(DBPortal, BasePortal):
async def _backfill(
self,
source: u.User,
limit: int | None,
after_log_id: Long | None,
limit: Optional[int],
after_log_id: Optional[Long],
) -> None:
self.log.debug(f"Backfilling history through {source.mxid}")
self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
@ -1898,7 +1906,7 @@ class Portal(DBPortal, BasePortal):
@classmethod
@async_getter_lock
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
async def get_by_mxid(cls, mxid: RoomID) -> Optional[Portal]:
try:
return cls.by_mxid[mxid]
except KeyError:
@ -1919,8 +1927,8 @@ class Portal(DBPortal, BasePortal):
*,
kt_receiver: int = 0,
create: bool = True,
kt_type: ChannelType | None = None,
) -> Portal | None:
kt_type: Optional[ChannelType] = None,
) -> Optional[Portal]:
# TODO Direct chats are shared, so can remove kt_receiver if DM portals should be shared
if kt_type:
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, Dict, Optional, cast
from datetime import datetime, timedelta
import asyncio
@ -46,24 +46,24 @@ class Puppet(DBPuppet, BasePuppet):
hs_domain: str
mxid_template: SimpleTemplate[int]
by_ktid: dict[int, Puppet] = {}
by_custom_mxid: dict[UserID, Puppet] = {}
by_ktid: Dict[int, Puppet] = {}
by_custom_mxid: Dict[UserID, Puppet] = {}
_last_info_sync: datetime | None
_last_info_sync: Optional[datetime]
def __init__(
self,
ktid: Long,
name: str | None = None,
photo_id: str | None = None,
photo_mxc: ContentURI | None = None,
name: Optional[str] = None,
photo_id: Optional[str] = None,
photo_mxc: Optional[ContentURI] = None,
name_set: bool = False,
avatar_set: bool = False,
is_registered: bool = False,
custom_mxid: UserID | None = None,
access_token: str | None = None,
next_batch: SyncToken | None = None,
base_url: URL | None = None,
custom_mxid: Optional[UserID] = None,
access_token: Optional[str] = None,
next_batch: Optional[SyncToken] = None,
base_url: Optional[URL] = None,
) -> None:
super().__init__(
ktid,
@ -249,7 +249,7 @@ class Puppet(DBPuppet, BasePuppet):
@classmethod
@async_getter_lock
async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Puppet | None:
async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Optional[Puppet]:
try:
return cls.by_ktid[int]
except KeyError:
@ -269,7 +269,7 @@ class Puppet(DBPuppet, BasePuppet):
return None
@classmethod
async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional[Puppet]:
ktid = cls.get_id_from_mxid(mxid)
if ktid:
return await cls.get_by_ktid(ktid, create=create)
@ -277,7 +277,7 @@ class Puppet(DBPuppet, BasePuppet):
@classmethod
@async_getter_lock
async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
async def get_by_custom_mxid(cls, mxid: UserID) -> Optional[Puppet]:
try:
return cls.by_custom_mxid[mxid]
except KeyError:
@ -291,7 +291,7 @@ class Puppet(DBPuppet, BasePuppet):
return None
@classmethod
def get_id_from_mxid(cls, mxid: UserID) -> int | None:
def get_id_from_mxid(cls, mxid: UserID) -> Optional[int]:
return cls.mxid_template.parse(mxid)
@classmethod

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any, Callable, Awaitable
from typing import Any, Awaitable, Callable, Dict, List, Optional
import asyncio
import json
@ -26,16 +26,16 @@ from mautrix.types.primitive import JSON
from ..config import Config
from .types import RPCError
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
class CancelableEvent:
_event: asyncio.Event
_task: asyncio.Task | None
_task: Optional[asyncio.Task]
_cancelled: bool
_loop: asyncio.AbstractEventLoop
def __init__(self, loop: asyncio.AbstractEventLoop | None):
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]):
self._event = asyncio.Event()
self._task = None
self._cancelled = False
@ -74,15 +74,15 @@ class RPCClient:
loop: asyncio.AbstractEventLoop
log: logging.Logger = logging.getLogger("mau.rpc")
_reader: asyncio.StreamReader | None
_writer: asyncio.StreamWriter | None
_reader: Optional[asyncio.StreamReader]
_writer: Optional[asyncio.StreamWriter]
_req_id: int
_min_broadcast_id: int
_response_waiters: dict[int, asyncio.Future[JSON]]
_event_handlers: dict[str, list[EventHandler]]
_response_waiters: Dict[int, asyncio.Future[JSON]]
_event_handlers: Dict[str, List[EventHandler]]
_command_queue: asyncio.Queue
_read_task: asyncio.Task | None
_connection_task: asyncio.Task | None
_read_task: Optional[asyncio.Task]
_connection_task: Optional[asyncio.Task]
_is_connected: CancelableEvent
_is_disconnected: CancelableEvent
_connection_lock: asyncio.Lock
@ -203,10 +203,10 @@ class RPCClient:
except ValueError:
pass
def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
def set_event_handlers(self, method: str, handlers: List[EventHandler]) -> None:
self._event_handlers[method] = handlers
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
async def _run_event_handler(self, req_id: int, command: str, req: Dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
return

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, Dict, List, Optional, cast
import asyncio
import time
@ -92,22 +92,22 @@ class User(DBUser, BaseUser):
shutdown: bool = False
config: Config
by_mxid: dict[UserID, User] = {}
by_ktid: dict[int, User] = {}
by_mxid: Dict[UserID, User] = {}
by_ktid: Dict[int, User] = {}
_client: Client | None
_client: Optional[Client]
_notice_room_lock: asyncio.Lock
_notice_send_lock: asyncio.Lock
is_admin: bool
permission_level: str
_is_logged_in: bool | None
_is_connected: bool | None
_is_logged_in: Optional[bool]
_is_connected: Optional[bool]
_connection_time: float
_db_instance: DBUser | None
_db_instance: Optional[DBUser]
_sync_lock: SimpleLock
_is_rpc_reconnecting: bool
_logged_in_info: SettingsStruct | None
_logged_in_info: Optional[SettingsStruct]
_logged_in_info_time: float
def __init__(
@ -115,11 +115,11 @@ class User(DBUser, BaseUser):
mxid: UserID,
force_login: bool,
was_connected: bool,
ktid: Long | None = None,
uuid: str | None = None,
access_token: str | None = None,
refresh_token: str | None = None,
notice_room: RoomID | None = None,
ktid: Optional[Long] = None,
uuid: Optional[str] = None,
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
notice_room: Optional[RoomID] = None,
) -> None:
super().__init__(
mxid=mxid,
@ -168,11 +168,11 @@ class User(DBUser, BaseUser):
return self._client
@property
def is_connected(self) -> bool | None:
def is_connected(self) -> Optional[bool]:
return self._is_connected
@is_connected.setter
def is_connected(self, val: bool | None) -> None:
def is_connected(self, val: Optional[bool]) -> None:
if self._is_connected != val:
self._is_connected = val
self._connection_time = time.monotonic()
@ -209,7 +209,7 @@ class User(DBUser, BaseUser):
@classmethod
@async_getter_lock
async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> User | None:
async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> Optional[User]:
if pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid:
return None
try:
@ -233,7 +233,7 @@ class User(DBUser, BaseUser):
@classmethod
@async_getter_lock
async def get_by_ktid(cls, ktid: int) -> User | None:
async def get_by_ktid(cls, ktid: int) -> Optional[User]:
try:
return cls.by_ktid[ktid]
except KeyError:
@ -281,7 +281,7 @@ class User(DBUser, BaseUser):
self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
self.uuid = oauth_credential.deviceUUID
async def get_own_info(self, *, force: bool = False) -> SettingsStruct | None:
async def get_own_info(self, *, force: bool = False) -> Optional[SettingsStruct]:
if force or self._client and (not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic()):
self._logged_in_info = await self._client.get_settings()
self._logged_in_info_time = time.monotonic()
@ -297,7 +297,7 @@ class User(DBUser, BaseUser):
error="logged-out",
)
return False
latest_exc: Exception | None = None
latest_exc: Optional[Exception] = None
password_ok = False
try:
creds = await LoginCredential.get_by_mxid(self.mxid)
@ -357,7 +357,7 @@ class User(DBUser, BaseUser):
asyncio.create_task(self.post_login(is_startup=is_startup, is_token_login=not password_ok))
return True
async def _send_reset_notice(self, e: OAuthException, edit: EventID | None = None) -> None:
async def _send_reset_notice(self, e: OAuthException, edit: Optional[EventID] = None) -> None:
await self.send_bridge_notice(
"Got authentication error from KakaoTalk:\n\n"
f"> {e.message}\n\n"
@ -389,7 +389,7 @@ class User(DBUser, BaseUser):
return self._is_logged_in
async def reload_session(
self, event_id: EventID | None = None, retries: int = 3, is_startup: bool = False
self, event_id: Optional[EventID] = None, retries: int = 3, is_startup: bool = False
) -> None:
try:
if not await self._load_session(is_startup=is_startup) and self.has_state:
@ -487,7 +487,7 @@ class User(DBUser, BaseUser):
self.was_connected = False
await self.save()
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]:
return {
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
async for portal in po.Portal.get_all_by_receiver(self.ktid)
@ -495,7 +495,7 @@ class User(DBUser, BaseUser):
} if self.ktid else {}
@async_time(METRIC_CONNECT_AND_SYNC)
async def connect_and_sync(self, sync_count: int | None, force_sync: bool) -> bool:
async def connect_and_sync(self, sync_count: Optional[int], force_sync: bool) -> bool:
# TODO Look for a way to sync all channels without (re-)logging in
try:
login_result = await self.client.connect()
@ -517,7 +517,7 @@ class User(DBUser, BaseUser):
await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e))
return False
async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None:
async def _sync_channels(self, login_result: LoginResult, sync_count: Optional[int]) -> None:
if sync_count is None:
sync_count = self.config["bridge.initial_chat_sync"]
if not sync_count:
@ -625,12 +625,12 @@ class User(DBUser, BaseUser):
async def send_bridge_notice(
self,
text: str,
edit: EventID | None = None,
state_event: BridgeStateEvent | None = None,
edit: Optional[EventID] = None,
state_event: Optional[BridgeStateEvent] = None,
important: bool = False,
error_code: str | None = None,
error_message: str | None = None,
) -> EventID | None:
error_code: Optional[str] = None,
error_message: Optional[str] = None,
) -> Optional[EventID]:
if state_event:
await self.push_bridge_state(
state_event,
@ -662,7 +662,7 @@ class User(DBUser, BaseUser):
puppet = await pu.Puppet.get_by_ktid(self.ktid)
state.remote_name = puppet.name
async def get_bridge_states(self) -> list[BridgeState]:
async def get_bridge_states(self) -> List[BridgeState]:
if not self.has_state:
return []
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
@ -672,12 +672,12 @@ class User(DBUser, BaseUser):
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state]
async def get_puppet(self) -> pu.Puppet | None:
async def get_puppet(self) -> Optional[pu.Puppet]:
if not self.ktid:
return None
return await pu.Puppet.get_by_ktid(self.ktid)
async def get_portal_with(self, puppet: pu.Puppet, create: bool = True) -> po.Portal | None:
async def get_portal_with(self, puppet: pu.Puppet, create: bool = True) -> Optional[po.Portal]:
# TODO Make upstream request to return custom failure message
if not self.ktid or not self.is_connected:
return None
@ -732,7 +732,7 @@ class User(DBUser, BaseUser):
await self.save()
return not skip_sync
async def on_disconnect(self, res: KickoutRes | None) -> None:
async def on_disconnect(self, res: Optional[KickoutRes]) -> None:
self.is_connected = False
self._track_metric(METRIC_CONNECTED, False)
logout = not self.config["bridge.remain_logged_in_on_disconnect"]
@ -934,7 +934,7 @@ class User(DBUser, BaseUser):
await self._sync_channel(channel_info)
@async_time(METRIC_CHANNEL_LEFT)
async def on_channel_left(self, channel_id: Long, channel_type: ChannelType | None) -> None:
async def on_channel_left(self, channel_id: Long, channel_type: Optional[ChannelType]) -> None:
assert self.ktid
portal = await po.Portal.get_by_ktid(
channel_id,