diff --git a/matrix_appservice_kakaotalk/db/message.py b/matrix_appservice_kakaotalk/db/message.py
index 5c8d878..30ebb0c 100644
--- a/matrix_appservice_kakaotalk/db/message.py
+++ b/matrix_appservice_kakaotalk/db/message.py
@@ -18,12 +18,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
-from attr import dataclass
+from attr import dataclass, field
from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database
-from ..kt.types.bson import Long, StrLong
+from ..kt.types.bson import Long
fake_db = Database.create("") if TYPE_CHECKING else None
@@ -32,27 +32,17 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Message:
db: ClassVar[Database] = fake_db
- # TODO Store all Long values as the same type
mxid: EventID
mx_room: RoomID
- ktid: Long
+ ktid: Long = field(converter=Long)
index: int
- kt_chat: Long
- kt_receiver: Long
+ kt_chat: Long = field(converter=Long)
+ kt_receiver: Long = field(converter=Long)
timestamp: int
@classmethod
- def _from_row(cls, row: Record) -> Message | None:
- data = {**row}
- ktid = data.pop("ktid")
- kt_chat = data.pop("kt_chat")
- kt_receiver = data.pop("kt_receiver")
- return cls(
- **data,
- ktid=StrLong(ktid),
- kt_chat=Long.from_bytes(kt_chat),
- kt_receiver=Long.from_bytes(kt_receiver)
- )
+ def _from_row(cls, row: Record) -> Message:
+ return cls(**row)
@classmethod
def _from_optional_row(cls, row: Record | None) -> Message | None:
@@ -61,15 +51,15 @@ class Message:
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
@classmethod
- async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
+ async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
- rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
+ rows = await cls.db.fetch(q, ktid, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
- async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
+ async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
- row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
+ row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
return cls._from_optional_row(row)
@classmethod
@@ -83,18 +73,18 @@ class Message:
return cls._from_optional_row(row)
@classmethod
- async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
+ async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1"
)
- row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
+ row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
return cls._from_optional_row(row)
@classmethod
async def get_closest_before(
- cls, kt_chat: Long, kt_receiver: Long, timestamp: int
+ cls, kt_chat: int, kt_receiver: int, timestamp: int
) -> Message | None:
q = (
f"SELECT {cls.columns} "
@@ -102,7 +92,7 @@ class Message:
" ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1"
)
- row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
+ row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
return cls._from_optional_row(row)
_insert_query = (
@@ -111,46 +101,23 @@ class Message:
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
- @classmethod
- async def bulk_create(
- cls,
- ktid: Long,
- kt_chat: Long,
- kt_receiver: Long,
- event_ids: list[EventID],
- timestamp: int,
- mx_room: RoomID,
- ) -> None:
- if not event_ids:
- return
- columns = [col.strip('"') for col in cls.columns.split(", ")]
- records = [
- (mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
- for index, mxid in enumerate(event_ids)
- ]
- async with cls.db.acquire() as conn, conn.transaction():
- if cls.db.scheme == "postgres":
- await conn.copy_records_to_table("message", records=records, columns=columns)
- else:
- await conn.executemany(cls._insert_query, records)
-
async def insert(self) -> None:
q = self._insert_query
await self.db.execute(
q,
self.mxid,
self.mx_room,
- str(self.ktid),
+ self.ktid,
self.index,
- bytes(self.kt_chat),
- bytes(self.kt_receiver),
+ self.kt_chat,
+ self.kt_receiver,
self.timestamp,
)
async def delete(self) -> None:
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
- await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index)
+ await self.db.execute(q, self.ktid, self.kt_receiver, self.index)
async def update(self) -> None:
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
- await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room)
+ await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)
diff --git a/matrix_appservice_kakaotalk/db/portal.py b/matrix_appservice_kakaotalk/db/portal.py
index 6a97212..3b5e756 100644
--- a/matrix_appservice_kakaotalk/db/portal.py
+++ b/matrix_appservice_kakaotalk/db/portal.py
@@ -18,7 +18,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
-from attr import dataclass
+from attr import dataclass, field
from mautrix.types import ContentURI, RoomID, UserID
from mautrix.util.async_db import Database
@@ -33,8 +33,8 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Portal:
db: ClassVar[Database] = fake_db
- ktid: Long
- kt_receiver: Long
+ ktid: Long = field(converter=Long)
+ kt_receiver: Long = field(converter=Long)
kt_type: ChannelType
mxid: RoomID | None
name: str | None
@@ -47,23 +47,20 @@ class Portal:
@classmethod
def _from_row(cls, row: Record) -> Portal:
- data = {**row}
- ktid = data.pop("ktid")
- kt_receiver = data.pop("kt_receiver")
- return cls(**data, ktid=Long.from_bytes(ktid), kt_receiver=Long.from_bytes(kt_receiver))
+ return cls(**row)
@classmethod
def _from_optional_row(cls, row: Record | None) -> Portal | None:
return cls._from_row(row) if row is not None else None
@classmethod
- async def get_by_ktid(cls, ktid: Long, kt_receiver: Long) -> Portal | None:
+ async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id
FROM portal WHERE ktid=$1 AND kt_receiver=$2
"""
- row = await cls.db.fetchrow(q, bytes(ktid), bytes(kt_receiver))
+ row = await cls.db.fetchrow(q, ktid, kt_receiver)
return cls._from_optional_row(row)
@classmethod
@@ -77,13 +74,13 @@ class Portal:
return cls._from_optional_row(row)
@classmethod
- async def get_all_by_receiver(cls, kt_receiver: Long) -> list[Portal]:
+ async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id
FROM portal WHERE kt_receiver=$1
"""
- rows = await cls.db.fetch(q, bytes(kt_receiver))
+ rows = await cls.db.fetch(q, kt_receiver)
return [cls._from_row(row) for row in rows if row]
@classmethod
@@ -99,8 +96,8 @@ class Portal:
@property
def _values(self):
return (
- Long.to_optional_bytes(self.ktid),
- Long.to_optional_bytes(self.kt_receiver),
+ self.ktid,
+ self.kt_receiver,
self.kt_type,
self.mxid,
self.name,
@@ -122,7 +119,7 @@ class Portal:
async def delete(self) -> None:
q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2"
- await self.db.execute(q, Long.to_optional_bytes(self.ktid), Long.to_optional_bytes(self.kt_receiver))
+ await self.db.execute(q, self.ktid, self.kt_receiver)
async def save(self) -> None:
q = """
diff --git a/matrix_appservice_kakaotalk/db/puppet.py b/matrix_appservice_kakaotalk/db/puppet.py
index 937949c..3f9b1fe 100644
--- a/matrix_appservice_kakaotalk/db/puppet.py
+++ b/matrix_appservice_kakaotalk/db/puppet.py
@@ -18,7 +18,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
-from attr import dataclass
+from attr import dataclass, field
from yarl import URL
from mautrix.types import ContentURI, SyncToken, UserID
@@ -33,7 +33,7 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Puppet:
db: ClassVar[Database] = fake_db
- ktid: Long
+ ktid: Long = field(converter=Long)
name: str | None
photo_id: str | None
photo_mxc: ContentURI | None
@@ -49,22 +49,21 @@ class Puppet:
@classmethod
def _from_row(cls, row: Record) -> Puppet:
data = {**row}
- ktid = data.pop("ktid")
base_url = data.pop("base_url", None)
- return cls(**data, ktid=Long.from_optional_bytes(ktid), base_url=URL(base_url) if base_url else None)
+ return cls(**data, base_url=URL(base_url) if base_url else None)
@classmethod
def _from_optional_row(cls, row: Record | None) -> Puppet | None:
return cls._from_row(row) if row is not None else None
@classmethod
- async def get_by_ktid(cls, ktid: Long) -> Puppet | None:
+ async def get_by_ktid(cls, ktid: int) -> Puppet | None:
q = (
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
" custom_mxid, access_token, next_batch, base_url "
"FROM puppet WHERE ktid=$1"
)
- row = await cls.db.fetchrow(q, bytes(ktid))
+ row = await cls.db.fetchrow(q, ktid)
return cls._from_optional_row(row)
@classmethod
@@ -100,7 +99,7 @@ class Puppet:
@property
def _values(self):
return (
- bytes(self.ktid),
+ self.ktid,
self.name,
self.photo_id,
self.photo_mxc,
@@ -123,7 +122,7 @@ class Puppet:
async def delete(self) -> None:
q = "DELETE FROM puppet WHERE ktid=$1"
- await self.db.execute(q, bytes(self.ktid))
+ await self.db.execute(q, self.ktid)
async def save(self) -> None:
q = """
diff --git a/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py b/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py
index e37dbf4..d24afae 100644
--- a/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py
+++ b/matrix_appservice_kakaotalk/db/upgrade/v01_initial_revision.py
@@ -30,7 +30,7 @@ async def create_v1_tables(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
- ktid BYTES UNIQUE,
+ ktid BIGINT UNIQUE,
uuid TEXT,
access_token TEXT,
refresh_token TEXT,
@@ -39,8 +39,8 @@ async def create_v1_tables(conn: Connection) -> None:
)
await conn.execute(
"""CREATE TABLE portal (
- ktid BYTES,
- kt_receiver BYTES,
+ ktid BIGINT,
+ kt_receiver BIGINT,
kt_type TEXT,
mxid TEXT UNIQUE,
name TEXT,
@@ -55,7 +55,7 @@ async def create_v1_tables(conn: Connection) -> None:
)
await conn.execute(
"""CREATE TABLE puppet (
- ktid BYTES PRIMARY KEY,
+ ktid BIGINT PRIMARY KEY,
name TEXT,
photo_id TEXT,
photo_mxc TEXT,
@@ -74,10 +74,10 @@ async def create_v1_tables(conn: Connection) -> None:
"""CREATE TABLE message (
mxid TEXT,
mx_room TEXT,
- ktid TEXT,
- kt_receiver BYTES,
+ ktid BIGINT,
+ kt_receiver BIGINT,
"index" SMALLINT,
- kt_chat BYTES,
+ kt_chat BIGINT,
timestamp BIGINT,
PRIMARY KEY (ktid, kt_receiver, "index"),
FOREIGN KEY (kt_chat, kt_receiver) REFERENCES portal(ktid, kt_receiver)
@@ -90,8 +90,8 @@ async def create_v1_tables(conn: Connection) -> None:
mxid TEXT,
mx_room TEXT,
kt_msgid TEXT,
- kt_receiver BYTES,
- kt_sender BYTES,
+ kt_receiver BIGINT,
+ kt_sender BIGINT,
reaction TEXT,
PRIMARY KEY (kt_msgid, kt_receiver, kt_sender),
UNIQUE (mxid, mx_room)
diff --git a/matrix_appservice_kakaotalk/db/user.py b/matrix_appservice_kakaotalk/db/user.py
index 005621f..22f8dbe 100644
--- a/matrix_appservice_kakaotalk/db/user.py
+++ b/matrix_appservice_kakaotalk/db/user.py
@@ -18,7 +18,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar, List, Set
from asyncpg import Record
-from attr import dataclass
+from attr import dataclass, field
from mautrix.types import RoomID, UserID
from mautrix.util.async_db import Database
@@ -33,7 +33,7 @@ class User:
db: ClassVar[Database] = fake_db
mxid: UserID
- ktid: Long | None
+ ktid: Long | None = field(converter=lambda x: Long(x) if x is not None else None)
uuid: str | None
access_token: str | None
refresh_token: str | None
@@ -41,9 +41,7 @@ class User:
@classmethod
def _from_row(cls, row: Record) -> User:
- data = {**row}
- ktid = data.pop("ktid", None)
- return cls(**data, ktid=Long.from_optional_bytes(ktid))
+ return cls(**row)
@classmethod
def _from_optional_row(cls, row: Record | None) -> User | None:
@@ -59,9 +57,9 @@ class User:
return [cls._from_row(row) for row in rows if row]
@classmethod
- async def get_by_ktid(cls, ktid: Long) -> User | None:
+ async def get_by_ktid(cls, ktid: int) -> User | None:
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
- row = await cls.db.fetchrow(q, bytes(ktid))
+ row = await cls.db.fetchrow(q, ktid)
return cls._from_optional_row(row)
@classmethod
@@ -81,7 +79,7 @@ class User:
VALUES ($1, $2, $3, $4, $5, $6)
"""
await self.db.execute(
- q, self.mxid, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room
+ q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room
)
async def delete(self) -> None:
@@ -93,5 +91,5 @@ class User:
WHERE mxid=$6
"""
await self.db.execute(
- q, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
+ q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
)
diff --git a/matrix_appservice_kakaotalk/kt/types/bson.py b/matrix_appservice_kakaotalk/kt/types/bson.py
index 296e3e5..ee71916 100644
--- a/matrix_appservice_kakaotalk/kt/types/bson.py
+++ b/matrix_appservice_kakaotalk/kt/types/bson.py
@@ -13,69 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import ClassVar, Optional
-
-from attr import dataclass, asdict
-import bson
-
-from mautrix.types import SerializableAttrs, JSON
+from mautrix.types import Serializable, JSON
-@dataclass(frozen=True)
-class Long(SerializableAttrs):
- high: int
- low: int
- unsigned: bool
-
- @classmethod
- def from_bytes(cls, raw: bytes) -> "Long":
- return cls(**bson.loads(raw))
-
- @classmethod
- def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
- return cls(**bson.loads(raw)) if raw is not None else None
-
- @classmethod
- def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
- return bytes(value) if value is not None else None
-
+class Long(int, Serializable):
def serialize(self) -> JSON:
- data = super().serialize()
- data["__type__"] = "Long"
- return data
+ return {"__type__": "Long", "str": str(self)}
- def __bytes__(self) -> bytes:
- return bson.dumps(asdict(self))
-
- def __int__(self) -> int:
- if self.unsigned:
- pass
- result = \
- ((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
- ( self.low + (1 << 32 if self.low < 0 else 0))
- return result + (1 << 32 if self.unsigned and result < 0 else 0)
-
- def __str__(self) -> str:
- return str(int(self))
-
- ZERO: ClassVar["Long"]
-
-Long.ZERO = Long(0, 0, False)
-
-
-class IntLong(Long):
- """Helper class for constructing a Long from an int."""
- def __init__(self, val: int):
- if val < 0:
- pass
- super().__init__(
- high=(val & 0xffffffff00000000) >> 32,
- low = val & 0x00000000ffffffff,
- unsigned=val < 0,
- )
-
-
-class StrLong(IntLong):
- """Helper class for constructing a Long from the string representation of an int."""
- def __init__(self, val: str):
- super().__init__(int(val))
+ @classmethod
+ def deserialize(cls, raw: JSON) -> "Long":
+ assert isinstance(raw, str), f"Long deserialization expected a string, but got non-string value {raw}"
+ return cls(raw)
diff --git a/matrix_appservice_kakaotalk/portal.py b/matrix_appservice_kakaotalk/portal.py
index 1187461..8ed18e1 100644
--- a/matrix_appservice_kakaotalk/portal.py
+++ b/matrix_appservice_kakaotalk/portal.py
@@ -48,7 +48,7 @@ from .db import (
)
from .formatter import kakaotalk_to_matrix, matrix_to_kakaotalk
-from .kt.types.bson import Long, IntLong
+from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
from .kt.types.chat.chat import Chatlog
@@ -87,7 +87,7 @@ StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STAT
class Portal(DBPortal, BasePortal):
invite_own_puppet_to_pm: bool = False
by_mxid: dict[RoomID, Portal] = {}
- by_ktid: dict[tuple[Long, Long], Portal] = {}
+ by_ktid: dict[tuple[int, int], Portal] = {}
matrix: m.MatrixHandler
config: Config
@@ -162,7 +162,7 @@ class Portal(DBPortal, BasePortal):
async def delete(self) -> None:
if self.mxid:
await DBMessage.delete_all_by_room(self.mxid)
- self.by_ktid.pop(self.ktid_full, None)
+ self.by_ktid.pop(self._ktid_full, None)
self.by_mxid.pop(self.mxid, None)
await super().delete()
@@ -170,7 +170,7 @@ class Portal(DBPortal, BasePortal):
# region Properties
@property
- def ktid_full(self) -> tuple[Long, Long]:
+ def _ktid_full(self) -> tuple[int, int]:
return self.ktid, self.kt_receiver
@property
@@ -617,7 +617,7 @@ class Portal(DBPortal, BasePortal):
# endregion
# region Matrix event handling
- def require_send_lock(self, user_id: Long) -> asyncio.Lock:
+ def require_send_lock(self, user_id: int) -> asyncio.Lock:
try:
lock = self._send_locks[user_id]
except KeyError:
@@ -625,7 +625,7 @@ class Portal(DBPortal, BasePortal):
self._send_locks[user_id] = lock
return lock
- def optional_send_lock(self, user_id: Long) -> asyncio.Lock | FakeLock:
+ def optional_send_lock(self, user_id: int) -> asyncio.Lock | FakeLock:
try:
return self._send_locks[user_id]
except KeyError:
@@ -937,8 +937,8 @@ class Portal(DBPortal, BasePortal):
after_log_id: Long | None,
channel_info: ChannelInfo,
) -> None:
- self.log.debug("Backfilling history through %s", source.mxid)
- self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid))
+ self.log.debug(f"Backfilling history through {source.mxid}")
+ self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
messages = await source.client.get_chats(
channel_info.channelId,
limit,
@@ -961,7 +961,7 @@ class Portal(DBPortal, BasePortal):
# region Database getters
async def postinit(self) -> None:
- self.by_ktid[self.ktid_full] = self
+ self.by_ktid[self._ktid_full] = self
if self.mxid:
self.by_mxid[self.mxid] = self
self._main_intent = (
@@ -989,14 +989,14 @@ class Portal(DBPortal, BasePortal):
@async_getter_lock
async def get_by_ktid(
cls,
- ktid: Long,
+ ktid: int,
*,
- kt_receiver: Long = Long.ZERO,
+ kt_receiver: int = 0,
create: bool = True,
kt_type: ChannelType | None = None,
) -> Portal | None:
if kt_type:
- kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else Long.ZERO
+ kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
ktid_full = (ktid, kt_receiver)
try:
return cls.by_ktid[ktid_full]
@@ -1017,7 +1017,7 @@ class Portal(DBPortal, BasePortal):
return None
@classmethod
- async def get_all_by_receiver(cls, kt_receiver: Long) -> AsyncGenerator[Portal, None]:
+ async def get_all_by_receiver(cls, kt_receiver: int) -> AsyncGenerator[Portal, None]:
portals = await super().get_all_by_receiver(kt_receiver)
portal: Portal
for portal in portals:
diff --git a/matrix_appservice_kakaotalk/puppet.py b/matrix_appservice_kakaotalk/puppet.py
index 180a1ac..644bfba 100644
--- a/matrix_appservice_kakaotalk/puppet.py
+++ b/matrix_appservice_kakaotalk/puppet.py
@@ -31,7 +31,7 @@ from . import matrix as m, portal as p, user as u
from .config import Config
from .db import Puppet as DBPuppet
-from .kt.types.bson import Long, StrLong
+from .kt.types.bson import Long
from .kt.client.types import UserInfoUnion
@@ -43,9 +43,9 @@ class Puppet(DBPuppet, BasePuppet):
mx: m.MatrixHandler
config: Config
hs_domain: str
- mxid_template: SimpleTemplate[StrLong]
+ mxid_template: SimpleTemplate[int]
- by_ktid: dict[Long, Puppet] = {}
+ by_ktid: dict[int, Puppet] = {}
by_custom_mxid: dict[UserID, Puppet] = {}
_last_info_sync: datetime | None
@@ -127,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
keyword="userid",
prefix="@",
suffix=f":{Puppet.hs_domain}",
- type=StrLong,
+ type=int,
)
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
cls.homeserver_url_map = {
@@ -218,9 +218,9 @@ class Puppet(DBPuppet, BasePuppet):
@classmethod
@async_getter_lock
- async def get_by_ktid(cls, ktid: Long, *, create: bool = True) -> Puppet | None:
+ async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Puppet | None:
try:
- return cls.by_ktid[ktid]
+ return cls.by_ktid[int]
except KeyError:
pass
@@ -260,11 +260,11 @@ class Puppet(DBPuppet, BasePuppet):
return None
@classmethod
- def get_id_from_mxid(cls, mxid: UserID) -> Long | None:
+ def get_id_from_mxid(cls, mxid: UserID) -> int | None:
return cls.mxid_template.parse(mxid)
@classmethod
- def get_mxid_from_id(cls, ktid: Long) -> UserID:
+ def get_mxid_from_id(cls, ktid: int) -> UserID:
return UserID(cls.mxid_template.format_full(ktid))
@classmethod
diff --git a/matrix_appservice_kakaotalk/user.py b/matrix_appservice_kakaotalk/user.py
index 2e8fce4..254c6aa 100644
--- a/matrix_appservice_kakaotalk/user.py
+++ b/matrix_appservice_kakaotalk/user.py
@@ -84,7 +84,7 @@ class User(DBUser, BaseUser):
config: Config
by_mxid: dict[UserID, User] = {}
- by_ktid: dict[Long, User] = {}
+ by_ktid: dict[int, User] = {}
client: Client | None
@@ -217,7 +217,7 @@ class User(DBUser, BaseUser):
@classmethod
@async_getter_lock
- async def get_by_ktid(cls, ktid: Long) -> User | None:
+ async def get_by_ktid(cls, ktid: int) -> User | None:
try:
return cls.by_ktid[ktid]
except KeyError:
diff --git a/node/src/client.js b/node/src/client.js
index 457cf2b..6320abd 100644
--- a/node/src/client.js
+++ b/node/src/client.js
@@ -168,10 +168,9 @@ export default class PeerClient {
* @returns {Promise}
*/
#write(data) {
- return promisify(cb => this.socket.write(JSON.stringify(data) + "\n", cb))
+ return promisify(cb => this.socket.write(JSON.stringify(data, this.#writeReplacer) + "\n", cb))
}
-
/**
* @param {Object} req
* @param {string} req.passcode
@@ -414,7 +413,7 @@ export default class PeerClient {
}
let req
try {
- req = JSON.parse(line)
+ req = JSON.parse(line, this.#readReviver)
} catch (err) {
this.log("Non-JSON request:", line)
return
@@ -465,7 +464,6 @@ export default class PeerClient {
const resp = { id: req.id }
delete req.id
delete req.command
- req = typeify(req)
resp.command = "response"
try {
resp.response = await handler(req)
@@ -483,29 +481,22 @@ export default class PeerClient {
}
await this.#write(resp)
}
-}
-/**
- * Recursively scan an object to check if any of its sub-objects
- * should be converted into instances of a specified class.
- * @param obj The object to be scanned & updated.
- * @returns The converted object.
- */
-function typeify(obj) {
- if (!(obj instanceof Object)) {
- return obj
+ #writeReplacer = function(key, value) {
+ if (value instanceof Long) {
+ return value.toString()
+ } else {
+ return value
+ }
}
- const converterFunc = TYPE_MAP.get(obj.__type__)
- if (converterFunc !== undefined) {
- return converterFunc(obj)
- }
- for (const key in obj) {
- obj[key] = typeify(obj[key])
- }
- return obj
-}
-// TODO Add more if needed
-const TYPE_MAP = new Map([
- ["Long", (obj) => new Long(obj.low, obj.high, obj.unsigned)],
-])
+ #readReviver = function(key, value) {
+ if (value instanceof Object) {
+ // TODO Use a type map if there will be many possible types
+ if (value.__type__ == "Long") {
+ return Long.fromString(value.str)
+ }
+ }
+ return value
+ }
+}