From 4158788496be2ffb9e6bcd885f599dce325dbad3 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Fri, 11 Mar 2022 03:43:00 -0500 Subject: [PATCH] Clean up Long Keep it as an int in Python, and do all fancy conversions in Node --- matrix_appservice_kakaotalk/db/message.py | 73 +++++-------------- matrix_appservice_kakaotalk/db/portal.py | 25 +++---- matrix_appservice_kakaotalk/db/puppet.py | 15 ++-- .../db/upgrade/v01_initial_revision.py | 18 ++--- matrix_appservice_kakaotalk/db/user.py | 16 ++-- matrix_appservice_kakaotalk/kt/types/bson.py | 69 ++---------------- matrix_appservice_kakaotalk/portal.py | 26 +++---- matrix_appservice_kakaotalk/puppet.py | 16 ++-- matrix_appservice_kakaotalk/user.py | 4 +- node/src/client.js | 45 +++++------- 10 files changed, 102 insertions(+), 205 deletions(-) diff --git a/matrix_appservice_kakaotalk/db/message.py b/matrix_appservice_kakaotalk/db/message.py index 5c8d878..30ebb0c 100644 --- a/matrix_appservice_kakaotalk/db/message.py +++ b/matrix_appservice_kakaotalk/db/message.py @@ -18,12 +18,12 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar from asyncpg import Record -from attr import dataclass +from attr import dataclass, field from mautrix.types import EventID, RoomID from mautrix.util.async_db import Database -from ..kt.types.bson import Long, StrLong +from ..kt.types.bson import Long fake_db = Database.create("") if TYPE_CHECKING else None @@ -32,27 +32,17 @@ fake_db = Database.create("") if TYPE_CHECKING else None class Message: db: ClassVar[Database] = fake_db - # TODO Store all Long values as the same type mxid: EventID mx_room: RoomID - ktid: Long + ktid: Long = field(converter=Long) index: int - kt_chat: Long - kt_receiver: Long + kt_chat: Long = field(converter=Long) + kt_receiver: Long = field(converter=Long) timestamp: int @classmethod - def _from_row(cls, row: Record) -> Message | None: - data = {**row} - ktid = data.pop("ktid") - kt_chat = data.pop("kt_chat") - kt_receiver = data.pop("kt_receiver") - return cls( - **data, - ktid=StrLong(ktid), - kt_chat=Long.from_bytes(kt_chat), - kt_receiver=Long.from_bytes(kt_receiver) - ) + def _from_row(cls, row: Record) -> Message: + return cls(**row) @classmethod def _from_optional_row(cls, row: Record | None) -> Message | None: @@ -61,15 +51,15 @@ class Message: columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp' @classmethod - async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]: + async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]: q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2" - rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver)) + rows = await cls.db.fetch(q, ktid, kt_receiver) return [cls._from_row(row) for row in rows if row] @classmethod - async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None: + async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None: q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' - row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index) + row = await cls.db.fetchrow(q, ktid, kt_receiver, index) return cls._from_optional_row(row) @classmethod @@ -83,18 +73,18 @@ class Message: return cls._from_optional_row(row) @classmethod - async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None: + async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None: q = ( f"SELECT {cls.columns} " "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL " "ORDER BY timestamp DESC LIMIT 1" ) - row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver)) + row = await cls.db.fetchrow(q, kt_chat, kt_receiver) return cls._from_optional_row(row) @classmethod async def get_closest_before( - cls, kt_chat: Long, kt_receiver: Long, timestamp: int + cls, kt_chat: int, kt_receiver: int, timestamp: int ) -> Message | None: q = ( f"SELECT {cls.columns} " @@ -102,7 +92,7 @@ class Message: " ktid IS NOT NULL " "ORDER BY timestamp DESC LIMIT 1" ) - row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp) + row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp) return cls._from_optional_row(row) _insert_query = ( @@ -111,46 +101,23 @@ class Message: "VALUES ($1, $2, $3, $4, $5, $6, $7)" ) - @classmethod - async def bulk_create( - cls, - ktid: Long, - kt_chat: Long, - kt_receiver: Long, - event_ids: list[EventID], - timestamp: int, - mx_room: RoomID, - ) -> None: - if not event_ids: - return - columns = [col.strip('"') for col in cls.columns.split(", ")] - records = [ - (mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp) - for index, mxid in enumerate(event_ids) - ] - async with cls.db.acquire() as conn, conn.transaction(): - if cls.db.scheme == "postgres": - await conn.copy_records_to_table("message", records=records, columns=columns) - else: - await conn.executemany(cls._insert_query, records) - async def insert(self) -> None: q = self._insert_query await self.db.execute( q, self.mxid, self.mx_room, - str(self.ktid), + self.ktid, self.index, - bytes(self.kt_chat), - bytes(self.kt_receiver), + self.kt_chat, + self.kt_receiver, self.timestamp, ) async def delete(self) -> None: q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' - await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index) + await self.db.execute(q, self.ktid, self.kt_receiver, self.index) async def update(self) -> None: q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4" - await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room) + await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room) diff --git a/matrix_appservice_kakaotalk/db/portal.py b/matrix_appservice_kakaotalk/db/portal.py index 6a97212..3b5e756 100644 --- a/matrix_appservice_kakaotalk/db/portal.py +++ b/matrix_appservice_kakaotalk/db/portal.py @@ -18,7 +18,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar from asyncpg import Record -from attr import dataclass +from attr import dataclass, field from mautrix.types import ContentURI, RoomID, UserID from mautrix.util.async_db import Database @@ -33,8 +33,8 @@ fake_db = Database.create("") if TYPE_CHECKING else None class Portal: db: ClassVar[Database] = fake_db - ktid: Long - kt_receiver: Long + ktid: Long = field(converter=Long) + kt_receiver: Long = field(converter=Long) kt_type: ChannelType mxid: RoomID | None name: str | None @@ -47,23 +47,20 @@ class Portal: @classmethod def _from_row(cls, row: Record) -> Portal: - data = {**row} - ktid = data.pop("ktid") - kt_receiver = data.pop("kt_receiver") - return cls(**data, ktid=Long.from_bytes(ktid), kt_receiver=Long.from_bytes(kt_receiver)) + return cls(**row) @classmethod def _from_optional_row(cls, row: Record | None) -> Portal | None: return cls._from_row(row) if row is not None else None @classmethod - async def get_by_ktid(cls, ktid: Long, kt_receiver: Long) -> Portal | None: + async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None: q = """ SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted, name_set, avatar_set, relay_user_id FROM portal WHERE ktid=$1 AND kt_receiver=$2 """ - row = await cls.db.fetchrow(q, bytes(ktid), bytes(kt_receiver)) + row = await cls.db.fetchrow(q, ktid, kt_receiver) return cls._from_optional_row(row) @classmethod @@ -77,13 +74,13 @@ class Portal: return cls._from_optional_row(row) @classmethod - async def get_all_by_receiver(cls, kt_receiver: Long) -> list[Portal]: + async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]: q = """ SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted, name_set, avatar_set, relay_user_id FROM portal WHERE kt_receiver=$1 """ - rows = await cls.db.fetch(q, bytes(kt_receiver)) + rows = await cls.db.fetch(q, kt_receiver) return [cls._from_row(row) for row in rows if row] @classmethod @@ -99,8 +96,8 @@ class Portal: @property def _values(self): return ( - Long.to_optional_bytes(self.ktid), - Long.to_optional_bytes(self.kt_receiver), + self.ktid, + self.kt_receiver, self.kt_type, self.mxid, self.name, @@ -122,7 +119,7 @@ class Portal: async def delete(self) -> None: q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2" - await self.db.execute(q, Long.to_optional_bytes(self.ktid), Long.to_optional_bytes(self.kt_receiver)) + await self.db.execute(q, self.ktid, self.kt_receiver) async def save(self) -> None: q = """ diff --git a/matrix_appservice_kakaotalk/db/puppet.py b/matrix_appservice_kakaotalk/db/puppet.py index 937949c..3f9b1fe 100644 --- a/matrix_appservice_kakaotalk/db/puppet.py +++ b/matrix_appservice_kakaotalk/db/puppet.py @@ -18,7 +18,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar from asyncpg import Record -from attr import dataclass +from attr import dataclass, field from yarl import URL from mautrix.types import ContentURI, SyncToken, UserID @@ -33,7 +33,7 @@ fake_db = Database.create("") if TYPE_CHECKING else None class Puppet: db: ClassVar[Database] = fake_db - ktid: Long + ktid: Long = field(converter=Long) name: str | None photo_id: str | None photo_mxc: ContentURI | None @@ -49,22 +49,21 @@ class Puppet: @classmethod def _from_row(cls, row: Record) -> Puppet: data = {**row} - ktid = data.pop("ktid") base_url = data.pop("base_url", None) - return cls(**data, ktid=Long.from_optional_bytes(ktid), base_url=URL(base_url) if base_url else None) + return cls(**data, base_url=URL(base_url) if base_url else None) @classmethod def _from_optional_row(cls, row: Record | None) -> Puppet | None: return cls._from_row(row) if row is not None else None @classmethod - async def get_by_ktid(cls, ktid: Long) -> Puppet | None: + async def get_by_ktid(cls, ktid: int) -> Puppet | None: q = ( "SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, " " custom_mxid, access_token, next_batch, base_url " "FROM puppet WHERE ktid=$1" ) - row = await cls.db.fetchrow(q, bytes(ktid)) + row = await cls.db.fetchrow(q, ktid) return cls._from_optional_row(row) @classmethod @@ -100,7 +99,7 @@ class Puppet: @property def _values(self): return ( - bytes(self.ktid), + self.ktid, self.name, self.photo_id, self.photo_mxc, @@ -123,7 +122,7 @@ class Puppet: async def delete(self) -> None: q = "DELETE FROM puppet WHERE ktid=$1" - await self.db.execute(q, bytes(self.ktid)) + await self.db.execute(q, self.ktid) async def save(self) -> None: q = """ diff --git a/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py b/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py index e37dbf4..d24afae 100644 --- a/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py +++ b/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py @@ -30,7 +30,7 @@ async def create_v1_tables(conn: Connection) -> None: await conn.execute( """CREATE TABLE "user" ( mxid TEXT PRIMARY KEY, - ktid BYTES UNIQUE, + ktid BIGINT UNIQUE, uuid TEXT, access_token TEXT, refresh_token TEXT, @@ -39,8 +39,8 @@ async def create_v1_tables(conn: Connection) -> None: ) await conn.execute( """CREATE TABLE portal ( - ktid BYTES, - kt_receiver BYTES, + ktid BIGINT, + kt_receiver BIGINT, kt_type TEXT, mxid TEXT UNIQUE, name TEXT, @@ -55,7 +55,7 @@ async def create_v1_tables(conn: Connection) -> None: ) await conn.execute( """CREATE TABLE puppet ( - ktid BYTES PRIMARY KEY, + ktid BIGINT PRIMARY KEY, name TEXT, photo_id TEXT, photo_mxc TEXT, @@ -74,10 +74,10 @@ async def create_v1_tables(conn: Connection) -> None: """CREATE TABLE message ( mxid TEXT, mx_room TEXT, - ktid TEXT, - kt_receiver BYTES, + ktid BIGINT, + kt_receiver BIGINT, "index" SMALLINT, - kt_chat BYTES, + kt_chat BIGINT, timestamp BIGINT, PRIMARY KEY (ktid, kt_receiver, "index"), FOREIGN KEY (kt_chat, kt_receiver) REFERENCES portal(ktid, kt_receiver) @@ -90,8 +90,8 @@ async def create_v1_tables(conn: Connection) -> None: mxid TEXT, mx_room TEXT, kt_msgid TEXT, - kt_receiver BYTES, - kt_sender BYTES, + kt_receiver BIGINT, + kt_sender BIGINT, reaction TEXT, PRIMARY KEY (kt_msgid, kt_receiver, kt_sender), UNIQUE (mxid, mx_room) diff --git a/matrix_appservice_kakaotalk/db/user.py b/matrix_appservice_kakaotalk/db/user.py index 005621f..22f8dbe 100644 --- a/matrix_appservice_kakaotalk/db/user.py +++ b/matrix_appservice_kakaotalk/db/user.py @@ -18,7 +18,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar, List, Set from asyncpg import Record -from attr import dataclass +from attr import dataclass, field from mautrix.types import RoomID, UserID from mautrix.util.async_db import Database @@ -33,7 +33,7 @@ class User: db: ClassVar[Database] = fake_db mxid: UserID - ktid: Long | None + ktid: Long | None = field(converter=lambda x: Long(x) if x is not None else None) uuid: str | None access_token: str | None refresh_token: str | None @@ -41,9 +41,7 @@ class User: @classmethod def _from_row(cls, row: Record) -> User: - data = {**row} - ktid = data.pop("ktid", None) - return cls(**data, ktid=Long.from_optional_bytes(ktid)) + return cls(**row) @classmethod def _from_optional_row(cls, row: Record | None) -> User | None: @@ -59,9 +57,9 @@ class User: return [cls._from_row(row) for row in rows if row] @classmethod - async def get_by_ktid(cls, ktid: Long) -> User | None: + async def get_by_ktid(cls, ktid: int) -> User | None: q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1' - row = await cls.db.fetchrow(q, bytes(ktid)) + row = await cls.db.fetchrow(q, ktid) return cls._from_optional_row(row) @classmethod @@ -81,7 +79,7 @@ class User: VALUES ($1, $2, $3, $4, $5, $6) """ await self.db.execute( - q, self.mxid, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room + q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room ) async def delete(self) -> None: @@ -93,5 +91,5 @@ class User: WHERE mxid=$6 """ await self.db.execute( - q, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid + q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid ) diff --git a/matrix_appservice_kakaotalk/kt/types/bson.py b/matrix_appservice_kakaotalk/kt/types/bson.py index 296e3e5..ee71916 100644 --- a/matrix_appservice_kakaotalk/kt/types/bson.py +++ b/matrix_appservice_kakaotalk/kt/types/bson.py @@ -13,69 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import ClassVar, Optional - -from attr import dataclass, asdict -import bson - -from mautrix.types import SerializableAttrs, JSON +from mautrix.types import Serializable, JSON -@dataclass(frozen=True) -class Long(SerializableAttrs): - high: int - low: int - unsigned: bool - - @classmethod - def from_bytes(cls, raw: bytes) -> "Long": - return cls(**bson.loads(raw)) - - @classmethod - def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]: - return cls(**bson.loads(raw)) if raw is not None else None - - @classmethod - def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]: - return bytes(value) if value is not None else None - +class Long(int, Serializable): def serialize(self) -> JSON: - data = super().serialize() - data["__type__"] = "Long" - return data + return {"__type__": "Long", "str": str(self)} - def __bytes__(self) -> bytes: - return bson.dumps(asdict(self)) - - def __int__(self) -> int: - if self.unsigned: - pass - result = \ - ((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \ - ( self.low + (1 << 32 if self.low < 0 else 0)) - return result + (1 << 32 if self.unsigned and result < 0 else 0) - - def __str__(self) -> str: - return str(int(self)) - - ZERO: ClassVar["Long"] - -Long.ZERO = Long(0, 0, False) - - -class IntLong(Long): - """Helper class for constructing a Long from an int.""" - def __init__(self, val: int): - if val < 0: - pass - super().__init__( - high=(val & 0xffffffff00000000) >> 32, - low = val & 0x00000000ffffffff, - unsigned=val < 0, - ) - - -class StrLong(IntLong): - """Helper class for constructing a Long from the string representation of an int.""" - def __init__(self, val: str): - super().__init__(int(val)) + @classmethod + def deserialize(cls, raw: JSON) -> "Long": + assert isinstance(raw, str), f"Long deserialization expected a string, but got non-string value {raw}" + return cls(raw) diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py index 1187461..8ed18e1 100644 --- a/matrix_appservice_kakaotalk/portal.py +++ b/matrix_appservice_kakaotalk/portal.py @@ -48,7 +48,7 @@ from .db import ( ) from .formatter import kakaotalk_to_matrix, matrix_to_kakaotalk -from .kt.types.bson import Long, IntLong +from .kt.types.bson import Long from .kt.types.channel.channel_info import ChannelInfo from .kt.types.channel.channel_type import KnownChannelType, ChannelType from .kt.types.chat.chat import Chatlog @@ -87,7 +87,7 @@ StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STAT class Portal(DBPortal, BasePortal): invite_own_puppet_to_pm: bool = False by_mxid: dict[RoomID, Portal] = {} - by_ktid: dict[tuple[Long, Long], Portal] = {} + by_ktid: dict[tuple[int, int], Portal] = {} matrix: m.MatrixHandler config: Config @@ -162,7 +162,7 @@ class Portal(DBPortal, BasePortal): async def delete(self) -> None: if self.mxid: await DBMessage.delete_all_by_room(self.mxid) - self.by_ktid.pop(self.ktid_full, None) + self.by_ktid.pop(self._ktid_full, None) self.by_mxid.pop(self.mxid, None) await super().delete() @@ -170,7 +170,7 @@ class Portal(DBPortal, BasePortal): # region Properties @property - def ktid_full(self) -> tuple[Long, Long]: + def _ktid_full(self) -> tuple[int, int]: return self.ktid, self.kt_receiver @property @@ -617,7 +617,7 @@ class Portal(DBPortal, BasePortal): # endregion # region Matrix event handling - def require_send_lock(self, user_id: Long) -> asyncio.Lock: + def require_send_lock(self, user_id: int) -> asyncio.Lock: try: lock = self._send_locks[user_id] except KeyError: @@ -625,7 +625,7 @@ class Portal(DBPortal, BasePortal): self._send_locks[user_id] = lock return lock - def optional_send_lock(self, user_id: Long) -> asyncio.Lock | FakeLock: + def optional_send_lock(self, user_id: int) -> asyncio.Lock | FakeLock: try: return self._send_locks[user_id] except KeyError: @@ -937,8 +937,8 @@ class Portal(DBPortal, BasePortal): after_log_id: Long | None, channel_info: ChannelInfo, ) -> None: - self.log.debug("Backfilling history through %s", source.mxid) - self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid)) + self.log.debug(f"Backfilling history through {source.mxid}") + self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}") messages = await source.client.get_chats( channel_info.channelId, limit, @@ -961,7 +961,7 @@ class Portal(DBPortal, BasePortal): # region Database getters async def postinit(self) -> None: - self.by_ktid[self.ktid_full] = self + self.by_ktid[self._ktid_full] = self if self.mxid: self.by_mxid[self.mxid] = self self._main_intent = ( @@ -989,14 +989,14 @@ class Portal(DBPortal, BasePortal): @async_getter_lock async def get_by_ktid( cls, - ktid: Long, + ktid: int, *, - kt_receiver: Long = Long.ZERO, + kt_receiver: int = 0, create: bool = True, kt_type: ChannelType | None = None, ) -> Portal | None: if kt_type: - kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else Long.ZERO + kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0 ktid_full = (ktid, kt_receiver) try: return cls.by_ktid[ktid_full] @@ -1017,7 +1017,7 @@ class Portal(DBPortal, BasePortal): return None @classmethod - async def get_all_by_receiver(cls, kt_receiver: Long) -> AsyncGenerator[Portal, None]: + async def get_all_by_receiver(cls, kt_receiver: int) -> AsyncGenerator[Portal, None]: portals = await super().get_all_by_receiver(kt_receiver) portal: Portal for portal in portals: diff --git a/matrix_appservice_kakaotalk/puppet.py b/matrix_appservice_kakaotalk/puppet.py index 180a1ac..644bfba 100644 --- a/matrix_appservice_kakaotalk/puppet.py +++ b/matrix_appservice_kakaotalk/puppet.py @@ -31,7 +31,7 @@ from . import matrix as m, portal as p, user as u from .config import Config from .db import Puppet as DBPuppet -from .kt.types.bson import Long, StrLong +from .kt.types.bson import Long from .kt.client.types import UserInfoUnion @@ -43,9 +43,9 @@ class Puppet(DBPuppet, BasePuppet): mx: m.MatrixHandler config: Config hs_domain: str - mxid_template: SimpleTemplate[StrLong] + mxid_template: SimpleTemplate[int] - by_ktid: dict[Long, Puppet] = {} + by_ktid: dict[int, Puppet] = {} by_custom_mxid: dict[UserID, Puppet] = {} _last_info_sync: datetime | None @@ -127,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet): keyword="userid", prefix="@", suffix=f":{Puppet.hs_domain}", - type=StrLong, + type=int, ) cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"] cls.homeserver_url_map = { @@ -218,9 +218,9 @@ class Puppet(DBPuppet, BasePuppet): @classmethod @async_getter_lock - async def get_by_ktid(cls, ktid: Long, *, create: bool = True) -> Puppet | None: + async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Puppet | None: try: - return cls.by_ktid[ktid] + return cls.by_ktid[int] except KeyError: pass @@ -260,11 +260,11 @@ class Puppet(DBPuppet, BasePuppet): return None @classmethod - def get_id_from_mxid(cls, mxid: UserID) -> Long | None: + def get_id_from_mxid(cls, mxid: UserID) -> int | None: return cls.mxid_template.parse(mxid) @classmethod - def get_mxid_from_id(cls, ktid: Long) -> UserID: + def get_mxid_from_id(cls, ktid: int) -> UserID: return UserID(cls.mxid_template.format_full(ktid)) @classmethod diff --git a/matrix_appservice_kakaotalk/user.py b/matrix_appservice_kakaotalk/user.py index 2e8fce4..254c6aa 100644 --- a/matrix_appservice_kakaotalk/user.py +++ b/matrix_appservice_kakaotalk/user.py @@ -84,7 +84,7 @@ class User(DBUser, BaseUser): config: Config by_mxid: dict[UserID, User] = {} - by_ktid: dict[Long, User] = {} + by_ktid: dict[int, User] = {} client: Client | None @@ -217,7 +217,7 @@ class User(DBUser, BaseUser): @classmethod @async_getter_lock - async def get_by_ktid(cls, ktid: Long) -> User | None: + async def get_by_ktid(cls, ktid: int) -> User | None: try: return cls.by_ktid[ktid] except KeyError: diff --git a/node/src/client.js b/node/src/client.js index 457cf2b..6320abd 100644 --- a/node/src/client.js +++ b/node/src/client.js @@ -168,10 +168,9 @@ export default class PeerClient { * @returns {Promise} */ #write(data) { - return promisify(cb => this.socket.write(JSON.stringify(data) + "\n", cb)) + return promisify(cb => this.socket.write(JSON.stringify(data, this.#writeReplacer) + "\n", cb)) } - /** * @param {Object} req * @param {string} req.passcode @@ -414,7 +413,7 @@ export default class PeerClient { } let req try { - req = JSON.parse(line) + req = JSON.parse(line, this.#readReviver) } catch (err) { this.log("Non-JSON request:", line) return @@ -465,7 +464,6 @@ export default class PeerClient { const resp = { id: req.id } delete req.id delete req.command - req = typeify(req) resp.command = "response" try { resp.response = await handler(req) @@ -483,29 +481,22 @@ export default class PeerClient { } await this.#write(resp) } -} -/** - * Recursively scan an object to check if any of its sub-objects - * should be converted into instances of a specified class. - * @param obj The object to be scanned & updated. - * @returns The converted object. - */ -function typeify(obj) { - if (!(obj instanceof Object)) { - return obj + #writeReplacer = function(key, value) { + if (value instanceof Long) { + return value.toString() + } else { + return value + } } - const converterFunc = TYPE_MAP.get(obj.__type__) - if (converterFunc !== undefined) { - return converterFunc(obj) - } - for (const key in obj) { - obj[key] = typeify(obj[key]) - } - return obj -} -// TODO Add more if needed -const TYPE_MAP = new Map([ - ["Long", (obj) => new Long(obj.low, obj.high, obj.unsigned)], -]) + #readReviver = function(key, value) { + if (value instanceof Object) { + // TODO Use a type map if there will be many possible types + if (value.__type__ == "Long") { + return Long.fromString(value.str) + } + } + return value + } +}