Clean up Long
Keep it as an int in Python, and do all fancy conversions in Node
This commit is contained in:
parent
2a91a7b43e
commit
4158788496
@ -18,12 +18,12 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
from attr import dataclass, field
|
||||
|
||||
from mautrix.types import EventID, RoomID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..kt.types.bson import Long, StrLong
|
||||
from ..kt.types.bson import Long
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
@ -32,27 +32,17 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
class Message:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
# TODO Store all Long values as the same type
|
||||
mxid: EventID
|
||||
mx_room: RoomID
|
||||
ktid: Long
|
||||
ktid: Long = field(converter=Long)
|
||||
index: int
|
||||
kt_chat: Long
|
||||
kt_receiver: Long
|
||||
kt_chat: Long = field(converter=Long)
|
||||
kt_receiver: Long = field(converter=Long)
|
||||
timestamp: int
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record) -> Message | None:
|
||||
data = {**row}
|
||||
ktid = data.pop("ktid")
|
||||
kt_chat = data.pop("kt_chat")
|
||||
kt_receiver = data.pop("kt_receiver")
|
||||
return cls(
|
||||
**data,
|
||||
ktid=StrLong(ktid),
|
||||
kt_chat=Long.from_bytes(kt_chat),
|
||||
kt_receiver=Long.from_bytes(kt_receiver)
|
||||
)
|
||||
def _from_row(cls, row: Record) -> Message:
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> Message | None:
|
||||
@ -61,15 +51,15 @@ class Message:
|
||||
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
|
||||
|
||||
@classmethod
|
||||
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
|
||||
async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
|
||||
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
|
||||
rows = await cls.db.fetch(q, ktid, kt_receiver)
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
|
||||
async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None:
|
||||
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
|
||||
row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
@ -83,18 +73,18 @@ class Message:
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
|
||||
async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
|
||||
"ORDER BY timestamp DESC LIMIT 1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
|
||||
row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_closest_before(
|
||||
cls, kt_chat: Long, kt_receiver: Long, timestamp: int
|
||||
cls, kt_chat: int, kt_receiver: int, timestamp: int
|
||||
) -> Message | None:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
@ -102,7 +92,7 @@ class Message:
|
||||
" ktid IS NOT NULL "
|
||||
"ORDER BY timestamp DESC LIMIT 1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
||||
row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
_insert_query = (
|
||||
@ -111,46 +101,23 @@ class Message:
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def bulk_create(
|
||||
cls,
|
||||
ktid: Long,
|
||||
kt_chat: Long,
|
||||
kt_receiver: Long,
|
||||
event_ids: list[EventID],
|
||||
timestamp: int,
|
||||
mx_room: RoomID,
|
||||
) -> None:
|
||||
if not event_ids:
|
||||
return
|
||||
columns = [col.strip('"') for col in cls.columns.split(", ")]
|
||||
records = [
|
||||
(mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
||||
for index, mxid in enumerate(event_ids)
|
||||
]
|
||||
async with cls.db.acquire() as conn, conn.transaction():
|
||||
if cls.db.scheme == "postgres":
|
||||
await conn.copy_records_to_table("message", records=records, columns=columns)
|
||||
else:
|
||||
await conn.executemany(cls._insert_query, records)
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = self._insert_query
|
||||
await self.db.execute(
|
||||
q,
|
||||
self.mxid,
|
||||
self.mx_room,
|
||||
str(self.ktid),
|
||||
self.ktid,
|
||||
self.index,
|
||||
bytes(self.kt_chat),
|
||||
bytes(self.kt_receiver),
|
||||
self.kt_chat,
|
||||
self.kt_receiver,
|
||||
self.timestamp,
|
||||
)
|
||||
|
||||
async def delete(self) -> None:
|
||||
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||
await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index)
|
||||
await self.db.execute(q, self.ktid, self.kt_receiver, self.index)
|
||||
|
||||
async def update(self) -> None:
|
||||
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
|
||||
await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room)
|
||||
await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
from attr import dataclass, field
|
||||
|
||||
from mautrix.types import ContentURI, RoomID, UserID
|
||||
from mautrix.util.async_db import Database
|
||||
@ -33,8 +33,8 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
class Portal:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
ktid: Long
|
||||
kt_receiver: Long
|
||||
ktid: Long = field(converter=Long)
|
||||
kt_receiver: Long = field(converter=Long)
|
||||
kt_type: ChannelType
|
||||
mxid: RoomID | None
|
||||
name: str | None
|
||||
@ -47,23 +47,20 @@ class Portal:
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record) -> Portal:
|
||||
data = {**row}
|
||||
ktid = data.pop("ktid")
|
||||
kt_receiver = data.pop("kt_receiver")
|
||||
return cls(**data, ktid=Long.from_bytes(ktid), kt_receiver=Long.from_bytes(kt_receiver))
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> Portal | None:
|
||||
return cls._from_row(row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long) -> Portal | None:
|
||||
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
|
||||
q = """
|
||||
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
|
||||
name_set, avatar_set, relay_user_id
|
||||
FROM portal WHERE ktid=$1 AND kt_receiver=$2
|
||||
"""
|
||||
row = await cls.db.fetchrow(q, bytes(ktid), bytes(kt_receiver))
|
||||
row = await cls.db.fetchrow(q, ktid, kt_receiver)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
@ -77,13 +74,13 @@ class Portal:
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_all_by_receiver(cls, kt_receiver: Long) -> list[Portal]:
|
||||
async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
|
||||
q = """
|
||||
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
|
||||
name_set, avatar_set, relay_user_id
|
||||
FROM portal WHERE kt_receiver=$1
|
||||
"""
|
||||
rows = await cls.db.fetch(q, bytes(kt_receiver))
|
||||
rows = await cls.db.fetch(q, kt_receiver)
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
||||
@classmethod
|
||||
@ -99,8 +96,8 @@ class Portal:
|
||||
@property
|
||||
def _values(self):
|
||||
return (
|
||||
Long.to_optional_bytes(self.ktid),
|
||||
Long.to_optional_bytes(self.kt_receiver),
|
||||
self.ktid,
|
||||
self.kt_receiver,
|
||||
self.kt_type,
|
||||
self.mxid,
|
||||
self.name,
|
||||
@ -122,7 +119,7 @@ class Portal:
|
||||
|
||||
async def delete(self) -> None:
|
||||
q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2"
|
||||
await self.db.execute(q, Long.to_optional_bytes(self.ktid), Long.to_optional_bytes(self.kt_receiver))
|
||||
await self.db.execute(q, self.ktid, self.kt_receiver)
|
||||
|
||||
async def save(self) -> None:
|
||||
q = """
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
from attr import dataclass, field
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.types import ContentURI, SyncToken, UserID
|
||||
@ -33,7 +33,7 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
class Puppet:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
ktid: Long
|
||||
ktid: Long = field(converter=Long)
|
||||
name: str | None
|
||||
photo_id: str | None
|
||||
photo_mxc: ContentURI | None
|
||||
@ -49,22 +49,21 @@ class Puppet:
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record) -> Puppet:
|
||||
data = {**row}
|
||||
ktid = data.pop("ktid")
|
||||
base_url = data.pop("base_url", None)
|
||||
return cls(**data, ktid=Long.from_optional_bytes(ktid), base_url=URL(base_url) if base_url else None)
|
||||
return cls(**data, base_url=URL(base_url) if base_url else None)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> Puppet | None:
|
||||
return cls._from_row(row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: Long) -> Puppet | None:
|
||||
async def get_by_ktid(cls, ktid: int) -> Puppet | None:
|
||||
q = (
|
||||
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
|
||||
" custom_mxid, access_token, next_batch, base_url "
|
||||
"FROM puppet WHERE ktid=$1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, bytes(ktid))
|
||||
row = await cls.db.fetchrow(q, ktid)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
@ -100,7 +99,7 @@ class Puppet:
|
||||
@property
|
||||
def _values(self):
|
||||
return (
|
||||
bytes(self.ktid),
|
||||
self.ktid,
|
||||
self.name,
|
||||
self.photo_id,
|
||||
self.photo_mxc,
|
||||
@ -123,7 +122,7 @@ class Puppet:
|
||||
|
||||
async def delete(self) -> None:
|
||||
q = "DELETE FROM puppet WHERE ktid=$1"
|
||||
await self.db.execute(q, bytes(self.ktid))
|
||||
await self.db.execute(q, self.ktid)
|
||||
|
||||
async def save(self) -> None:
|
||||
q = """
|
||||
|
@ -30,7 +30,7 @@ async def create_v1_tables(conn: Connection) -> None:
|
||||
await conn.execute(
|
||||
"""CREATE TABLE "user" (
|
||||
mxid TEXT PRIMARY KEY,
|
||||
ktid BYTES UNIQUE,
|
||||
ktid BIGINT UNIQUE,
|
||||
uuid TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
@ -39,8 +39,8 @@ async def create_v1_tables(conn: Connection) -> None:
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE portal (
|
||||
ktid BYTES,
|
||||
kt_receiver BYTES,
|
||||
ktid BIGINT,
|
||||
kt_receiver BIGINT,
|
||||
kt_type TEXT,
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT,
|
||||
@ -55,7 +55,7 @@ async def create_v1_tables(conn: Connection) -> None:
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE puppet (
|
||||
ktid BYTES PRIMARY KEY,
|
||||
ktid BIGINT PRIMARY KEY,
|
||||
name TEXT,
|
||||
photo_id TEXT,
|
||||
photo_mxc TEXT,
|
||||
@ -74,10 +74,10 @@ async def create_v1_tables(conn: Connection) -> None:
|
||||
"""CREATE TABLE message (
|
||||
mxid TEXT,
|
||||
mx_room TEXT,
|
||||
ktid TEXT,
|
||||
kt_receiver BYTES,
|
||||
ktid BIGINT,
|
||||
kt_receiver BIGINT,
|
||||
"index" SMALLINT,
|
||||
kt_chat BYTES,
|
||||
kt_chat BIGINT,
|
||||
timestamp BIGINT,
|
||||
PRIMARY KEY (ktid, kt_receiver, "index"),
|
||||
FOREIGN KEY (kt_chat, kt_receiver) REFERENCES portal(ktid, kt_receiver)
|
||||
@ -90,8 +90,8 @@ async def create_v1_tables(conn: Connection) -> None:
|
||||
mxid TEXT,
|
||||
mx_room TEXT,
|
||||
kt_msgid TEXT,
|
||||
kt_receiver BYTES,
|
||||
kt_sender BYTES,
|
||||
kt_receiver BIGINT,
|
||||
kt_sender BIGINT,
|
||||
reaction TEXT,
|
||||
PRIMARY KEY (kt_msgid, kt_receiver, kt_sender),
|
||||
UNIQUE (mxid, mx_room)
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Set
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
from attr import dataclass, field
|
||||
|
||||
from mautrix.types import RoomID, UserID
|
||||
from mautrix.util.async_db import Database
|
||||
@ -33,7 +33,7 @@ class User:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
mxid: UserID
|
||||
ktid: Long | None
|
||||
ktid: Long | None = field(converter=lambda x: Long(x) if x is not None else None)
|
||||
uuid: str | None
|
||||
access_token: str | None
|
||||
refresh_token: str | None
|
||||
@ -41,9 +41,7 @@ class User:
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record) -> User:
|
||||
data = {**row}
|
||||
ktid = data.pop("ktid", None)
|
||||
return cls(**data, ktid=Long.from_optional_bytes(ktid))
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
def _from_optional_row(cls, row: Record | None) -> User | None:
|
||||
@ -59,9 +57,9 @@ class User:
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: Long) -> User | None:
|
||||
async def get_by_ktid(cls, ktid: int) -> User | None:
|
||||
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
|
||||
row = await cls.db.fetchrow(q, bytes(ktid))
|
||||
row = await cls.db.fetchrow(q, ktid)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
@ -81,7 +79,7 @@ class User:
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
"""
|
||||
await self.db.execute(
|
||||
q, self.mxid, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room
|
||||
q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room
|
||||
)
|
||||
|
||||
async def delete(self) -> None:
|
||||
@ -93,5 +91,5 @@ class User:
|
||||
WHERE mxid=$6
|
||||
"""
|
||||
await self.db.execute(
|
||||
q, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
|
||||
q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
|
||||
)
|
||||
|
@ -13,69 +13,14 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from attr import dataclass, asdict
|
||||
import bson
|
||||
|
||||
from mautrix.types import SerializableAttrs, JSON
|
||||
from mautrix.types import Serializable, JSON
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Long(SerializableAttrs):
|
||||
high: int
|
||||
low: int
|
||||
unsigned: bool
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, raw: bytes) -> "Long":
|
||||
return cls(**bson.loads(raw))
|
||||
|
||||
@classmethod
|
||||
def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
|
||||
return cls(**bson.loads(raw)) if raw is not None else None
|
||||
|
||||
@classmethod
|
||||
def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
|
||||
return bytes(value) if value is not None else None
|
||||
|
||||
class Long(int, Serializable):
|
||||
def serialize(self) -> JSON:
|
||||
data = super().serialize()
|
||||
data["__type__"] = "Long"
|
||||
return data
|
||||
return {"__type__": "Long", "str": str(self)}
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bson.dumps(asdict(self))
|
||||
|
||||
def __int__(self) -> int:
|
||||
if self.unsigned:
|
||||
pass
|
||||
result = \
|
||||
((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
|
||||
( self.low + (1 << 32 if self.low < 0 else 0))
|
||||
return result + (1 << 32 if self.unsigned and result < 0 else 0)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(int(self))
|
||||
|
||||
ZERO: ClassVar["Long"]
|
||||
|
||||
Long.ZERO = Long(0, 0, False)
|
||||
|
||||
|
||||
class IntLong(Long):
|
||||
"""Helper class for constructing a Long from an int."""
|
||||
def __init__(self, val: int):
|
||||
if val < 0:
|
||||
pass
|
||||
super().__init__(
|
||||
high=(val & 0xffffffff00000000) >> 32,
|
||||
low = val & 0x00000000ffffffff,
|
||||
unsigned=val < 0,
|
||||
)
|
||||
|
||||
|
||||
class StrLong(IntLong):
|
||||
"""Helper class for constructing a Long from the string representation of an int."""
|
||||
def __init__(self, val: str):
|
||||
super().__init__(int(val))
|
||||
@classmethod
|
||||
def deserialize(cls, raw: JSON) -> "Long":
|
||||
assert isinstance(raw, str), f"Long deserialization expected a string, but got non-string value {raw}"
|
||||
return cls(raw)
|
||||
|
@ -48,7 +48,7 @@ from .db import (
|
||||
)
|
||||
from .formatter import kakaotalk_to_matrix, matrix_to_kakaotalk
|
||||
|
||||
from .kt.types.bson import Long, IntLong
|
||||
from .kt.types.bson import Long
|
||||
from .kt.types.channel.channel_info import ChannelInfo
|
||||
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
|
||||
from .kt.types.chat.chat import Chatlog
|
||||
@ -87,7 +87,7 @@ StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STAT
|
||||
class Portal(DBPortal, BasePortal):
|
||||
invite_own_puppet_to_pm: bool = False
|
||||
by_mxid: dict[RoomID, Portal] = {}
|
||||
by_ktid: dict[tuple[Long, Long], Portal] = {}
|
||||
by_ktid: dict[tuple[int, int], Portal] = {}
|
||||
matrix: m.MatrixHandler
|
||||
config: Config
|
||||
|
||||
@ -162,7 +162,7 @@ class Portal(DBPortal, BasePortal):
|
||||
async def delete(self) -> None:
|
||||
if self.mxid:
|
||||
await DBMessage.delete_all_by_room(self.mxid)
|
||||
self.by_ktid.pop(self.ktid_full, None)
|
||||
self.by_ktid.pop(self._ktid_full, None)
|
||||
self.by_mxid.pop(self.mxid, None)
|
||||
await super().delete()
|
||||
|
||||
@ -170,7 +170,7 @@ class Portal(DBPortal, BasePortal):
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
def ktid_full(self) -> tuple[Long, Long]:
|
||||
def _ktid_full(self) -> tuple[int, int]:
|
||||
return self.ktid, self.kt_receiver
|
||||
|
||||
@property
|
||||
@ -617,7 +617,7 @@ class Portal(DBPortal, BasePortal):
|
||||
# endregion
|
||||
# region Matrix event handling
|
||||
|
||||
def require_send_lock(self, user_id: Long) -> asyncio.Lock:
|
||||
def require_send_lock(self, user_id: int) -> asyncio.Lock:
|
||||
try:
|
||||
lock = self._send_locks[user_id]
|
||||
except KeyError:
|
||||
@ -625,7 +625,7 @@ class Portal(DBPortal, BasePortal):
|
||||
self._send_locks[user_id] = lock
|
||||
return lock
|
||||
|
||||
def optional_send_lock(self, user_id: Long) -> asyncio.Lock | FakeLock:
|
||||
def optional_send_lock(self, user_id: int) -> asyncio.Lock | FakeLock:
|
||||
try:
|
||||
return self._send_locks[user_id]
|
||||
except KeyError:
|
||||
@ -937,8 +937,8 @@ class Portal(DBPortal, BasePortal):
|
||||
after_log_id: Long | None,
|
||||
channel_info: ChannelInfo,
|
||||
) -> None:
|
||||
self.log.debug("Backfilling history through %s", source.mxid)
|
||||
self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid))
|
||||
self.log.debug(f"Backfilling history through {source.mxid}")
|
||||
self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
|
||||
messages = await source.client.get_chats(
|
||||
channel_info.channelId,
|
||||
limit,
|
||||
@ -961,7 +961,7 @@ class Portal(DBPortal, BasePortal):
|
||||
# region Database getters
|
||||
|
||||
async def postinit(self) -> None:
|
||||
self.by_ktid[self.ktid_full] = self
|
||||
self.by_ktid[self._ktid_full] = self
|
||||
if self.mxid:
|
||||
self.by_mxid[self.mxid] = self
|
||||
self._main_intent = (
|
||||
@ -989,14 +989,14 @@ class Portal(DBPortal, BasePortal):
|
||||
@async_getter_lock
|
||||
async def get_by_ktid(
|
||||
cls,
|
||||
ktid: Long,
|
||||
ktid: int,
|
||||
*,
|
||||
kt_receiver: Long = Long.ZERO,
|
||||
kt_receiver: int = 0,
|
||||
create: bool = True,
|
||||
kt_type: ChannelType | None = None,
|
||||
) -> Portal | None:
|
||||
if kt_type:
|
||||
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else Long.ZERO
|
||||
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
|
||||
ktid_full = (ktid, kt_receiver)
|
||||
try:
|
||||
return cls.by_ktid[ktid_full]
|
||||
@ -1017,7 +1017,7 @@ class Portal(DBPortal, BasePortal):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def get_all_by_receiver(cls, kt_receiver: Long) -> AsyncGenerator[Portal, None]:
|
||||
async def get_all_by_receiver(cls, kt_receiver: int) -> AsyncGenerator[Portal, None]:
|
||||
portals = await super().get_all_by_receiver(kt_receiver)
|
||||
portal: Portal
|
||||
for portal in portals:
|
||||
|
@ -31,7 +31,7 @@ from . import matrix as m, portal as p, user as u
|
||||
from .config import Config
|
||||
from .db import Puppet as DBPuppet
|
||||
|
||||
from .kt.types.bson import Long, StrLong
|
||||
from .kt.types.bson import Long
|
||||
|
||||
from .kt.client.types import UserInfoUnion
|
||||
|
||||
@ -43,9 +43,9 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
mx: m.MatrixHandler
|
||||
config: Config
|
||||
hs_domain: str
|
||||
mxid_template: SimpleTemplate[StrLong]
|
||||
mxid_template: SimpleTemplate[int]
|
||||
|
||||
by_ktid: dict[Long, Puppet] = {}
|
||||
by_ktid: dict[int, Puppet] = {}
|
||||
by_custom_mxid: dict[UserID, Puppet] = {}
|
||||
|
||||
_last_info_sync: datetime | None
|
||||
@ -127,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
keyword="userid",
|
||||
prefix="@",
|
||||
suffix=f":{Puppet.hs_domain}",
|
||||
type=StrLong,
|
||||
type=int,
|
||||
)
|
||||
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
|
||||
cls.homeserver_url_map = {
|
||||
@ -218,9 +218,9 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_ktid(cls, ktid: Long, *, create: bool = True) -> Puppet | None:
|
||||
async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Puppet | None:
|
||||
try:
|
||||
return cls.by_ktid[ktid]
|
||||
return cls.by_ktid[int]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@ -260,11 +260,11 @@ class Puppet(DBPuppet, BasePuppet):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid: UserID) -> Long | None:
|
||||
def get_id_from_mxid(cls, mxid: UserID) -> int | None:
|
||||
return cls.mxid_template.parse(mxid)
|
||||
|
||||
@classmethod
|
||||
def get_mxid_from_id(cls, ktid: Long) -> UserID:
|
||||
def get_mxid_from_id(cls, ktid: int) -> UserID:
|
||||
return UserID(cls.mxid_template.format_full(ktid))
|
||||
|
||||
@classmethod
|
||||
|
@ -84,7 +84,7 @@ class User(DBUser, BaseUser):
|
||||
config: Config
|
||||
|
||||
by_mxid: dict[UserID, User] = {}
|
||||
by_ktid: dict[Long, User] = {}
|
||||
by_ktid: dict[int, User] = {}
|
||||
|
||||
client: Client | None
|
||||
|
||||
@ -217,7 +217,7 @@ class User(DBUser, BaseUser):
|
||||
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get_by_ktid(cls, ktid: Long) -> User | None:
|
||||
async def get_by_ktid(cls, ktid: int) -> User | None:
|
||||
try:
|
||||
return cls.by_ktid[ktid]
|
||||
except KeyError:
|
||||
|
@ -168,10 +168,9 @@ export default class PeerClient {
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
#write(data) {
|
||||
return promisify(cb => this.socket.write(JSON.stringify(data) + "\n", cb))
|
||||
return promisify(cb => this.socket.write(JSON.stringify(data, this.#writeReplacer) + "\n", cb))
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @param {Object} req
|
||||
* @param {string} req.passcode
|
||||
@ -414,7 +413,7 @@ export default class PeerClient {
|
||||
}
|
||||
let req
|
||||
try {
|
||||
req = JSON.parse(line)
|
||||
req = JSON.parse(line, this.#readReviver)
|
||||
} catch (err) {
|
||||
this.log("Non-JSON request:", line)
|
||||
return
|
||||
@ -465,7 +464,6 @@ export default class PeerClient {
|
||||
const resp = { id: req.id }
|
||||
delete req.id
|
||||
delete req.command
|
||||
req = typeify(req)
|
||||
resp.command = "response"
|
||||
try {
|
||||
resp.response = await handler(req)
|
||||
@ -483,29 +481,22 @@ export default class PeerClient {
|
||||
}
|
||||
await this.#write(resp)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively scan an object to check if any of its sub-objects
|
||||
* should be converted into instances of a specified class.
|
||||
* @param obj The object to be scanned & updated.
|
||||
* @returns The converted object.
|
||||
*/
|
||||
function typeify(obj) {
|
||||
if (!(obj instanceof Object)) {
|
||||
return obj
|
||||
#writeReplacer = function(key, value) {
|
||||
if (value instanceof Long) {
|
||||
return value.toString()
|
||||
} else {
|
||||
return value
|
||||
}
|
||||
}
|
||||
const converterFunc = TYPE_MAP.get(obj.__type__)
|
||||
if (converterFunc !== undefined) {
|
||||
return converterFunc(obj)
|
||||
}
|
||||
for (const key in obj) {
|
||||
obj[key] = typeify(obj[key])
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// TODO Add more if needed
|
||||
const TYPE_MAP = new Map([
|
||||
["Long", (obj) => new Long(obj.low, obj.high, obj.unsigned)],
|
||||
])
|
||||
#readReviver = function(key, value) {
|
||||
if (value instanceof Object) {
|
||||
// TODO Use a type map if there will be many possible types
|
||||
if (value.__type__ == "Long") {
|
||||
return Long.fromString(value.str)
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user