Puppets and backfilling
This commit is contained in:
parent
e28694c987
commit
0c9550841c
|
@ -99,3 +99,15 @@ async def ping(evt: CommandEvent) -> None:
|
||||||
async def refresh(evt: CommandEvent) -> None:
|
async def refresh(evt: CommandEvent) -> None:
|
||||||
await evt.sender.refresh(force_notice=True)
|
await evt.sender.refresh(force_notice=True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@command_handler(
|
||||||
|
needs_auth=True,
|
||||||
|
management_only=True,
|
||||||
|
help_section=SECTION_CONNECTION,
|
||||||
|
help_text="Resync chats",
|
||||||
|
)
|
||||||
|
async def sync(evt: CommandEvent) -> None:
|
||||||
|
await evt.mark_read()
|
||||||
|
await evt.sender.post_login(is_startup=False)
|
||||||
|
await evt.reply("Sync complete")
|
||||||
|
|
|
@ -23,6 +23,8 @@ from attr import dataclass
|
||||||
from mautrix.types import EventID, RoomID
|
from mautrix.types import EventID, RoomID
|
||||||
from mautrix.util.async_db import Database
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
|
from ..kt.types.bson import Long, StrLong
|
||||||
|
|
||||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,46 +32,44 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
class Message:
|
class Message:
|
||||||
db: ClassVar[Database] = fake_db
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
|
# TODO Store all Long values as the same type
|
||||||
mxid: EventID
|
mxid: EventID
|
||||||
mx_room: RoomID
|
mx_room: RoomID
|
||||||
ktid: str | None
|
ktid: Long
|
||||||
kt_txn_id: int | None
|
|
||||||
index: int
|
index: int
|
||||||
kt_chat: int
|
kt_chat: Long
|
||||||
kt_receiver: int
|
kt_receiver: Long
|
||||||
kt_sender: int
|
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_row(cls, row: Record | None) -> Message | None:
|
def _from_row(cls, row: Record | None) -> Message | None:
|
||||||
if row is None:
|
data = {**row}
|
||||||
return None
|
ktid = data.pop("ktid")
|
||||||
return cls(**row)
|
kt_chat = data.pop("kt_chat")
|
||||||
|
kt_receiver = data.pop("kt_receiver")
|
||||||
columns = 'mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, kt_sender, timestamp'
|
return cls(
|
||||||
|
**data,
|
||||||
|
ktid=StrLong(ktid),
|
||||||
|
kt_chat=Long.from_bytes(kt_chat),
|
||||||
|
kt_receiver=Long.from_bytes(kt_receiver)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_all_by_ktid(cls, ktid: str, kt_receiver: int) -> list[Message]:
|
def _from_optional_row(cls, row: Record | None) -> Message | None:
|
||||||
|
return cls._from_row(row) if row is not None else None
|
||||||
|
|
||||||
|
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
|
||||||
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
|
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
|
||||||
rows = await cls.db.fetch(q, ktid, kt_receiver)
|
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
|
||||||
return [cls._from_row(row) for row in rows]
|
return [cls._from_row(row) for row in rows]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_ktid(cls, ktid: str, kt_receiver: int, index: int = 0) -> Message | None:
|
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
|
||||||
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||||
row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
|
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
|
||||||
return cls._from_row(row)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_by_ktid_or_oti(
|
|
||||||
cls, ktid: str, oti: int, kt_receiver: int, kt_sender: int, index: int = 0
|
|
||||||
) -> Message | None:
|
|
||||||
q = (
|
|
||||||
f"SELECT {cls.columns} "
|
|
||||||
"FROM message WHERE (ktid=$1 OR (kt_txn_id=$2 AND kt_sender=$3)) AND "
|
|
||||||
' kt_receiver=$4 AND "index"=$5'
|
|
||||||
)
|
|
||||||
row = await cls.db.fetchrow(q, ktid, oti, kt_sender, kt_receiver, index)
|
|
||||||
return cls._from_row(row)
|
return cls._from_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -83,18 +83,18 @@ class Message:
|
||||||
return cls._from_row(row)
|
return cls._from_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
|
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
|
||||||
q = (
|
q = (
|
||||||
f"SELECT {cls.columns} "
|
f"SELECT {cls.columns} "
|
||||||
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
|
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
|
||||||
"ORDER BY timestamp DESC LIMIT 1"
|
"ORDER BY timestamp DESC LIMIT 1"
|
||||||
)
|
)
|
||||||
row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
|
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
|
||||||
return cls._from_row(row)
|
return cls._from_row(row)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_closest_before(
|
async def get_closest_before(
|
||||||
cls, kt_chat: int, kt_receiver: int, timestamp: int
|
cls, kt_chat: Long, kt_receiver: Long, timestamp: int
|
||||||
) -> Message | None:
|
) -> Message | None:
|
||||||
q = (
|
q = (
|
||||||
f"SELECT {cls.columns} "
|
f"SELECT {cls.columns} "
|
||||||
|
@ -102,23 +102,21 @@ class Message:
|
||||||
" ktid IS NOT NULL "
|
" ktid IS NOT NULL "
|
||||||
"ORDER BY timestamp DESC LIMIT 1"
|
"ORDER BY timestamp DESC LIMIT 1"
|
||||||
)
|
)
|
||||||
row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
|
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
||||||
return cls._from_row(row)
|
return cls._from_row(row)
|
||||||
|
|
||||||
_insert_query = (
|
_insert_query = (
|
||||||
'INSERT INTO message (mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, '
|
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
|
||||||
" kt_sender, timestamp) "
|
" timestamp) "
|
||||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def bulk_create(
|
async def bulk_create(
|
||||||
cls,
|
cls,
|
||||||
ktid: str,
|
ktid: Long,
|
||||||
oti: int,
|
kt_chat: Long,
|
||||||
kt_chat: int,
|
kt_receiver: Long,
|
||||||
kt_receiver: int,
|
|
||||||
kt_sender: int,
|
|
||||||
event_ids: list[EventID],
|
event_ids: list[EventID],
|
||||||
timestamp: int,
|
timestamp: int,
|
||||||
mx_room: RoomID,
|
mx_room: RoomID,
|
||||||
|
@ -127,7 +125,7 @@ class Message:
|
||||||
return
|
return
|
||||||
columns = [col.strip('"') for col in cls.columns.split(", ")]
|
columns = [col.strip('"') for col in cls.columns.split(", ")]
|
||||||
records = [
|
records = [
|
||||||
(mxid, mx_room, ktid, oti, index, kt_chat, kt_receiver, kt_sender, timestamp)
|
(mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
||||||
for index, mxid in enumerate(event_ids)
|
for index, mxid in enumerate(event_ids)
|
||||||
]
|
]
|
||||||
async with cls.db.acquire() as conn, conn.transaction():
|
async with cls.db.acquire() as conn, conn.transaction():
|
||||||
|
@ -142,19 +140,17 @@ class Message:
|
||||||
q,
|
q,
|
||||||
self.mxid,
|
self.mxid,
|
||||||
self.mx_room,
|
self.mx_room,
|
||||||
self.ktid,
|
str(self.ktid),
|
||||||
self.kt_txn_id,
|
|
||||||
self.index,
|
self.index,
|
||||||
self.kt_chat,
|
bytes(self.kt_chat),
|
||||||
self.kt_receiver,
|
bytes(self.kt_receiver),
|
||||||
self.kt_sender,
|
|
||||||
self.timestamp,
|
self.timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||||
await self.db.execute(q, self.ktid, self.kt_receiver, self.index)
|
await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index)
|
||||||
|
|
||||||
async def update(self) -> None:
|
async def update(self) -> None:
|
||||||
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
|
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
|
||||||
await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)
|
await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room)
|
||||||
|
|
|
@ -186,7 +186,6 @@ bridge:
|
||||||
# Whether or not the KakaoTalk users of logged in Matrix users should be
|
# Whether or not the KakaoTalk users of logged in Matrix users should be
|
||||||
# invited to private chats when backfilling history from KakaoTalk. This is
|
# invited to private chats when backfilling history from KakaoTalk. This is
|
||||||
# usually needed to prevent rate limits and to allow timestamp massaging.
|
# usually needed to prevent rate limits and to allow timestamp massaging.
|
||||||
# TODO Is this necessary?
|
|
||||||
invite_own_puppet: true
|
invite_own_puppet: true
|
||||||
# Maximum number of messages to backfill initially.
|
# Maximum number of messages to backfill initially.
|
||||||
# Set to 0 to disable backfilling when creating portal.
|
# Set to 0 to disable backfilling when creating portal.
|
||||||
|
|
|
@ -43,7 +43,11 @@ from ..types.chat.chat import Chatlog
|
||||||
from ..types.oauth import OAuthCredential, OAuthInfo
|
from ..types.oauth import OAuthCredential, OAuthInfo
|
||||||
from ..types.request import (
|
from ..types.request import (
|
||||||
deserialize_result,
|
deserialize_result,
|
||||||
ResultType, RootCommandResult, CommandResultDoneValue)
|
ResultType,
|
||||||
|
ResultListType,
|
||||||
|
RootCommandResult,
|
||||||
|
CommandResultDoneValue
|
||||||
|
)
|
||||||
|
|
||||||
from .types import ChannelInfoUnion
|
from .types import ChannelInfoUnion
|
||||||
from .types import PortalChannelInfo
|
from .types import PortalChannelInfo
|
||||||
|
@ -155,7 +159,8 @@ class Client:
|
||||||
) -> _RequestContextManager:
|
) -> _RequestContextManager:
|
||||||
# TODO Is auth ever needed?
|
# TODO Is auth ever needed?
|
||||||
headers = {
|
headers = {
|
||||||
**self._headers,
|
# TODO Are any default headers needed?
|
||||||
|
#**self._headers,
|
||||||
**(headers or {}),
|
**(headers or {}),
|
||||||
}
|
}
|
||||||
url = URL(url)
|
url = URL(url)
|
||||||
|
@ -185,16 +190,24 @@ class Client:
|
||||||
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
||||||
return login_result
|
return login_result
|
||||||
|
|
||||||
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
|
|
||||||
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
|
|
||||||
return profile_req_struct.profile
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
async def is_connected(self) -> bool:
|
async def is_connected(self) -> bool:
|
||||||
resp = await self._rpc_client.request("is_connected")
|
resp = await self._rpc_client.request("is_connected")
|
||||||
return resp["is_connected"]
|
return resp["is_connected"]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
|
||||||
|
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
|
||||||
|
return profile_req_struct.profile
|
||||||
|
|
||||||
|
async def get_profile(self, user_id: Long) -> ProfileStruct:
|
||||||
|
profile_req_struct = await self._api_user_request_result(
|
||||||
|
ProfileReqStruct,
|
||||||
|
"get_profile",
|
||||||
|
user_id=user_id.serialize()
|
||||||
|
)
|
||||||
|
return profile_req_struct.profile
|
||||||
|
|
||||||
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo:
|
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo:
|
||||||
req = await self._api_user_request_result(
|
req = await self._api_user_request_result(
|
||||||
PortalChannelInfo,
|
PortalChannelInfo,
|
||||||
|
@ -204,13 +217,13 @@ class Client:
|
||||||
req.channel_info = channel_info
|
req.channel_info = channel_info
|
||||||
return req
|
return req
|
||||||
|
|
||||||
async def get_profile(self, user_id: Long) -> ProfileStruct:
|
async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
|
||||||
profile_req_struct = await self._api_user_request_result(
|
return (await self._api_user_request_result(
|
||||||
ProfileReqStruct,
|
ResultListType(Chatlog),
|
||||||
"get_profile",
|
"get_chats",
|
||||||
user_id=user_id.serialize()
|
channel_id=channel_id.serialize(),
|
||||||
)
|
sync_from=sync_from.serialize() if sync_from else None
|
||||||
return profile_req_struct.profile
|
))[-limit if limit else 0:]
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
# TODO Stop all event handlers
|
# TODO Stop all event handlers
|
||||||
|
|
|
@ -15,23 +15,44 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""Custom wrapper classes around types defined by the KakaoTalk API."""
|
"""Custom wrapper classes around types defined by the KakaoTalk API."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, NewType, Union
|
||||||
|
|
||||||
from attr import dataclass
|
from attr import dataclass
|
||||||
|
|
||||||
from mautrix.types import SerializableAttrs
|
from mautrix.types import SerializableAttrs, JSON, deserializer
|
||||||
|
|
||||||
from ..types.channel.channel_info import NormalChannelInfo
|
from ..types.channel.channel_info import NormalChannelInfo
|
||||||
from ..types.openlink.open_channel_info import OpenChannelInfo
|
from ..types.openlink.open_channel_info import OpenChannelInfo
|
||||||
from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo
|
from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo
|
||||||
|
|
||||||
|
|
||||||
ChannelInfoUnion = Union[NormalChannelInfo, OpenChannelInfo]
|
ChannelInfoUnion = NewType("ChannelInfoUnion", Union[NormalChannelInfo, OpenChannelInfo])
|
||||||
UserInfoUnion = Union[NormalChannelUserInfo, OpenChannelUserInfo]
|
|
||||||
|
@deserializer(ChannelInfoUnion)
|
||||||
|
def deserialize_channel_info_union(data: JSON) -> ChannelInfoUnion:
|
||||||
|
if "openLink" in data:
|
||||||
|
return OpenChannelInfo.deserialize(data)
|
||||||
|
else:
|
||||||
|
return NormalChannelInfo.deserialize(data)
|
||||||
|
|
||||||
|
setattr(ChannelInfoUnion, "deserialize", deserialize_channel_info_union)
|
||||||
|
|
||||||
|
|
||||||
|
UserInfoUnion = NewType("UserInfoUnion", Union[NormalChannelUserInfo, OpenChannelUserInfo])
|
||||||
|
|
||||||
|
@deserializer(UserInfoUnion)
|
||||||
|
def deserialize_user_info_union(data: JSON) -> UserInfoUnion:
|
||||||
|
if "perm" in data:
|
||||||
|
return OpenChannelUserInfo.deserialize(data)
|
||||||
|
else:
|
||||||
|
return NormalChannelUserInfo.deserialize(data)
|
||||||
|
|
||||||
|
setattr(UserInfoUnion, "deserialize", deserialize_user_info_union)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PortalChannelInfo(SerializableAttrs):
|
class PortalChannelInfo(SerializableAttrs):
|
||||||
name: str
|
name: str
|
||||||
#participants: list[PuppetUserInfo]
|
participants: list[UserInfoUnion]
|
||||||
# TODO Image
|
# TODO Image
|
||||||
channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller
|
channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller
|
||||||
|
|
|
@ -32,11 +32,11 @@ class Long(SerializableAttrs):
|
||||||
return cls(**bson.loads(raw))
|
return cls(**bson.loads(raw))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_optional_bytes(cls, raw: bytes | None) -> Optional["Long"]:
|
def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
|
||||||
return cls(**bson.loads(raw)) if raw is not None else None
|
return cls(**bson.loads(raw)) if raw is not None else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_optional_bytes(cls, value: Optional["Long"]) -> bytes | None:
|
def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
|
||||||
return bytes(value) if value is not None else None
|
return bytes(value) if value is not None else None
|
||||||
|
|
||||||
def serialize(self) -> JSON:
|
def serialize(self) -> JSON:
|
||||||
|
@ -48,13 +48,34 @@ class Long(SerializableAttrs):
|
||||||
return bson.dumps(asdict(self))
|
return bson.dumps(asdict(self))
|
||||||
|
|
||||||
def __int__(self) -> int:
|
def __int__(self) -> int:
|
||||||
# TODO Is this right?
|
if self.unsigned:
|
||||||
return self.high << 32 + self.low
|
pass
|
||||||
|
result = \
|
||||||
|
((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
|
||||||
|
( self.low + (1 << 32 if self.low < 0 else 0))
|
||||||
|
return result + (1 << 32 if self.unsigned and result < 0 else 0)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.high << 32 if self.high else ''}{self.low}"
|
return str(int(self))
|
||||||
|
|
||||||
ZERO: ClassVar["Long"]
|
ZERO: ClassVar["Long"]
|
||||||
|
|
||||||
|
|
||||||
Long.ZERO = Long(0, 0, False)
|
Long.ZERO = Long(0, 0, False)
|
||||||
|
|
||||||
|
|
||||||
|
class IntLong(Long):
|
||||||
|
"""Helper class for constructing a Long from an int."""
|
||||||
|
def __init__(self, val: int):
|
||||||
|
if val < 0:
|
||||||
|
pass
|
||||||
|
super().__init__(
|
||||||
|
high=(val & 0xffffffff00000000) >> 32,
|
||||||
|
low = val & 0x00000000ffffffff,
|
||||||
|
unsigned=val < 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StrLong(IntLong):
|
||||||
|
"""Helper class for constructing a Long from the string representation of an int."""
|
||||||
|
def __init__(self, val: str):
|
||||||
|
super().__init__(int(val))
|
||||||
|
|
|
@ -23,12 +23,12 @@ class KnownChannelType(str, Enum):
|
||||||
DirectChat = "DirectChat"
|
DirectChat = "DirectChat"
|
||||||
PlusChat = "PlusChat"
|
PlusChat = "PlusChat"
|
||||||
MemoChat = "MemoChat"
|
MemoChat = "MemoChat"
|
||||||
OM = "OM"
|
OM = "OM" # "OpenMulti"?
|
||||||
OD = "OD"
|
OD = "OD" # "OpenDirect"?
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_direct(cls, value: Union["KnownChannelType", str]) -> bool:
|
def is_direct(cls, value: Union["KnownChannelType", str]) -> bool:
|
||||||
return str in [KnownChannelType.DirectChat, KnownChannelType.OD]
|
return value == KnownChannelType.DirectChat
|
||||||
|
|
||||||
|
|
||||||
ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str
|
ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str
|
||||||
|
|
|
@ -13,12 +13,12 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Generic, Type, TypeVar, Union
|
from typing import Generic, Type, TypeVar, Union, Iterable
|
||||||
|
|
||||||
from attr import dataclass
|
from attr import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
from mautrix.types import SerializableAttrs, JSON
|
from mautrix.types import Serializable, SerializableAttrs, JSON
|
||||||
|
|
||||||
from .api.auth_api_client import KnownAuthStatusCode
|
from .api.auth_api_client import KnownAuthStatusCode
|
||||||
|
|
||||||
|
@ -80,7 +80,22 @@ class RootCommandResult(ResponseState):
|
||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
ResultType = TypeVar("ResultType", bound=SerializableAttrs)
|
ResultType = TypeVar("ResultType", bound=Serializable)
|
||||||
|
|
||||||
|
def ResultListType(result_type: Type[ResultType]):
|
||||||
|
class _ResultListType(list[result_type], Serializable):
|
||||||
|
def __init__(self, iterable: Iterable[result_type]=()):
|
||||||
|
list.__init__(self, (result_type.deserialize(x) for x in iterable))
|
||||||
|
|
||||||
|
def serialize(self) -> list[JSON]:
|
||||||
|
return [v.serialize() for v in self]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, data: list[JSON]) -> "_ResultListType":
|
||||||
|
return cls(data)
|
||||||
|
|
||||||
|
return _ResultListType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommandResultDoneValue(RootCommandResult, Generic[ResultType]):
|
class CommandResultDoneValue(RootCommandResult, Generic[ResultType]):
|
||||||
|
|
|
@ -45,12 +45,14 @@ from .db import (
|
||||||
Message as DBMessage,
|
Message as DBMessage,
|
||||||
Portal as DBPortal,
|
Portal as DBPortal,
|
||||||
)
|
)
|
||||||
|
from .formatter.from_kakaotalk import kakaotalk_to_matrix
|
||||||
|
|
||||||
from .kt.types.bson import Long
|
from .kt.types.bson import Long, IntLong
|
||||||
|
from .kt.types.channel.channel_info import ChannelInfo
|
||||||
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
|
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
|
||||||
from .kt.types.user.channel_user_info import DisplayUserInfo
|
from .kt.types.chat.chat import Chatlog
|
||||||
|
|
||||||
from .kt.client.types import PortalChannelInfo
|
from .kt.client.types import UserInfoUnion, PortalChannelInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .__main__ import KakaoTalkBridge
|
from .__main__ import KakaoTalkBridge
|
||||||
|
@ -201,7 +203,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
#self._update_photo(source, info.image),
|
#self._update_photo(source, info.image),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
changed = await self._update_participants(source, info.channel_info.displayUserList) or changed
|
changed = await self._update_participants(source, info.participants) or changed
|
||||||
if changed or force_save:
|
if changed or force_save:
|
||||||
await self.update_bridge_info()
|
await self.update_bridge_info()
|
||||||
await self.save()
|
await self.save()
|
||||||
|
@ -342,7 +344,7 @@ class Portal(DBPortal, BasePortal):
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _update_participants(self, source: u.User, participants: list[DisplayUserInfo]) -> bool:
|
async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool:
|
||||||
changed = False
|
changed = False
|
||||||
# TODO nick_map?
|
# TODO nick_map?
|
||||||
for participant in participants:
|
for participant in participants:
|
||||||
|
@ -556,10 +558,10 @@ class Portal(DBPortal, BasePortal):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.is_direct:
|
if not self.is_direct:
|
||||||
await self._update_participants(source, info.channel_info.displayUserList)
|
await self._update_participants(source, info.participants)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.backfill(source, is_initial=True, channel=info.channel_info)
|
await self.backfill(source, is_initial=True, channel_info=info.channel_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to backfill new portal")
|
self.log.exception("Failed to backfill new portal")
|
||||||
|
|
||||||
|
@ -752,31 +754,131 @@ class Portal(DBPortal, BasePortal):
|
||||||
self,
|
self,
|
||||||
source: u.User,
|
source: u.User,
|
||||||
sender: p.Puppet,
|
sender: p.Puppet,
|
||||||
message: str,
|
message: Chatlog,
|
||||||
reply_to: None = None,
|
reply_to: Chatlog | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
await self._handle_remote_message(source, sender, message, reply_to)
|
await self._handle_remote_message(source, sender, message, reply_to)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception(
|
self.log.exception(
|
||||||
"Error handling Kakaotalk message <TODO: ID>"
|
"Error handling KakaoTalk message %s",
|
||||||
|
message.logId,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_remote_message(
|
async def _handle_remote_message(
|
||||||
self,
|
self,
|
||||||
source: u.User,
|
source: u.User,
|
||||||
sender: p.Puppet,
|
sender: p.Puppet,
|
||||||
message: str,
|
message: Chatlog,
|
||||||
reply_to: None = None,
|
reply_to: Chatlog | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.log.info("TODO")
|
self.log.debug(f"Handling KakaoTalk event {message.logId}")
|
||||||
|
if not self.mxid:
|
||||||
|
mxid = await self.create_matrix_room(source)
|
||||||
|
if not mxid:
|
||||||
|
# Failed to create
|
||||||
|
return
|
||||||
|
if not await self._bridge_own_message_pm(source, sender, f"message {message.logId}"):
|
||||||
|
return
|
||||||
|
intent = sender.intent_for(self)
|
||||||
|
if (
|
||||||
|
self._backfill_leave is not None
|
||||||
|
and self.ktid != sender.ktid
|
||||||
|
and intent != sender.intent
|
||||||
|
and intent not in self._backfill_leave
|
||||||
|
):
|
||||||
|
self.log.debug("Adding %s's default puppet to room for backfilling", sender.mxid)
|
||||||
|
await self.main_intent.invite_user(self.mxid, intent.mxid)
|
||||||
|
await intent.ensure_joined(self.mxid)
|
||||||
|
self._backfill_leave.add(intent)
|
||||||
|
|
||||||
|
if message.attachment:
|
||||||
|
self.log.info("TODO: _handle_remote_message attachments")
|
||||||
|
if message.supplement:
|
||||||
|
self.log.info("TODO: _handle_remote_message supplements")
|
||||||
|
if message.text:
|
||||||
|
content = await kakaotalk_to_matrix(message.text)
|
||||||
|
event_id = await self._send_message(intent, content, timestamp=message.sendAt)
|
||||||
|
await DBMessage(
|
||||||
|
mxid=event_id,
|
||||||
|
mx_room=self.mxid,
|
||||||
|
ktid=message.logId,
|
||||||
|
index=0,
|
||||||
|
kt_chat=self.ktid,
|
||||||
|
kt_receiver=self.kt_receiver,
|
||||||
|
timestamp=message.sendAt,
|
||||||
|
).insert()
|
||||||
|
await self._send_delivery_receipt(event_id)
|
||||||
|
else:
|
||||||
|
self.log.warning(f"Unhandled KakaoTalk message {message.logId}")
|
||||||
|
|
||||||
# TODO Many more remote handlers
|
# TODO Many more remote handlers
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
async def backfill(self, source: u.User, is_initial: bool, channel: PortalChannelInfo) -> None:
|
async def backfill(self, source: u.User, is_initial: bool, channel_info: ChannelInfo) -> None:
|
||||||
self.log.info("TODO: backfill")
|
limit = (
|
||||||
|
self.config["bridge.backfill.initial_limit"]
|
||||||
|
if is_initial
|
||||||
|
else self.config["bridge.backfill.missed_limit"]
|
||||||
|
)
|
||||||
|
if limit == 0:
|
||||||
|
return
|
||||||
|
elif limit < 0:
|
||||||
|
limit = None
|
||||||
|
last_log_id = None
|
||||||
|
if not is_initial and channel_info.lastChatLog:
|
||||||
|
last_log_id = channel_info.lastChatLog.logId
|
||||||
|
most_recent = await DBMessage.get_most_recent(self.ktid, self.kt_receiver)
|
||||||
|
if most_recent and is_initial:
|
||||||
|
self.log.debug("Not backfilling %s: already bridged messages found", self.ktid_log)
|
||||||
|
# TODO Should this be removed? With it, can't sync an empty portal!
|
||||||
|
#elif (not most_recent or not most_recent.timestamp) and not is_initial:
|
||||||
|
# self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log)
|
||||||
|
elif last_log_id and most_recent and int(most_recent.ktid) >= int(last_log_id):
|
||||||
|
self.log.debug(
|
||||||
|
"Not backfilling %s: last activity is equal to most recent bridged "
|
||||||
|
"message (%s >= %s)",
|
||||||
|
self.ktid_log,
|
||||||
|
most_recent.ktid,
|
||||||
|
last_log_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with self.backfill_lock:
|
||||||
|
await self._backfill(
|
||||||
|
source,
|
||||||
|
limit,
|
||||||
|
most_recent.ktid if most_recent else None,
|
||||||
|
channel_info=channel_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _backfill(
|
||||||
|
self,
|
||||||
|
source: u.User,
|
||||||
|
limit: int | None,
|
||||||
|
after_log_id: Long | None,
|
||||||
|
channel_info: ChannelInfo,
|
||||||
|
) -> None:
|
||||||
|
self.log.debug("Backfilling history through %s", source.mxid)
|
||||||
|
self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid))
|
||||||
|
messages = await source.client.get_chats(
|
||||||
|
channel_info.channelId,
|
||||||
|
limit,
|
||||||
|
after_log_id
|
||||||
|
)
|
||||||
|
if not messages:
|
||||||
|
self.log.debug("Didn't get any messages from server")
|
||||||
|
return
|
||||||
|
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server")
|
||||||
|
self._backfill_leave = set()
|
||||||
|
async with NotificationDisabler(self.mxid, source):
|
||||||
|
for message in messages:
|
||||||
|
puppet = await p.Puppet.get_by_ktid(message.sender.userId)
|
||||||
|
await self.handle_remote_message(source, puppet, message)
|
||||||
|
for intent in self._backfill_leave:
|
||||||
|
self.log.trace("Leaving room with %s post-backfill", intent.mxid)
|
||||||
|
await intent.leave_room(self.mxid)
|
||||||
|
self.log.info("Backfilled %d messages through %s", len(messages), source.mxid)
|
||||||
|
|
||||||
# region Database getters
|
# region Database getters
|
||||||
|
|
||||||
|
|
|
@ -31,8 +31,9 @@ from . import matrix as m, portal as p, user as u
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .db import Puppet as DBPuppet
|
from .db import Puppet as DBPuppet
|
||||||
|
|
||||||
from .kt.types.bson import Long
|
from .kt.types.bson import Long, StrLong
|
||||||
from .kt.types.user.channel_user_info import DisplayUserInfo
|
|
||||||
|
from .kt.client.types import UserInfoUnion
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .__main__ import KakaoTalkBridge
|
from .__main__ import KakaoTalkBridge
|
||||||
|
@ -42,7 +43,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||||
mx: m.MatrixHandler
|
mx: m.MatrixHandler
|
||||||
config: Config
|
config: Config
|
||||||
hs_domain: str
|
hs_domain: str
|
||||||
mxid_template: SimpleTemplate[int]
|
mxid_template: SimpleTemplate[StrLong]
|
||||||
|
|
||||||
by_ktid: dict[Long, Puppet] = {}
|
by_ktid: dict[Long, Puppet] = {}
|
||||||
by_custom_mxid: dict[UserID, Puppet] = {}
|
by_custom_mxid: dict[UserID, Puppet] = {}
|
||||||
|
@ -126,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||||
keyword="userid",
|
keyword="userid",
|
||||||
prefix="@",
|
prefix="@",
|
||||||
suffix=f":{Puppet.hs_domain}",
|
suffix=f":{Puppet.hs_domain}",
|
||||||
type=int,
|
type=StrLong,
|
||||||
)
|
)
|
||||||
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
|
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
|
||||||
cls.homeserver_url_map = {
|
cls.homeserver_url_map = {
|
||||||
|
@ -147,7 +148,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||||
async def update_info(
|
async def update_info(
|
||||||
self,
|
self,
|
||||||
source: u.User,
|
source: u.User,
|
||||||
info: DisplayUserInfo,
|
info: UserInfoUnion,
|
||||||
update_avatar: bool = True,
|
update_avatar: bool = True,
|
||||||
) -> Puppet:
|
) -> Puppet:
|
||||||
self._last_info_sync = datetime.now()
|
self._last_info_sync = datetime.now()
|
||||||
|
@ -161,7 +162,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||||
self.log.exception(f"Failed to update info from source {source.ktid}")
|
self.log.exception(f"Failed to update info from source {source.ktid}")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _update_name(self, info: DisplayUserInfo) -> bool:
|
async def _update_name(self, info: UserInfoUnion) -> bool:
|
||||||
name = info.nickname
|
name = info.nickname
|
||||||
if name != self.name or not self.name_set:
|
if name != self.name or not self.name_set:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -259,7 +260,7 @@ class Puppet(DBPuppet, BasePuppet):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_id_from_mxid(cls, mxid: UserID) -> int | None:
|
def get_id_from_mxid(cls, mxid: UserID) -> Long | None:
|
||||||
return cls.mxid_template.parse(mxid)
|
return cls.mxid_template.parse(mxid)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -415,6 +415,7 @@ class User(DBUser, BaseUser):
|
||||||
|
|
||||||
assert self.client
|
assert self.client
|
||||||
try:
|
try:
|
||||||
|
# TODO if not is_startup, close existing listeners
|
||||||
login_result = await self.client.start()
|
login_result = await self.client.start()
|
||||||
await self._sync_channels(login_result, is_startup)
|
await self._sync_channels(login_result, is_startup)
|
||||||
# TODO connect listeners, even if channel sync fails (except if it's an auth failure)
|
# TODO connect listeners, even if channel sync fails (except if it's an auth failure)
|
||||||
|
@ -502,7 +503,7 @@ class User(DBUser, BaseUser):
|
||||||
await portal.create_matrix_room(self, portal_info)
|
await portal.create_matrix_room(self, portal_info)
|
||||||
else:
|
else:
|
||||||
await portal.update_matrix_room(self, portal_info)
|
await portal.update_matrix_room(self, portal_info)
|
||||||
await portal.backfill(self, is_initial=False, channel=channel_info)
|
await portal.backfill(self, is_initial=False, channel_info=channel_info)
|
||||||
|
|
||||||
async def get_notice_room(self) -> RoomID:
|
async def get_notice_room(self) -> RoomID:
|
||||||
if not self.notice_room:
|
if not self.notice_room:
|
||||||
|
|
|
@ -202,6 +202,14 @@ export default class PeerClient {
|
||||||
return loginRes
|
return loginRes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TODO Consider caching per-user
|
||||||
|
* @param {string} uuid
|
||||||
|
*/
|
||||||
|
async #createAuthClient(uuid) {
|
||||||
|
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO Wrapper for per-user commands
|
// TODO Wrapper for per-user commands
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -237,14 +245,6 @@ export default class PeerClient {
|
||||||
return await oAuthClient.renew(req.oauth_credential)
|
return await oAuthClient.renew(req.oauth_credential)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* TODO Consider caching per-user
|
|
||||||
* @param {string} uuid
|
|
||||||
*/
|
|
||||||
async #createAuthClient(uuid) {
|
|
||||||
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {Object} req
|
* @param {Object} req
|
||||||
* @param {string} req.mxid
|
* @param {string} req.mxid
|
||||||
|
@ -314,22 +314,33 @@ export default class PeerClient {
|
||||||
* @param {string} req.mxid
|
* @param {string} req.mxid
|
||||||
* @param {Long} req.channel_id
|
* @param {Long} req.channel_id
|
||||||
*/
|
*/
|
||||||
getPortalChannelInfo = (req) => {
|
getPortalChannelInfo = async (req) => {
|
||||||
const userClient = this.#getUser(req.mxid)
|
const userClient = this.#getUser(req.mxid)
|
||||||
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
|
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
|
||||||
|
|
||||||
/* TODO Decide if this is needed. If it is, make function async!
|
|
||||||
const res = await talkChannel.updateAll()
|
const res = await talkChannel.updateAll()
|
||||||
if (!res.success) return res
|
if (!res.success) return res
|
||||||
*/
|
|
||||||
|
|
||||||
return this.#makeCommandResult({
|
return this.#makeCommandResult({
|
||||||
name: talkChannel.getDisplayName(),
|
name: talkChannel.getDisplayName(),
|
||||||
//participants: Array.from(talkChannel.getAllUserInfo()),
|
participants: Array.from(talkChannel.getAllUserInfo()),
|
||||||
// TODO Image
|
// TODO Image
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {Object} req
|
||||||
|
* @param {string} req.mxid
|
||||||
|
* @param {Long} req.channel_id
|
||||||
|
* @param {Long?} req.sync_from
|
||||||
|
*/
|
||||||
|
getChats = async (req) => {
|
||||||
|
const userClient = this.#getUser(req.mxid)
|
||||||
|
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
|
||||||
|
|
||||||
|
return await talkChannel.getChatListFrom(req.sync_from)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {Object} req
|
* @param {Object} req
|
||||||
* @param {string} req.mxid
|
* @param {string} req.mxid
|
||||||
|
@ -421,6 +432,7 @@ export default class PeerClient {
|
||||||
register_device: this.registerDevice,
|
register_device: this.registerDevice,
|
||||||
get_own_profile: this.getOwnProfile,
|
get_own_profile: this.getOwnProfile,
|
||||||
get_portal_channel_info: this.getPortalChannelInfo,
|
get_portal_channel_info: this.getPortalChannelInfo,
|
||||||
|
get_chats: this.getChats,
|
||||||
get_profile: this.getProfile,
|
get_profile: this.getProfile,
|
||||||
/*
|
/*
|
||||||
send: req => this.puppet.sendMessage(req.chat_id, req.text),
|
send: req => this.puppet.sendMessage(req.chat_id, req.text),
|
||||||
|
|
Loading…
Reference in New Issue