Clean up Long
Keep it as an int in Python, and do all fancy conversions in Node
This commit is contained in:
parent
2a91a7b43e
commit
4158788496
@ -18,12 +18,12 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, ClassVar
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
from asyncpg import Record
|
from asyncpg import Record
|
||||||
from attr import dataclass
|
from attr import dataclass, field
|
||||||
|
|
||||||
from mautrix.types import EventID, RoomID
|
from mautrix.types import EventID, RoomID
|
||||||
from mautrix.util.async_db import Database
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
from ..kt.types.bson import Long, StrLong
|
from ..kt.types.bson import Long
|
||||||
|
|
||||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
|
|
||||||
@ -32,27 +32,17 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
|||||||
class Message:
|
class Message:
|
||||||
db: ClassVar[Database] = fake_db
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
# TODO Store all Long values as the same type
|
|
||||||
mxid: EventID
|
mxid: EventID
|
||||||
mx_room: RoomID
|
mx_room: RoomID
|
||||||
ktid: Long
|
ktid: Long = field(converter=Long)
|
||||||
index: int
|
index: int
|
||||||
kt_chat: Long
|
kt_chat: Long = field(converter=Long)
|
||||||
kt_receiver: Long
|
kt_receiver: Long = field(converter=Long)
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_row(cls, row: Record) -> Message | None:
|
def _from_row(cls, row: Record) -> Message:
|
||||||
data = {**row}
|
return cls(**row)
|
||||||
ktid = data.pop("ktid")
|
|
||||||
kt_chat = data.pop("kt_chat")
|
|
||||||
kt_receiver = data.pop("kt_receiver")
|
|
||||||
return cls(
|
|
||||||
**data,
|
|
||||||
ktid=StrLong(ktid),
|
|
||||||
kt_chat=Long.from_bytes(kt_chat),
|
|
||||||
kt_receiver=Long.from_bytes(kt_receiver)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_optional_row(cls, row: Record | None) -> Message | None:
|
def _from_optional_row(cls, row: Record | None) -> Message | None:
|
||||||
@ -61,15 +51,15 @@ class Message:
|
|||||||
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
|
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
|
async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]:
|
||||||
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
|
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
|
||||||
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
|
rows = await cls.db.fetch(q, ktid, kt_receiver)
|
||||||
return [cls._from_row(row) for row in rows if row]
|
return [cls._from_row(row) for row in rows if row]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
|
async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None:
|
||||||
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||||
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
|
row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -83,18 +73,18 @@ class Message:
|
|||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
|
async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
|
||||||
q = (
|
q = (
|
||||||
f"SELECT {cls.columns} "
|
f"SELECT {cls.columns} "
|
||||||
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
|
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
|
||||||
"ORDER BY timestamp DESC LIMIT 1"
|
"ORDER BY timestamp DESC LIMIT 1"
|
||||||
)
|
)
|
||||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
|
row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_closest_before(
|
async def get_closest_before(
|
||||||
cls, kt_chat: Long, kt_receiver: Long, timestamp: int
|
cls, kt_chat: int, kt_receiver: int, timestamp: int
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
q = (
|
q = (
|
||||||
f"SELECT {cls.columns} "
|
f"SELECT {cls.columns} "
|
||||||
@ -102,7 +92,7 @@ class Message:
|
|||||||
" ktid IS NOT NULL "
|
" ktid IS NOT NULL "
|
||||||
"ORDER BY timestamp DESC LIMIT 1"
|
"ORDER BY timestamp DESC LIMIT 1"
|
||||||
)
|
)
|
||||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
_insert_query = (
|
_insert_query = (
|
||||||
@ -111,46 +101,23 @@ class Message:
|
|||||||
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def bulk_create(
|
|
||||||
cls,
|
|
||||||
ktid: Long,
|
|
||||||
kt_chat: Long,
|
|
||||||
kt_receiver: Long,
|
|
||||||
event_ids: list[EventID],
|
|
||||||
timestamp: int,
|
|
||||||
mx_room: RoomID,
|
|
||||||
) -> None:
|
|
||||||
if not event_ids:
|
|
||||||
return
|
|
||||||
columns = [col.strip('"') for col in cls.columns.split(", ")]
|
|
||||||
records = [
|
|
||||||
(mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
|
||||||
for index, mxid in enumerate(event_ids)
|
|
||||||
]
|
|
||||||
async with cls.db.acquire() as conn, conn.transaction():
|
|
||||||
if cls.db.scheme == "postgres":
|
|
||||||
await conn.copy_records_to_table("message", records=records, columns=columns)
|
|
||||||
else:
|
|
||||||
await conn.executemany(cls._insert_query, records)
|
|
||||||
|
|
||||||
async def insert(self) -> None:
|
async def insert(self) -> None:
|
||||||
q = self._insert_query
|
q = self._insert_query
|
||||||
await self.db.execute(
|
await self.db.execute(
|
||||||
q,
|
q,
|
||||||
self.mxid,
|
self.mxid,
|
||||||
self.mx_room,
|
self.mx_room,
|
||||||
str(self.ktid),
|
self.ktid,
|
||||||
self.index,
|
self.index,
|
||||||
bytes(self.kt_chat),
|
self.kt_chat,
|
||||||
bytes(self.kt_receiver),
|
self.kt_receiver,
|
||||||
self.timestamp,
|
self.timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||||
await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index)
|
await self.db.execute(q, self.ktid, self.kt_receiver, self.index)
|
||||||
|
|
||||||
async def update(self) -> None:
|
async def update(self) -> None:
|
||||||
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
|
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
|
||||||
await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room)
|
await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, ClassVar
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
from asyncpg import Record
|
from asyncpg import Record
|
||||||
from attr import dataclass
|
from attr import dataclass, field
|
||||||
|
|
||||||
from mautrix.types import ContentURI, RoomID, UserID
|
from mautrix.types import ContentURI, RoomID, UserID
|
||||||
from mautrix.util.async_db import Database
|
from mautrix.util.async_db import Database
|
||||||
@ -33,8 +33,8 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
|||||||
class Portal:
|
class Portal:
|
||||||
db: ClassVar[Database] = fake_db
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
ktid: Long
|
ktid: Long = field(converter=Long)
|
||||||
kt_receiver: Long
|
kt_receiver: Long = field(converter=Long)
|
||||||
kt_type: ChannelType
|
kt_type: ChannelType
|
||||||
mxid: RoomID | None
|
mxid: RoomID | None
|
||||||
name: str | None
|
name: str | None
|
||||||
@ -47,23 +47,20 @@ class Portal:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_row(cls, row: Record) -> Portal:
|
def _from_row(cls, row: Record) -> Portal:
|
||||||
data = {**row}
|
return cls(**row)
|
||||||
ktid = data.pop("ktid")
|
|
||||||
kt_receiver = data.pop("kt_receiver")
|
|
||||||
return cls(**data, ktid=Long.from_bytes(ktid), kt_receiver=Long.from_bytes(kt_receiver))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_optional_row(cls, row: Record | None) -> Portal | None:
|
def _from_optional_row(cls, row: Record | None) -> Portal | None:
|
||||||
return cls._from_row(row) if row is not None else None
|
return cls._from_row(row) if row is not None else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long) -> Portal | None:
|
async def get_by_ktid(cls, ktid: int, kt_receiver: int) -> Portal | None:
|
||||||
q = """
|
q = """
|
||||||
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
|
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
|
||||||
name_set, avatar_set, relay_user_id
|
name_set, avatar_set, relay_user_id
|
||||||
FROM portal WHERE ktid=$1 AND kt_receiver=$2
|
FROM portal WHERE ktid=$1 AND kt_receiver=$2
|
||||||
"""
|
"""
|
||||||
row = await cls.db.fetchrow(q, bytes(ktid), bytes(kt_receiver))
|
row = await cls.db.fetchrow(q, ktid, kt_receiver)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -77,13 +74,13 @@ class Portal:
|
|||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_all_by_receiver(cls, kt_receiver: Long) -> list[Portal]:
|
async def get_all_by_receiver(cls, kt_receiver: int) -> list[Portal]:
|
||||||
q = """
|
q = """
|
||||||
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
|
SELECT ktid, kt_receiver, kt_type, mxid, name, photo_id, avatar_url, encrypted,
|
||||||
name_set, avatar_set, relay_user_id
|
name_set, avatar_set, relay_user_id
|
||||||
FROM portal WHERE kt_receiver=$1
|
FROM portal WHERE kt_receiver=$1
|
||||||
"""
|
"""
|
||||||
rows = await cls.db.fetch(q, bytes(kt_receiver))
|
rows = await cls.db.fetch(q, kt_receiver)
|
||||||
return [cls._from_row(row) for row in rows if row]
|
return [cls._from_row(row) for row in rows if row]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -99,8 +96,8 @@ class Portal:
|
|||||||
@property
|
@property
|
||||||
def _values(self):
|
def _values(self):
|
||||||
return (
|
return (
|
||||||
Long.to_optional_bytes(self.ktid),
|
self.ktid,
|
||||||
Long.to_optional_bytes(self.kt_receiver),
|
self.kt_receiver,
|
||||||
self.kt_type,
|
self.kt_type,
|
||||||
self.mxid,
|
self.mxid,
|
||||||
self.name,
|
self.name,
|
||||||
@ -122,7 +119,7 @@ class Portal:
|
|||||||
|
|
||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2"
|
q = "DELETE FROM portal WHERE ktid=$1 AND kt_receiver=$2"
|
||||||
await self.db.execute(q, Long.to_optional_bytes(self.ktid), Long.to_optional_bytes(self.kt_receiver))
|
await self.db.execute(q, self.ktid, self.kt_receiver)
|
||||||
|
|
||||||
async def save(self) -> None:
|
async def save(self) -> None:
|
||||||
q = """
|
q = """
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, ClassVar
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
from asyncpg import Record
|
from asyncpg import Record
|
||||||
from attr import dataclass
|
from attr import dataclass, field
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from mautrix.types import ContentURI, SyncToken, UserID
|
from mautrix.types import ContentURI, SyncToken, UserID
|
||||||
@ -33,7 +33,7 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
|||||||
class Puppet:
|
class Puppet:
|
||||||
db: ClassVar[Database] = fake_db
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
ktid: Long
|
ktid: Long = field(converter=Long)
|
||||||
name: str | None
|
name: str | None
|
||||||
photo_id: str | None
|
photo_id: str | None
|
||||||
photo_mxc: ContentURI | None
|
photo_mxc: ContentURI | None
|
||||||
@ -49,22 +49,21 @@ class Puppet:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _from_row(cls, row: Record) -> Puppet:
|
def _from_row(cls, row: Record) -> Puppet:
|
||||||
data = {**row}
|
data = {**row}
|
||||||
ktid = data.pop("ktid")
|
|
||||||
base_url = data.pop("base_url", None)
|
base_url = data.pop("base_url", None)
|
||||||
return cls(**data, ktid=Long.from_optional_bytes(ktid), base_url=URL(base_url) if base_url else None)
|
return cls(**data, base_url=URL(base_url) if base_url else None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_optional_row(cls, row: Record | None) -> Puppet | None:
|
def _from_optional_row(cls, row: Record | None) -> Puppet | None:
|
||||||
return cls._from_row(row) if row is not None else None
|
return cls._from_row(row) if row is not None else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_ktid(cls, ktid: Long) -> Puppet | None:
|
async def get_by_ktid(cls, ktid: int) -> Puppet | None:
|
||||||
q = (
|
q = (
|
||||||
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
|
"SELECT ktid, name, photo_id, photo_mxc, name_set, avatar_set, is_registered, "
|
||||||
" custom_mxid, access_token, next_batch, base_url "
|
" custom_mxid, access_token, next_batch, base_url "
|
||||||
"FROM puppet WHERE ktid=$1"
|
"FROM puppet WHERE ktid=$1"
|
||||||
)
|
)
|
||||||
row = await cls.db.fetchrow(q, bytes(ktid))
|
row = await cls.db.fetchrow(q, ktid)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -100,7 +99,7 @@ class Puppet:
|
|||||||
@property
|
@property
|
||||||
def _values(self):
|
def _values(self):
|
||||||
return (
|
return (
|
||||||
bytes(self.ktid),
|
self.ktid,
|
||||||
self.name,
|
self.name,
|
||||||
self.photo_id,
|
self.photo_id,
|
||||||
self.photo_mxc,
|
self.photo_mxc,
|
||||||
@ -123,7 +122,7 @@ class Puppet:
|
|||||||
|
|
||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
q = "DELETE FROM puppet WHERE ktid=$1"
|
q = "DELETE FROM puppet WHERE ktid=$1"
|
||||||
await self.db.execute(q, bytes(self.ktid))
|
await self.db.execute(q, self.ktid)
|
||||||
|
|
||||||
async def save(self) -> None:
|
async def save(self) -> None:
|
||||||
q = """
|
q = """
|
||||||
|
@ -30,7 +30,7 @@ async def create_v1_tables(conn: Connection) -> None:
|
|||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""CREATE TABLE "user" (
|
"""CREATE TABLE "user" (
|
||||||
mxid TEXT PRIMARY KEY,
|
mxid TEXT PRIMARY KEY,
|
||||||
ktid BYTES UNIQUE,
|
ktid BIGINT UNIQUE,
|
||||||
uuid TEXT,
|
uuid TEXT,
|
||||||
access_token TEXT,
|
access_token TEXT,
|
||||||
refresh_token TEXT,
|
refresh_token TEXT,
|
||||||
@ -39,8 +39,8 @@ async def create_v1_tables(conn: Connection) -> None:
|
|||||||
)
|
)
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""CREATE TABLE portal (
|
"""CREATE TABLE portal (
|
||||||
ktid BYTES,
|
ktid BIGINT,
|
||||||
kt_receiver BYTES,
|
kt_receiver BIGINT,
|
||||||
kt_type TEXT,
|
kt_type TEXT,
|
||||||
mxid TEXT UNIQUE,
|
mxid TEXT UNIQUE,
|
||||||
name TEXT,
|
name TEXT,
|
||||||
@ -55,7 +55,7 @@ async def create_v1_tables(conn: Connection) -> None:
|
|||||||
)
|
)
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""CREATE TABLE puppet (
|
"""CREATE TABLE puppet (
|
||||||
ktid BYTES PRIMARY KEY,
|
ktid BIGINT PRIMARY KEY,
|
||||||
name TEXT,
|
name TEXT,
|
||||||
photo_id TEXT,
|
photo_id TEXT,
|
||||||
photo_mxc TEXT,
|
photo_mxc TEXT,
|
||||||
@ -74,10 +74,10 @@ async def create_v1_tables(conn: Connection) -> None:
|
|||||||
"""CREATE TABLE message (
|
"""CREATE TABLE message (
|
||||||
mxid TEXT,
|
mxid TEXT,
|
||||||
mx_room TEXT,
|
mx_room TEXT,
|
||||||
ktid TEXT,
|
ktid BIGINT,
|
||||||
kt_receiver BYTES,
|
kt_receiver BIGINT,
|
||||||
"index" SMALLINT,
|
"index" SMALLINT,
|
||||||
kt_chat BYTES,
|
kt_chat BIGINT,
|
||||||
timestamp BIGINT,
|
timestamp BIGINT,
|
||||||
PRIMARY KEY (ktid, kt_receiver, "index"),
|
PRIMARY KEY (ktid, kt_receiver, "index"),
|
||||||
FOREIGN KEY (kt_chat, kt_receiver) REFERENCES portal(ktid, kt_receiver)
|
FOREIGN KEY (kt_chat, kt_receiver) REFERENCES portal(ktid, kt_receiver)
|
||||||
@ -90,8 +90,8 @@ async def create_v1_tables(conn: Connection) -> None:
|
|||||||
mxid TEXT,
|
mxid TEXT,
|
||||||
mx_room TEXT,
|
mx_room TEXT,
|
||||||
kt_msgid TEXT,
|
kt_msgid TEXT,
|
||||||
kt_receiver BYTES,
|
kt_receiver BIGINT,
|
||||||
kt_sender BYTES,
|
kt_sender BIGINT,
|
||||||
reaction TEXT,
|
reaction TEXT,
|
||||||
PRIMARY KEY (kt_msgid, kt_receiver, kt_sender),
|
PRIMARY KEY (kt_msgid, kt_receiver, kt_sender),
|
||||||
UNIQUE (mxid, mx_room)
|
UNIQUE (mxid, mx_room)
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, ClassVar, List, Set
|
from typing import TYPE_CHECKING, ClassVar, List, Set
|
||||||
|
|
||||||
from asyncpg import Record
|
from asyncpg import Record
|
||||||
from attr import dataclass
|
from attr import dataclass, field
|
||||||
|
|
||||||
from mautrix.types import RoomID, UserID
|
from mautrix.types import RoomID, UserID
|
||||||
from mautrix.util.async_db import Database
|
from mautrix.util.async_db import Database
|
||||||
@ -33,7 +33,7 @@ class User:
|
|||||||
db: ClassVar[Database] = fake_db
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
mxid: UserID
|
mxid: UserID
|
||||||
ktid: Long | None
|
ktid: Long | None = field(converter=lambda x: Long(x) if x is not None else None)
|
||||||
uuid: str | None
|
uuid: str | None
|
||||||
access_token: str | None
|
access_token: str | None
|
||||||
refresh_token: str | None
|
refresh_token: str | None
|
||||||
@ -41,9 +41,7 @@ class User:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_row(cls, row: Record) -> User:
|
def _from_row(cls, row: Record) -> User:
|
||||||
data = {**row}
|
return cls(**row)
|
||||||
ktid = data.pop("ktid", None)
|
|
||||||
return cls(**data, ktid=Long.from_optional_bytes(ktid))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_optional_row(cls, row: Record | None) -> User | None:
|
def _from_optional_row(cls, row: Record | None) -> User | None:
|
||||||
@ -59,9 +57,9 @@ class User:
|
|||||||
return [cls._from_row(row) for row in rows if row]
|
return [cls._from_row(row) for row in rows if row]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_ktid(cls, ktid: Long) -> User | None:
|
async def get_by_ktid(cls, ktid: int) -> User | None:
|
||||||
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
|
q = 'SELECT mxid, ktid, uuid, access_token, refresh_token, notice_room FROM "user" WHERE ktid=$1'
|
||||||
row = await cls.db.fetchrow(q, bytes(ktid))
|
row = await cls.db.fetchrow(q, ktid)
|
||||||
return cls._from_optional_row(row)
|
return cls._from_optional_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -81,7 +79,7 @@ class User:
|
|||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
"""
|
"""
|
||||||
await self.db.execute(
|
await self.db.execute(
|
||||||
q, self.mxid, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room
|
q, self.mxid, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
@ -93,5 +91,5 @@ class User:
|
|||||||
WHERE mxid=$6
|
WHERE mxid=$6
|
||||||
"""
|
"""
|
||||||
await self.db.execute(
|
await self.db.execute(
|
||||||
q, Long.to_optional_bytes(self.ktid), self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
|
q, self.ktid, self.uuid, self.access_token, self.refresh_token, self.notice_room, self.mxid
|
||||||
)
|
)
|
||||||
|
@ -13,69 +13,14 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import ClassVar, Optional
|
from mautrix.types import Serializable, JSON
|
||||||
|
|
||||||
from attr import dataclass, asdict
|
|
||||||
import bson
|
|
||||||
|
|
||||||
from mautrix.types import SerializableAttrs, JSON
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
class Long(int, Serializable):
|
||||||
class Long(SerializableAttrs):
|
|
||||||
high: int
|
|
||||||
low: int
|
|
||||||
unsigned: bool
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, raw: bytes) -> "Long":
|
|
||||||
return cls(**bson.loads(raw))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
|
|
||||||
return cls(**bson.loads(raw)) if raw is not None else None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
|
|
||||||
return bytes(value) if value is not None else None
|
|
||||||
|
|
||||||
def serialize(self) -> JSON:
|
def serialize(self) -> JSON:
|
||||||
data = super().serialize()
|
return {"__type__": "Long", "str": str(self)}
|
||||||
data["__type__"] = "Long"
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
@classmethod
|
||||||
return bson.dumps(asdict(self))
|
def deserialize(cls, raw: JSON) -> "Long":
|
||||||
|
assert isinstance(raw, str), f"Long deserialization expected a string, but got non-string value {raw}"
|
||||||
def __int__(self) -> int:
|
return cls(raw)
|
||||||
if self.unsigned:
|
|
||||||
pass
|
|
||||||
result = \
|
|
||||||
((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
|
|
||||||
( self.low + (1 << 32 if self.low < 0 else 0))
|
|
||||||
return result + (1 << 32 if self.unsigned and result < 0 else 0)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(int(self))
|
|
||||||
|
|
||||||
ZERO: ClassVar["Long"]
|
|
||||||
|
|
||||||
Long.ZERO = Long(0, 0, False)
|
|
||||||
|
|
||||||
|
|
||||||
class IntLong(Long):
|
|
||||||
"""Helper class for constructing a Long from an int."""
|
|
||||||
def __init__(self, val: int):
|
|
||||||
if val < 0:
|
|
||||||
pass
|
|
||||||
super().__init__(
|
|
||||||
high=(val & 0xffffffff00000000) >> 32,
|
|
||||||
low = val & 0x00000000ffffffff,
|
|
||||||
unsigned=val < 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StrLong(IntLong):
|
|
||||||
"""Helper class for constructing a Long from the string representation of an int."""
|
|
||||||
def __init__(self, val: str):
|
|
||||||
super().__init__(int(val))
|
|
||||||
|
@ -48,7 +48,7 @@ from .db import (
|
|||||||
)
|
)
|
||||||
from .formatter import kakaotalk_to_matrix, matrix_to_kakaotalk
|
from .formatter import kakaotalk_to_matrix, matrix_to_kakaotalk
|
||||||
|
|
||||||
from .kt.types.bson import Long, IntLong
|
from .kt.types.bson import Long
|
||||||
from .kt.types.channel.channel_info import ChannelInfo
|
from .kt.types.channel.channel_info import ChannelInfo
|
||||||
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
|
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
|
||||||
from .kt.types.chat.chat import Chatlog
|
from .kt.types.chat.chat import Chatlog
|
||||||
@ -87,7 +87,7 @@ StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STAT
|
|||||||
class Portal(DBPortal, BasePortal):
|
class Portal(DBPortal, BasePortal):
|
||||||
invite_own_puppet_to_pm: bool = False
|
invite_own_puppet_to_pm: bool = False
|
||||||
by_mxid: dict[RoomID, Portal] = {}
|
by_mxid: dict[RoomID, Portal] = {}
|
||||||
by_ktid: dict[tuple[Long, Long], Portal] = {}
|
by_ktid: dict[tuple[int, int], Portal] = {}
|
||||||
matrix: m.MatrixHandler
|
matrix: m.MatrixHandler
|
||||||
config: Config
|
config: Config
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class Portal(DBPortal, BasePortal):
|
|||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
if self.mxid:
|
if self.mxid:
|
||||||
await DBMessage.delete_all_by_room(self.mxid)
|
await DBMessage.delete_all_by_room(self.mxid)
|
||||||
self.by_ktid.pop(self.ktid_full, None)
|
self.by_ktid.pop(self._ktid_full, None)
|
||||||
self.by_mxid.pop(self.mxid, None)
|
self.by_mxid.pop(self.mxid, None)
|
||||||
await super().delete()
|
await super().delete()
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ class Portal(DBPortal, BasePortal):
|
|||||||
# region Properties
|
# region Properties
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ktid_full(self) -> tuple[Long, Long]:
|
def _ktid_full(self) -> tuple[int, int]:
|
||||||
return self.ktid, self.kt_receiver
|
return self.ktid, self.kt_receiver
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -617,7 +617,7 @@ class Portal(DBPortal, BasePortal):
|
|||||||
# endregion
|
# endregion
|
||||||
# region Matrix event handling
|
# region Matrix event handling
|
||||||
|
|
||||||
def require_send_lock(self, user_id: Long) -> asyncio.Lock:
|
def require_send_lock(self, user_id: int) -> asyncio.Lock:
|
||||||
try:
|
try:
|
||||||
lock = self._send_locks[user_id]
|
lock = self._send_locks[user_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -625,7 +625,7 @@ class Portal(DBPortal, BasePortal):
|
|||||||
self._send_locks[user_id] = lock
|
self._send_locks[user_id] = lock
|
||||||
return lock
|
return lock
|
||||||
|
|
||||||
def optional_send_lock(self, user_id: Long) -> asyncio.Lock | FakeLock:
|
def optional_send_lock(self, user_id: int) -> asyncio.Lock | FakeLock:
|
||||||
try:
|
try:
|
||||||
return self._send_locks[user_id]
|
return self._send_locks[user_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -937,8 +937,8 @@ class Portal(DBPortal, BasePortal):
|
|||||||
after_log_id: Long | None,
|
after_log_id: Long | None,
|
||||||
channel_info: ChannelInfo,
|
channel_info: ChannelInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.log.debug("Backfilling history through %s", source.mxid)
|
self.log.debug(f"Backfilling history through {source.mxid}")
|
||||||
self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid))
|
self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
|
||||||
messages = await source.client.get_chats(
|
messages = await source.client.get_chats(
|
||||||
channel_info.channelId,
|
channel_info.channelId,
|
||||||
limit,
|
limit,
|
||||||
@ -961,7 +961,7 @@ class Portal(DBPortal, BasePortal):
|
|||||||
# region Database getters
|
# region Database getters
|
||||||
|
|
||||||
async def postinit(self) -> None:
|
async def postinit(self) -> None:
|
||||||
self.by_ktid[self.ktid_full] = self
|
self.by_ktid[self._ktid_full] = self
|
||||||
if self.mxid:
|
if self.mxid:
|
||||||
self.by_mxid[self.mxid] = self
|
self.by_mxid[self.mxid] = self
|
||||||
self._main_intent = (
|
self._main_intent = (
|
||||||
@ -989,14 +989,14 @@ class Portal(DBPortal, BasePortal):
|
|||||||
@async_getter_lock
|
@async_getter_lock
|
||||||
async def get_by_ktid(
|
async def get_by_ktid(
|
||||||
cls,
|
cls,
|
||||||
ktid: Long,
|
ktid: int,
|
||||||
*,
|
*,
|
||||||
kt_receiver: Long = Long.ZERO,
|
kt_receiver: int = 0,
|
||||||
create: bool = True,
|
create: bool = True,
|
||||||
kt_type: ChannelType | None = None,
|
kt_type: ChannelType | None = None,
|
||||||
) -> Portal | None:
|
) -> Portal | None:
|
||||||
if kt_type:
|
if kt_type:
|
||||||
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else Long.ZERO
|
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
|
||||||
ktid_full = (ktid, kt_receiver)
|
ktid_full = (ktid, kt_receiver)
|
||||||
try:
|
try:
|
||||||
return cls.by_ktid[ktid_full]
|
return cls.by_ktid[ktid_full]
|
||||||
@ -1017,7 +1017,7 @@ class Portal(DBPortal, BasePortal):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_all_by_receiver(cls, kt_receiver: Long) -> AsyncGenerator[Portal, None]:
|
async def get_all_by_receiver(cls, kt_receiver: int) -> AsyncGenerator[Portal, None]:
|
||||||
portals = await super().get_all_by_receiver(kt_receiver)
|
portals = await super().get_all_by_receiver(kt_receiver)
|
||||||
portal: Portal
|
portal: Portal
|
||||||
for portal in portals:
|
for portal in portals:
|
||||||
|
@ -31,7 +31,7 @@ from . import matrix as m, portal as p, user as u
|
|||||||
from .config import Config
|
from .config import Config
|
||||||
from .db import Puppet as DBPuppet
|
from .db import Puppet as DBPuppet
|
||||||
|
|
||||||
from .kt.types.bson import Long, StrLong
|
from .kt.types.bson import Long
|
||||||
|
|
||||||
from .kt.client.types import UserInfoUnion
|
from .kt.client.types import UserInfoUnion
|
||||||
|
|
||||||
@ -43,9 +43,9 @@ class Puppet(DBPuppet, BasePuppet):
|
|||||||
mx: m.MatrixHandler
|
mx: m.MatrixHandler
|
||||||
config: Config
|
config: Config
|
||||||
hs_domain: str
|
hs_domain: str
|
||||||
mxid_template: SimpleTemplate[StrLong]
|
mxid_template: SimpleTemplate[int]
|
||||||
|
|
||||||
by_ktid: dict[Long, Puppet] = {}
|
by_ktid: dict[int, Puppet] = {}
|
||||||
by_custom_mxid: dict[UserID, Puppet] = {}
|
by_custom_mxid: dict[UserID, Puppet] = {}
|
||||||
|
|
||||||
_last_info_sync: datetime | None
|
_last_info_sync: datetime | None
|
||||||
@ -127,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
|
|||||||
keyword="userid",
|
keyword="userid",
|
||||||
prefix="@",
|
prefix="@",
|
||||||
suffix=f":{Puppet.hs_domain}",
|
suffix=f":{Puppet.hs_domain}",
|
||||||
type=StrLong,
|
type=int,
|
||||||
)
|
)
|
||||||
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
|
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
|
||||||
cls.homeserver_url_map = {
|
cls.homeserver_url_map = {
|
||||||
@ -218,9 +218,9 @@ class Puppet(DBPuppet, BasePuppet):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@async_getter_lock
|
@async_getter_lock
|
||||||
async def get_by_ktid(cls, ktid: Long, *, create: bool = True) -> Puppet | None:
|
async def get_by_ktid(cls, ktid: int, *, create: bool = True) -> Puppet | None:
|
||||||
try:
|
try:
|
||||||
return cls.by_ktid[ktid]
|
return cls.by_ktid[int]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -260,11 +260,11 @@ class Puppet(DBPuppet, BasePuppet):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_id_from_mxid(cls, mxid: UserID) -> Long | None:
|
def get_id_from_mxid(cls, mxid: UserID) -> int | None:
|
||||||
return cls.mxid_template.parse(mxid)
|
return cls.mxid_template.parse(mxid)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mxid_from_id(cls, ktid: Long) -> UserID:
|
def get_mxid_from_id(cls, ktid: int) -> UserID:
|
||||||
return UserID(cls.mxid_template.format_full(ktid))
|
return UserID(cls.mxid_template.format_full(ktid))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -84,7 +84,7 @@ class User(DBUser, BaseUser):
|
|||||||
config: Config
|
config: Config
|
||||||
|
|
||||||
by_mxid: dict[UserID, User] = {}
|
by_mxid: dict[UserID, User] = {}
|
||||||
by_ktid: dict[Long, User] = {}
|
by_ktid: dict[int, User] = {}
|
||||||
|
|
||||||
client: Client | None
|
client: Client | None
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ class User(DBUser, BaseUser):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@async_getter_lock
|
@async_getter_lock
|
||||||
async def get_by_ktid(cls, ktid: Long) -> User | None:
|
async def get_by_ktid(cls, ktid: int) -> User | None:
|
||||||
try:
|
try:
|
||||||
return cls.by_ktid[ktid]
|
return cls.by_ktid[ktid]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -168,10 +168,9 @@ export default class PeerClient {
|
|||||||
* @returns {Promise<void>}
|
* @returns {Promise<void>}
|
||||||
*/
|
*/
|
||||||
#write(data) {
|
#write(data) {
|
||||||
return promisify(cb => this.socket.write(JSON.stringify(data) + "\n", cb))
|
return promisify(cb => this.socket.write(JSON.stringify(data, this.#writeReplacer) + "\n", cb))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {Object} req
|
* @param {Object} req
|
||||||
* @param {string} req.passcode
|
* @param {string} req.passcode
|
||||||
@ -414,7 +413,7 @@ export default class PeerClient {
|
|||||||
}
|
}
|
||||||
let req
|
let req
|
||||||
try {
|
try {
|
||||||
req = JSON.parse(line)
|
req = JSON.parse(line, this.#readReviver)
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.log("Non-JSON request:", line)
|
this.log("Non-JSON request:", line)
|
||||||
return
|
return
|
||||||
@ -465,7 +464,6 @@ export default class PeerClient {
|
|||||||
const resp = { id: req.id }
|
const resp = { id: req.id }
|
||||||
delete req.id
|
delete req.id
|
||||||
delete req.command
|
delete req.command
|
||||||
req = typeify(req)
|
|
||||||
resp.command = "response"
|
resp.command = "response"
|
||||||
try {
|
try {
|
||||||
resp.response = await handler(req)
|
resp.response = await handler(req)
|
||||||
@ -483,29 +481,22 @@ export default class PeerClient {
|
|||||||
}
|
}
|
||||||
await this.#write(resp)
|
await this.#write(resp)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
#writeReplacer = function(key, value) {
|
||||||
* Recursively scan an object to check if any of its sub-objects
|
if (value instanceof Long) {
|
||||||
* should be converted into instances of a specified class.
|
return value.toString()
|
||||||
* @param obj The object to be scanned & updated.
|
} else {
|
||||||
* @returns The converted object.
|
return value
|
||||||
*/
|
}
|
||||||
function typeify(obj) {
|
|
||||||
if (!(obj instanceof Object)) {
|
|
||||||
return obj
|
|
||||||
}
|
}
|
||||||
const converterFunc = TYPE_MAP.get(obj.__type__)
|
|
||||||
if (converterFunc !== undefined) {
|
|
||||||
return converterFunc(obj)
|
|
||||||
}
|
|
||||||
for (const key in obj) {
|
|
||||||
obj[key] = typeify(obj[key])
|
|
||||||
}
|
|
||||||
return obj
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO Add more if needed
|
#readReviver = function(key, value) {
|
||||||
const TYPE_MAP = new Map([
|
if (value instanceof Object) {
|
||||||
["Long", (obj) => new Long(obj.low, obj.high, obj.unsigned)],
|
// TODO Use a type map if there will be many possible types
|
||||||
])
|
if (value.__type__ == "Long") {
|
||||||
|
return Long.fromString(value.str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user