Clean up Long

Keep it as an int in Python, and do all fancy conversions in Node
This commit is contained in:
Andrew Ferrazzutti 2022-03-11 03:43:00 -05:00
parent 2a91a7b43e
commit 4158788496
10 changed files with 102 additions and 205 deletions

View File

@ -18,12 +18,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass, field
from mautrix.types import EventID, RoomID from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from ..kt.types.bson import Long, StrLong from ..kt.types.bson import Long
fake_db = Database.create("") if TYPE_CHECKING else None fake_db = Database.create("") if TYPE_CHECKING else None
@ -32,27 +32,17 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Message: class Message:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
# TODO Store all Long values as the same type
mxid: EventID mxid: EventID
mx_room: RoomID mx_room: RoomID
ktid: Long ktid: Long = field(converter=Long)
index: int index: int
kt_chat: Long kt_chat: Long = field(converter=Long)
kt_receiver: Long kt_receiver: Long = field(converter=Long)
timestamp: int timestamp: int
@classmethod @classmethod
def _from_row(cls, row: Record) -> Message | None: def _from_row(cls, row: Record) -> Message:
data = {**row} return cls(**row)
ktid = data.pop("ktid")
kt_chat = data.pop("kt_chat")
kt_receiver = data.pop("kt_receiver")
return cls(
**data,
ktid=StrLong(ktid),
kt_chat=Long.from_bytes(kt_chat),
kt_receiver=Long.from_bytes(kt_receiver)
)
@classmethod @classmethod
def _from_optional_row(cls, row: Record | None) -> Message | None: def _from_optional_row(cls, row: Record | None) -> Message | None:
@ -61,15 +51,15 @@ class Message:
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp' columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
@classmethod @classmethod
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]: async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2" q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver)) rows = await cls.db.fetch(q, ktid, kt_receiver)
return [cls._from_row(row) for row in rows if row] return [cls._from_row(row) for row in rows if row]
@classmethod @classmethod
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None: async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index) row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
return cls._from_optional_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
@ -83,18 +73,18 @@ class Message:
return cls._from_optional_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None: async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
q = ( q = (
f"SELECT {cls.columns} " f"SELECT {cls.columns} "
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL " "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1" "ORDER BY timestamp DESC LIMIT 1"
) )
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver)) row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
return cls._from_optional_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
async def get_closest_before( async def get_closest_before(
cls, kt_chat: Long, kt_receiver: Long, timestamp: int cls, kt_chat: int, kt_receiver: int, timestamp: int
) -> Message | None: ) -> Message | None:
q = ( q = (
f"SELECT {cls.columns} " f"SELECT {cls.columns} "
@ -102,7 +92,7 @@ class Message:
" ktid IS NOT NULL " " ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1" "ORDER BY timestamp DESC LIMIT 1"
) )
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp) row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
return cls._from_optional_row(row) return cls._from_optional_row(row)
_insert_query = ( _insert_query = (
@ -111,46 +101,23 @@ class Message:
"VALUES ($1, $2, $3, $4, $5, $6, $7)" "VALUES ($1, $2, $3, $4, $5, $6, $7)"
) )
@classmethod
async def bulk_create(
cls,
ktid: Long,
kt_chat: Long,
kt_receiver: Long,
event_ids: list[EventID],
timestamp: int,
mx_room: RoomID,
) -> None:
if not event_ids:
return
columns = [col.strip('"') for col in cls.columns.split(", ")]
records = [
(mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
for index, mxid in enumerate(event_ids)
]
async with cls.db.acquire() as conn, conn.transaction():
if cls.db.scheme == "postgres":
await conn.copy_records_to_table("message", records=records, columns=columns)
else:
await conn.executemany(cls._insert_query, records)
async def insert(self) -> None: async def insert(self) -> None:
q = self._insert_query q = self._insert_query
await self.db.execute( await self.db.execute(
q, q,
self.mxid, self.mxid,
self.mx_room, self.mx_room,
str(self.ktid), self.ktid,
self.index, self.index,
bytes(self.kt_chat), self.kt_chat,
bytes(self.kt_receiver), self.kt_receiver,
self.timestamp, self.timestamp,
) )
async def delete(self) -> None: async def delete(self) -> None:
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index) await self.db.execute(q, self.ktid, self.kt_receiver, self.index)
async def update(self) -> None: async def update(self) -> None:
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4" q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room) await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)

View File

@ -18,7 +18,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass, field
from mautrix.types import ContentURI, RoomID, UserID from mautrix.types import ContentURI, RoomID, UserID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
@ -33,8 +33,8 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Portal: class Portal:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
ktid: Long ktid: Long = field(converter=Long)
kt_receiver: Long kt_receiver: Long = field(converter=Long)
kt_type: ChannelType kt_type: ChannelType
mxid: RoomID | None mxid: RoomID | None
name: str | None name: str | None
@ -47,23 +47,20 @@ class Portal:
@classmethod @classmethod
def _from_row(cls, row: Record) -> Portal: def _from_row(cls, row: Record) -> Portal:
data = {**row} return cls(**row)
ktid = data.pop("ktid")
kt_receiver = data.pop("kt_receiver")
return cls(**data, ktid=Long.from_bytes(ktid), kt_receiver=Long.from_bytes(kt_receiver))
@classmethod @classmethod
def _from_optional_row(cls, row: Record | None) -> Portal | None: def _from_optional_row(cls, row: Record | None) -> Portal | None:
return cls._from_row(row) if row is not None else None return cls._from_row(row) if row is not None else None
@classmethod @classmethod
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long) -> Portal | None: async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
q = """ q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted, SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id name_set, avatar_set, relay_user_id
FROM portal WHERE ktid=$1 AND kt_receiver=$2 FROM portal WHERE ktid=$1 AND kt_receiver=$2
""" """
row = await cls.db.fetchrow(q, bytes(ktid), bytes(kt_receiver)) row = await cls.db.fetchrow(q, ktid, kt_receiver)
return cls._from_optional_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
@ -77,13 +74,13 @@ class Portal:
return cls._from_optional_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
async def get_all_by_receiver(cls, kt_receiver: Long) -> list[Portal]: async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
q = """ q = """
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted, SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
name_set, avatar_set, relay_user_id name_set, avatar_set, relay_user_id
FROM portal WHERE kt_receiver=$1 FROM portal WHERE kt_receiver=$1
""" """
rows = await cls.db.fetch(q, bytes(kt_receiver)) rows = await cls.db.fetch(q, kt_receiver)
return [cls._from_row(row) for row in rows if row] return [cls._from_row(row) for row in rows if row]
@classmethod @classmethod
@ -99,8 +96,8 @@ class Portal:
@property @property
def _values(self): def _values(self):
return ( return (
Long.to_optional_bytes(self.ktid), self.ktid,
Long.to_optional_bytes(self.kt_receiver), self.kt_receiver,
self.kt_type, self.kt_type,
self.mxid, self.mxid,
self.name, self.name,
@ -122,7 +119,7 @@ class Portal:
async def delete(self) -> None: async def delete(self) -> None:
q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2" q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2"
await self.db.execute(q, Long.to_optional_bytes(self.ktid), Long.to_optional_bytes(self.kt_receiver)) await self.db.execute(q, self.ktid, self.kt_receiver)
async def save(self) -> None: async def save(self) -> None:
q = """ q = """

View File

@ -18,7 +18,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass, field
from yarl import URL from yarl import URL
from mautrix.types import ContentURI, SyncToken, UserID from mautrix.types import ContentURI, SyncToken, UserID
@ -33,7 +33,7 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Puppet: class Puppet:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
ktid: Long ktid: Long = field(converter=Long)
name: str | None name: str | None
photo_id: str | None photo_id: str | None
photo_mxc: ContentURI | None photo_mxc: ContentURI | None
@ -49,22 +49,21 @@ class Puppet:
@classmethod @classmethod
def _from_row(cls, row: Record) -> Puppet: def _from_row(cls, row: Record) -> Puppet:
data = {**row} data = {**row}
ktid = data.pop("ktid")
base_url = data.pop("base_url", None) base_url = data.pop("base_url", None)
return cls(**data, ktid=Long.from_optional_bytes(ktid), base_url=URL(base_url) if base_url else None) return cls(**data, base_url=URL(base_url) if base_url else None)
@classmethod @classmethod
def _from_optional_row(cls, row: Record | None) -> Puppet | None: def _from_optional_row(cls, row: Record | None) -> Puppet | None:
return cls._from_row(row) if row is not None else None return cls._from_row(row) if row is not None else None
@classmethod @classmethod
async def get_by_ktid(cls, ktid: Long) -> Puppet | None: async def get_by_ktid(cls, ktid: int) -> Puppet | None:
q = ( q = (
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, " "SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
" custom_mxid, access_token, next_batch, base_url " " custom_mxid, access_token, next_batch, base_url "
"FROM puppet WHERE ktid=$1" "FROM puppet WHERE ktid=$1"
) )
row = await cls.db.fetchrow(q, bytes(ktid)) row = await cls.db.fetchrow(q, ktid)
return cls._from_optional_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
@ -100,7 +99,7 @@ class Puppet:
@property @property
def _values(self): def _values(self):
return ( return (
bytes(self.ktid), self.ktid,
self.name, self.name,
self.photo_id, self.photo_id,
self.photo_mxc, self.photo_mxc,
@ -123,7 +122,7 @@ class Puppet:
async def delete(self) -> None: async def delete(self) -> None:
q = "DELETE FROM puppet WHERE ktid=$1" q = "DELETE FROM puppet WHERE ktid=$1"
await self.db.execute(q, bytes(self.ktid)) await self.db.execute(q, self.ktid)
async def save(self) -> None: async def save(self) -> None:
q = """ q = """

View File

@ -30,7 +30,7 @@ async def create_v1_tables(conn: Connection) -> None:
await conn.execute( await conn.execute(
"""CREATE TABLE "user" ( """CREATE TABLE "user" (
mxid TEXT PRIMARY KEY, mxid TEXT PRIMARY KEY,
ktid BYTES UNIQUE, ktid BIGINT UNIQUE,
uuid TEXT, uuid TEXT,
access_token TEXT, access_token TEXT,
refresh_token TEXT, refresh_token TEXT,
@ -39,8 +39,8 @@ async def create_v1_tables(conn: Connection) -> None:
) )
await conn.execute( await conn.execute(
"""CREATE TABLE portal ( """CREATE TABLE portal (
ktid BYTES, ktid BIGINT,
kt_receiver BYTES, kt_receiver BIGINT,
kt_type TEXT, kt_type TEXT,
mxid TEXT UNIQUE, mxid TEXT UNIQUE,
name TEXT, name TEXT,
@ -55,7 +55,7 @@ async def create_v1_tables(conn: Connection) -> None:
) )
await conn.execute( await conn.execute(
"""CREATE TABLE puppet ( """CREATE TABLE puppet (
ktid BYTES PRIMARY KEY, ktid BIGINT PRIMARY KEY,
name TEXT, name TEXT,
photo_id TEXT, photo_id TEXT,
photo_mxc TEXT, photo_mxc TEXT,
@ -74,10 +74,10 @@ async def create_v1_tables(conn: Connection) -> None:
"""CREATE TABLE message ( """CREATE TABLE message (
mxid TEXT, mxid TEXT,
mx_room TEXT, mx_room TEXT,
ktid TEXT, ktid BIGINT,
kt_receiver BYTES, kt_receiver BIGINT,
"index" SMALLINT, "index" SMALLINT,
kt_chat BYTES, kt_chat BIGINT,
timestamp BIGINT, timestamp BIGINT,
PRIMARY KEY (ktid, kt_receiver, "index"), PRIMARY KEY (ktid, kt_receiver, "index"),
FOREIGN KEY (kt_chat, kt_receiver) REFERENCES portal(ktid, kt_receiver) FOREIGN KEY (kt_chat, kt_receiver) REFERENCES portal(ktid, kt_receiver)
@ -90,8 +90,8 @@ async def create_v1_tables(conn: Connection) -> None:
mxid TEXT, mxid TEXT,
mx_room TEXT, mx_room TEXT,
kt_msgid TEXT, kt_msgid TEXT,
kt_receiver BYTES, kt_receiver BIGINT,
kt_sender BYTES, kt_sender BIGINT,
reaction TEXT, reaction TEXT,
PRIMARY KEY (kt_msgid, kt_receiver, kt_sender), PRIMARY KEY (kt_msgid, kt_receiver, kt_sender),
UNIQUE (mxid, mx_room) UNIQUE (mxid, mx_room)

View File

@ -18,7 +18,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar, List, Set from typing import TYPE_CHECKING, ClassVar, List, Set
from asyncpg import Record from asyncpg import Record
from attr import dataclass from attr import dataclass, field
from mautrix.types import RoomID, UserID from mautrix.types import RoomID, UserID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
@ -33,7 +33,7 @@ class User:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
mxid: UserID mxid: UserID
ktid: Long | None ktid: Long | None = field(converter=lambda x: Long(x) if x is not None else None)
uuid: str | None uuid: str | None
access_token: str | None access_token: str | None
refresh_token: str | None refresh_token: str | None
@ -41,9 +41,7 @@ class User:
@classmethod @classmethod
def _from_row(cls, row: Record) -> User: def _from_row(cls, row: Record) -> User:
data = {**row} return cls(**row)
ktid = data.pop("ktid", None)
return cls(**data, ktid=Long.from_optional_bytes(ktid))
@classmethod @classmethod
def _from_optional_row(cls, row: Record | None) -> User | None: def _from_optional_row(cls, row: Record | None) -> User | None:
@ -59,9 +57,9 @@ class User:
return [cls._from_row(row) for row in rows if row] return [cls._from_row(row) for row in rows if row]
@classmethod @classmethod
async def get_by_ktid(cls, ktid: Long) -> User | None: async def get_by_ktid(cls, ktid: int) -> User | None:
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1' q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
row = await cls.db.fetchrow(q, bytes(ktid)) row = await cls.db.fetchrow(q, ktid)
return cls._from_optional_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
@ -81,7 +79,7 @@ class User:
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
""" """
await self.db.execute( await self.db.execute(
q, self.mxid, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room
) )
async def delete(self) -> None: async def delete(self) -> None:
@ -93,5 +91,5 @@ class User:
WHERE mxid=$6 WHERE mxid=$6
""" """
await self.db.execute( await self.db.execute(
q, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
) )

View File

@ -13,69 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import ClassVar, Optional from mautrix.types import Serializable, JSON
from attr import dataclass, asdict
import bson
from mautrix.types import SerializableAttrs, JSON
@dataclass(frozen=True) class Long(int, Serializable):
class Long(SerializableAttrs):
high: int
low: int
unsigned: bool
@classmethod
def from_bytes(cls, raw: bytes) -> "Long":
return cls(**bson.loads(raw))
@classmethod
def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
return cls(**bson.loads(raw)) if raw is not None else None
@classmethod
def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
return bytes(value) if value is not None else None
def serialize(self) -> JSON: def serialize(self) -> JSON:
data = super().serialize() return {"__type__": "Long", "str": str(self)}
data["__type__"] = "Long"
return data
def __bytes__(self) -> bytes: @classmethod
return bson.dumps(asdict(self)) def deserialize(cls, raw: JSON) -> "Long":
assert isinstance(raw, str), f"Long deserialization expected a string, but got non-string value {raw}"
def __int__(self) -> int: return cls(raw)
if self.unsigned:
pass
result = \
((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
( self.low + (1 << 32 if self.low < 0 else 0))
return result + (1 << 32 if self.unsigned and result < 0 else 0)
def __str__(self) -> str:
return str(int(self))
ZERO: ClassVar["Long"]
Long.ZERO = Long(0, 0, False)
class IntLong(Long):
"""Helper class for constructing a Long from an int."""
def __init__(self, val: int):
if val < 0:
pass
super().__init__(
high=(val & 0xffffffff00000000) >> 32,
low = val & 0x00000000ffffffff,
unsigned=val < 0,
)
class StrLong(IntLong):
"""Helper class for constructing a Long from the string representation of an int."""
def __init__(self, val: str):
super().__init__(int(val))

View File

@ -48,7 +48,7 @@ from .db import (
) )
from .formatter import kakaotalk_to_matrix, matrix_to_kakaotalk from .formatter import kakaotalk_to_matrix, matrix_to_kakaotalk
from .kt.types.bson import Long, IntLong from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo from .kt.types.channel.channel_info import ChannelInfo
from .kt.types.channel.channel_type import KnownChannelType, ChannelType from .kt.types.channel.channel_type import KnownChannelType, ChannelType
from .kt.types.chat.chat import Chatlog from .kt.types.chat.chat import Chatlog
@ -87,7 +87,7 @@ StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STAT
class Portal(DBPortal, BasePortal): class Portal(DBPortal, BasePortal):
invite_own_puppet_to_pm: bool = False invite_own_puppet_to_pm: bool = False
by_mxid: dict[RoomID, Portal] = {} by_mxid: dict[RoomID, Portal] = {}
by_ktid: dict[tuple[Long, Long], Portal] = {} by_ktid: dict[tuple[int, int], Portal] = {}
matrix: m.MatrixHandler matrix: m.MatrixHandler
config: Config config: Config
@ -162,7 +162,7 @@ class Portal(DBPortal, BasePortal):
async def delete(self) -> None: async def delete(self) -> None:
if self.mxid: if self.mxid:
await DBMessage.delete_all_by_room(self.mxid) await DBMessage.delete_all_by_room(self.mxid)
self.by_ktid.pop(self.ktid_full, None) self.by_ktid.pop(self._ktid_full, None)
self.by_mxid.pop(self.mxid, None) self.by_mxid.pop(self.mxid, None)
await super().delete() await super().delete()
@ -170,7 +170,7 @@ class Portal(DBPortal, BasePortal):
# region Properties # region Properties
@property @property
def ktid_full(self) -> tuple[Long, Long]: def _ktid_full(self) -> tuple[int, int]:
return self.ktid, self.kt_receiver return self.ktid, self.kt_receiver
@property @property
@ -617,7 +617,7 @@ class Portal(DBPortal, BasePortal):
# endregion # endregion
# region Matrix event handling # region Matrix event handling
def require_send_lock(self, user_id: Long) -> asyncio.Lock: def require_send_lock(self, user_id: int) -> asyncio.Lock:
try: try:
lock = self._send_locks[user_id] lock = self._send_locks[user_id]
except KeyError: except KeyError:
@ -625,7 +625,7 @@ class Portal(DBPortal, BasePortal):
self._send_locks[user_id] = lock self._send_locks[user_id] = lock
return lock return lock
def optional_send_lock(self, user_id: Long) -> asyncio.Lock | FakeLock: def optional_send_lock(self, user_id: int) -> asyncio.Lock | FakeLock:
try: try:
return self._send_locks[user_id] return self._send_locks[user_id]
except KeyError: except KeyError:
@ -937,8 +937,8 @@ class Portal(DBPortal, BasePortal):
after_log_id: Long | None, after_log_id: Long | None,
channel_info: ChannelInfo, channel_info: ChannelInfo,
) -> None: ) -> None:
self.log.debug("Backfilling history through %s", source.mxid) self.log.debug(f"Backfilling history through {source.mxid}")
self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid)) self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
messages = await source.client.get_chats( messages = await source.client.get_chats(
channel_info.channelId, channel_info.channelId,
limit, limit,
@ -961,7 +961,7 @@ class Portal(DBPortal, BasePortal):
# region Database getters # region Database getters
async def postinit(self) -> None: async def postinit(self) -> None:
self.by_ktid[self.ktid_full] = self self.by_ktid[self._ktid_full] = self
if self.mxid: if self.mxid:
self.by_mxid[self.mxid] = self self.by_mxid[self.mxid] = self
self._main_intent = ( self._main_intent = (
@ -989,14 +989,14 @@ class Portal(DBPortal, BasePortal):
@async_getter_lock @async_getter_lock
async def get_by_ktid( async def get_by_ktid(
cls, cls,
ktid: Long, ktid: int,
*, *,
kt_receiver: Long = Long.ZERO, kt_receiver: int = 0,
create: bool = True, create: bool = True,
kt_type: ChannelType | None = None, kt_type: ChannelType | None = None,
) -> Portal | None: ) -> Portal | None:
if kt_type: if kt_type:
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else Long.ZERO kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
ktid_full = (ktid, kt_receiver) ktid_full = (ktid, kt_receiver)
try: try:
return cls.by_ktid[ktid_full] return cls.by_ktid[ktid_full]
@ -1017,7 +1017,7 @@ class Portal(DBPortal, BasePortal):
return None return None
@classmethod @classmethod
async def get_all_by_receiver(cls, kt_receiver: Long) -> AsyncGenerator[Portal, None]: async def get_all_by_receiver(cls, kt_receiver: int) -> AsyncGenerator[Portal, None]:
portals = await super().get_all_by_receiver(kt_receiver) portals = await super().get_all_by_receiver(kt_receiver)
portal: Portal portal: Portal
for portal in portals: for portal in portals:

View File

@ -31,7 +31,7 @@ from . import matrix as m, portal as p, user as u
from .config import Config from .config import Config
from .db import Puppet as DBPuppet from .db import Puppet as DBPuppet
from .kt.types.bson import Long, StrLong from .kt.types.bson import Long
from .kt.client.types import UserInfoUnion from .kt.client.types import UserInfoUnion
@ -43,9 +43,9 @@ class Puppet(DBPuppet, BasePuppet):
mx: m.MatrixHandler mx: m.MatrixHandler
config: Config config: Config
hs_domain: str hs_domain: str
mxid_template: SimpleTemplate[StrLong] mxid_template: SimpleTemplate[int]
by_ktid: dict[Long, Puppet] = {} by_ktid: dict[int, Puppet] = {}
by_custom_mxid: dict[UserID, Puppet] = {} by_custom_mxid: dict[UserID, Puppet] = {}
_last_info_sync: datetime | None _last_info_sync: datetime | None
@ -127,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
keyword="userid", keyword="userid",
prefix="@", prefix="@",
suffix=f":{Puppet.hs_domain}", suffix=f":{Puppet.hs_domain}",
type=StrLong, type=int,
) )
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"] cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
cls.homeserver_url_map = { cls.homeserver_url_map = {
@ -218,9 +218,9 @@ class Puppet(DBPuppet, BasePuppet):
@classmethod @classmethod
@async_getter_lock @async_getter_lock
async def get_by_ktid(cls, ktid: Long, *, create: bool = True) -> Puppet | None: async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Puppet | None:
try: try:
return cls.by_ktid[ktid] return cls.by_ktid[int]
except KeyError: except KeyError:
pass pass
@ -260,11 +260,11 @@ class Puppet(DBPuppet, BasePuppet):
return None return None
@classmethod @classmethod
def get_id_from_mxid(cls, mxid: UserID) -> Long | None: def get_id_from_mxid(cls, mxid: UserID) -> int | None:
return cls.mxid_template.parse(mxid) return cls.mxid_template.parse(mxid)
@classmethod @classmethod
def get_mxid_from_id(cls, ktid: Long) -> UserID: def get_mxid_from_id(cls, ktid: int) -> UserID:
return UserID(cls.mxid_template.format_full(ktid)) return UserID(cls.mxid_template.format_full(ktid))
@classmethod @classmethod

View File

@ -84,7 +84,7 @@ class User(DBUser, BaseUser):
config: Config config: Config
by_mxid: dict[UserID, User] = {} by_mxid: dict[UserID, User] = {}
by_ktid: dict[Long, User] = {} by_ktid: dict[int, User] = {}
client: Client | None client: Client | None
@ -217,7 +217,7 @@ class User(DBUser, BaseUser):
@classmethod @classmethod
@async_getter_lock @async_getter_lock
async def get_by_ktid(cls, ktid: Long) -> User | None: async def get_by_ktid(cls, ktid: int) -> User | None:
try: try:
return cls.by_ktid[ktid] return cls.by_ktid[ktid]
except KeyError: except KeyError:

View File

@ -168,10 +168,9 @@ export default class PeerClient {
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
#write(data) { #write(data) {
return promisify(cb => this.socket.write(JSON.stringify(data) + "\n", cb)) return promisify(cb => this.socket.write(JSON.stringify(data, this.#writeReplacer) + "\n", cb))
} }
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.passcode * @param {string} req.passcode
@ -414,7 +413,7 @@ export default class PeerClient {
} }
let req let req
try { try {
req = JSON.parse(line) req = JSON.parse(line, this.#readReviver)
} catch (err) { } catch (err) {
this.log("Non-JSON request:", line) this.log("Non-JSON request:", line)
return return
@ -465,7 +464,6 @@ export default class PeerClient {
const resp = { id: req.id } const resp = { id: req.id }
delete req.id delete req.id
delete req.command delete req.command
req = typeify(req)
resp.command = "response" resp.command = "response"
try { try {
resp.response = await handler(req) resp.response = await handler(req)
@ -483,29 +481,22 @@ export default class PeerClient {
} }
await this.#write(resp) await this.#write(resp)
} }
}
/** #writeReplacer = function(key, value) {
* Recursively scan an object to check if any of its sub-objects if (value instanceof Long) {
* should be converted into instances of a specified class. return value.toString()
* @param obj The object to be scanned & updated. } else {
* @returns The converted object. return value
*/ }
function typeify(obj) {
if (!(obj instanceof Object)) {
return obj
} }
const converterFunc = TYPE_MAP.get(obj.__type__)
if (converterFunc !== undefined) {
return converterFunc(obj)
}
for (const key in obj) {
obj[key] = typeify(obj[key])
}
return obj
}
// TODO Add more if needed #readReviver = function(key, value) {
const TYPE_MAP = new Map([ if (value instanceof Object) {
["Long", (obj) => new Long(obj.low, obj.high, obj.unsigned)], // TODO Use a type map if there will be many possible types
]) if (value.__type__ == "Long") {
return Long.fromString(value.str)
}
}
return value
}
}