Compare commits

...

4 Commits

18 changed files with 322 additions and 120 deletions

View File

@ -128,6 +128,7 @@ async def enter_dv_code(evt: CommandEvent) -> None:
assert(evt.sender.command_status) assert(evt.sender.command_status)
req: dict = evt.sender.command_status["req"] req: dict = evt.sender.command_status["req"]
passcode = evt.content.body passcode = evt.content.body
await evt.mark_read()
try: try:
await KakaoTalkClient.register_device(passcode, **req) await KakaoTalkClient.register_device(passcode, **req)
await _do_login(evt, req) await _do_login(evt, req)
@ -191,6 +192,7 @@ async def reset_device(evt: CommandEvent) -> None:
if await evt.sender.is_logged_in(): if await evt.sender.is_logged_in():
await evt.reply("This command requires you to be logged out.") await evt.reply("This command requires you to be logged out.")
else: else:
await evt.mark_read()
await evt.sender.logout(reset_device=True) await evt.sender.logout(reset_device=True)
await evt.reply( await evt.reply(
"Your next login will use a different device ID.\n\n" "Your next login will use a different device ID.\n\n"

View File

@ -71,6 +71,7 @@ async def ping(evt: CommandEvent) -> None:
if not await evt.sender.is_logged_in(): if not await evt.sender.is_logged_in():
await evt.reply("You're not logged into KakaoTalk") await evt.reply("You're not logged into KakaoTalk")
return return
await evt.mark_read()
# try: # try:
own_info = await evt.sender.get_own_info() own_info = await evt.sender.get_own_info()
# TODO catch errors # TODO catch errors
@ -99,3 +100,15 @@ async def ping(evt: CommandEvent) -> None:
async def refresh(evt: CommandEvent) -> None: async def refresh(evt: CommandEvent) -> None:
await evt.sender.refresh(force_notice=True) await evt.sender.refresh(force_notice=True)
""" """
@command_handler(
needs_auth=True,
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Resync chats",
)
async def sync(evt: CommandEvent) -> None:
await evt.mark_read()
await evt.sender.post_login(is_startup=False)
await evt.reply("Sync complete")

View File

@ -23,6 +23,8 @@ from attr import dataclass
from mautrix.types import EventID, RoomID from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database from mautrix.util.async_db import Database
from ..kt.types.bson import Long, StrLong
fake_db = Database.create("") if TYPE_CHECKING else None fake_db = Database.create("") if TYPE_CHECKING else None
@ -30,46 +32,44 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Message: class Message:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
# TODO Store all Long values as the same type
mxid: EventID mxid: EventID
mx_room: RoomID mx_room: RoomID
ktid: str | None ktid: Long
kt_txn_id: int | None
index: int index: int
kt_chat: int kt_chat: Long
kt_receiver: int kt_receiver: Long
kt_sender: int
timestamp: int timestamp: int
@classmethod @classmethod
def _from_row(cls, row: Record | None) -> Message | None: def _from_row(cls, row: Record | None) -> Message | None:
if row is None: data = {**row}
return None ktid = data.pop("ktid")
return cls(**row) kt_chat = data.pop("kt_chat")
kt_receiver = data.pop("kt_receiver")
columns = 'mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, kt_sender, timestamp' return cls(
**data,
ktid=StrLong(ktid),
kt_chat=Long.from_bytes(kt_chat),
kt_receiver=Long.from_bytes(kt_receiver)
)
@classmethod @classmethod
async def get_all_by_ktid(cls, ktid: str, kt_receiver: int) -> list[Message]: def _from_optional_row(cls, row: Record | None) -> Message | None:
return cls._from_row(row) if row is not None else None
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
@classmethod
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2" q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
rows = await cls.db.fetch(q, ktid, kt_receiver) rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
return [cls._from_row(row) for row in rows] return [cls._from_row(row) for row in rows]
@classmethod @classmethod
async def get_by_ktid(cls, ktid: str, kt_receiver: int, index: int = 0) -> Message | None: async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
row = await cls.db.fetchrow(q, ktid, kt_receiver, index) row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
return cls._from_row(row)
@classmethod
async def get_by_ktid_or_oti(
cls, ktid: str, oti: int, kt_receiver: int, kt_sender: int, index: int = 0
) -> Message | None:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE (ktid=$1 OR (kt_txn_id=$2 AND kt_sender=$3)) AND "
' kt_receiver=$4 AND "index"=$5'
)
row = await cls.db.fetchrow(q, ktid, oti, kt_sender, kt_receiver, index)
return cls._from_row(row) return cls._from_row(row)
@classmethod @classmethod
@ -83,18 +83,18 @@ class Message:
return cls._from_row(row) return cls._from_row(row)
@classmethod @classmethod
async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None: async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
q = ( q = (
f"SELECT {cls.columns} " f"SELECT {cls.columns} "
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL " "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1" "ORDER BY timestamp DESC LIMIT 1"
) )
row = await cls.db.fetchrow(q, kt_chat, kt_receiver) row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
return cls._from_row(row) return cls._from_row(row)
@classmethod @classmethod
async def get_closest_before( async def get_closest_before(
cls, kt_chat: int, kt_receiver: int, timestamp: int cls, kt_chat: Long, kt_receiver: Long, timestamp: int
) -> Message | None: ) -> Message | None:
q = ( q = (
f"SELECT {cls.columns} " f"SELECT {cls.columns} "
@ -102,23 +102,21 @@ class Message:
" ktid IS NOT NULL " " ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1" "ORDER BY timestamp DESC LIMIT 1"
) )
row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp) row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
return cls._from_row(row) return cls._from_row(row)
_insert_query = ( _insert_query = (
'INSERT INTO message (mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, ' 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
" kt_sender, timestamp) " " timestamp) "
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" "VALUES ($1, $2, $3, $4, $5, $6, $7)"
) )
@classmethod @classmethod
async def bulk_create( async def bulk_create(
cls, cls,
ktid: str, ktid: Long,
oti: int, kt_chat: Long,
kt_chat: int, kt_receiver: Long,
kt_receiver: int,
kt_sender: int,
event_ids: list[EventID], event_ids: list[EventID],
timestamp: int, timestamp: int,
mx_room: RoomID, mx_room: RoomID,
@ -127,7 +125,7 @@ class Message:
return return
columns = [col.strip('"') for col in cls.columns.split(", ")] columns = [col.strip('"') for col in cls.columns.split(", ")]
records = [ records = [
(mxid, mx_room, ktid, oti, index, kt_chat, kt_receiver, kt_sender, timestamp) (mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
for index, mxid in enumerate(event_ids) for index, mxid in enumerate(event_ids)
] ]
async with cls.db.acquire() as conn, conn.transaction(): async with cls.db.acquire() as conn, conn.transaction():
@ -142,19 +140,17 @@ class Message:
q, q,
self.mxid, self.mxid,
self.mx_room, self.mx_room,
self.ktid, str(self.ktid),
self.kt_txn_id,
self.index, self.index,
self.kt_chat, bytes(self.kt_chat),
self.kt_receiver, bytes(self.kt_receiver),
self.kt_sender,
self.timestamp, self.timestamp,
) )
async def delete(self) -> None: async def delete(self) -> None:
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
await self.db.execute(q, self.ktid, self.kt_receiver, self.index) await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index)
async def update(self) -> None: async def update(self) -> None:
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4" q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room) await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room)

View File

@ -50,7 +50,7 @@ class Portal:
data = {**row} data = {**row}
ktid = data.pop("ktid") ktid = data.pop("ktid")
kt_receiver = data.pop("kt_receiver") kt_receiver = data.pop("kt_receiver")
return cls(**data, ktid=Long.from_optional_bytes(ktid), kt_receiver=Long.from_optional_bytes(kt_receiver)) return cls(**data, ktid=Long.from_bytes(ktid), kt_receiver=Long.from_bytes(kt_receiver))
@classmethod @classmethod
def _from_optional_row(cls, row: Record | None) -> Portal | None: def _from_optional_row(cls, row: Record | None) -> Portal | None:

View File

@ -186,7 +186,6 @@ bridge:
# Whether or not the KakaoTalk users of logged in Matrix users should be # Whether or not the KakaoTalk users of logged in Matrix users should be
# invited to private chats when backfilling history from KakaoTalk. This is # invited to private chats when backfilling history from KakaoTalk. This is
# usually needed to prevent rate limits and to allow timestamp massaging. # usually needed to prevent rate limits and to allow timestamp massaging.
# TODO Is this necessary?
invite_own_puppet: true invite_own_puppet: true
# Maximum number of messages to backfill initially. # Maximum number of messages to backfill initially.
# Set to 0 to disable backfilling when creating portal. # Set to 0 to disable backfilling when creating portal.

View File

@ -43,7 +43,11 @@ from ..types.chat.chat import Chatlog
from ..types.oauth import OAuthCredential, OAuthInfo from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.request import ( from ..types.request import (
deserialize_result, deserialize_result,
ResultType, RootCommandResult, CommandResultDoneValue) ResultType,
ResultListType,
RootCommandResult,
CommandResultDoneValue
)
from .types import ChannelInfoUnion from .types import ChannelInfoUnion
from .types import PortalChannelInfo from .types import PortalChannelInfo
@ -155,7 +159,8 @@ class Client:
) -> _RequestContextManager: ) -> _RequestContextManager:
# TODO Is auth ever needed? # TODO Is auth ever needed?
headers = { headers = {
**self._headers, # TODO Are any default headers needed?
#**self._headers,
**(headers or {}), **(headers or {}),
} }
url = URL(url) url = URL(url)
@ -185,16 +190,24 @@ class Client:
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
return login_result return login_result
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile
""" """
async def is_connected(self) -> bool: async def is_connected(self) -> bool:
resp = await self._rpc_client.request("is_connected") resp = await self._rpc_client.request("is_connected")
return resp["is_connected"] return resp["is_connected"]
""" """
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile
async def get_profile(self, user_id: Long) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(
ProfileReqStruct,
"get_profile",
user_id=user_id.serialize()
)
return profile_req_struct.profile
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo: async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo:
req = await self._api_user_request_result( req = await self._api_user_request_result(
PortalChannelInfo, PortalChannelInfo,
@ -204,13 +217,13 @@ class Client:
req.channel_info = channel_info req.channel_info = channel_info
return req return req
async def get_profile(self, user_id: Long) -> ProfileStruct: async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
profile_req_struct = await self._api_user_request_result( return (await self._api_user_request_result(
ProfileReqStruct, ResultListType(Chatlog),
"get_profile", "get_chats",
user_id=user_id.serialize() channel_id=channel_id.serialize(),
) sync_from=sync_from.serialize() if sync_from else None
return profile_req_struct.profile ))[-limit if limit else 0:]
async def stop(self) -> None: async def stop(self) -> None:
# TODO Stop all event handlers # TODO Stop all event handlers

View File

@ -15,23 +15,44 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Custom wrapper classes around types defined by the KakaoTalk API.""" """Custom wrapper classes around types defined by the KakaoTalk API."""
from typing import Optional, Union from typing import Optional, NewType, Union
from attr import dataclass from attr import dataclass
from mautrix.types import SerializableAttrs from mautrix.types import SerializableAttrs, JSON, deserializer
from ..types.channel.channel_info import NormalChannelInfo from ..types.channel.channel_info import NormalChannelInfo
from ..types.openlink.open_channel_info import OpenChannelInfo from ..types.openlink.open_channel_info import OpenChannelInfo
from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo
ChannelInfoUnion = Union[NormalChannelInfo, OpenChannelInfo] ChannelInfoUnion = NewType("ChannelInfoUnion", Union[NormalChannelInfo, OpenChannelInfo])
UserInfoUnion = Union[NormalChannelUserInfo, OpenChannelUserInfo]
@deserializer(ChannelInfoUnion)
def deserialize_channel_info_union(data: JSON) -> ChannelInfoUnion:
if "openLink" in data:
return OpenChannelInfo.deserialize(data)
else:
return NormalChannelInfo.deserialize(data)
setattr(ChannelInfoUnion, "deserialize", deserialize_channel_info_union)
UserInfoUnion = NewType("UserInfoUnion", Union[NormalChannelUserInfo, OpenChannelUserInfo])
@deserializer(UserInfoUnion)
def deserialize_user_info_union(data: JSON) -> UserInfoUnion:
if "perm" in data:
return OpenChannelUserInfo.deserialize(data)
else:
return NormalChannelUserInfo.deserialize(data)
setattr(UserInfoUnion, "deserialize", deserialize_user_info_union)
@dataclass @dataclass
class PortalChannelInfo(SerializableAttrs): class PortalChannelInfo(SerializableAttrs):
name: str name: str
#participants: list[PuppetUserInfo] participants: list[UserInfoUnion]
# TODO Image # TODO Image
channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller

View File

@ -13,6 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from attr import dataclass from attr import dataclass
from mautrix.types import SerializableAttrs from mautrix.types import SerializableAttrs
@ -23,9 +25,9 @@ from .mention import MentionStruct
@dataclass(kw_only=True) @dataclass(kw_only=True)
class Attachment(SerializableAttrs): class Attachment(SerializableAttrs):
shout: bool | None = None shout: Optional[bool] = None
mentions: list[MentionStruct] | None = None mentions: Optional[list[MentionStruct]] = None
urls: list[str] | None = None urls: Optional[list[str]] = None
@dataclass @dataclass

View File

@ -32,11 +32,11 @@ class Long(SerializableAttrs):
return cls(**bson.loads(raw)) return cls(**bson.loads(raw))
@classmethod @classmethod
def from_optional_bytes(cls, raw: bytes | None) -> Optional["Long"]: def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
return cls(**bson.loads(raw)) if raw is not None else None return cls(**bson.loads(raw)) if raw is not None else None
@classmethod @classmethod
def to_optional_bytes(cls, value: Optional["Long"]) -> bytes | None: def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
return bytes(value) if value is not None else None return bytes(value) if value is not None else None
def serialize(self) -> JSON: def serialize(self) -> JSON:
@ -48,13 +48,34 @@ class Long(SerializableAttrs):
return bson.dumps(asdict(self)) return bson.dumps(asdict(self))
def __int__(self) -> int: def __int__(self) -> int:
# TODO Is this right? if self.unsigned:
return self.high << 32 + self.low pass
result = \
((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
( self.low + (1 << 32 if self.low < 0 else 0))
return result + (1 << 32 if self.unsigned and result < 0 else 0)
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.high << 32 if self.high else ''}{self.low}" return str(int(self))
ZERO: ClassVar["Long"] ZERO: ClassVar["Long"]
Long.ZERO = Long(0, 0, False) Long.ZERO = Long(0, 0, False)
class IntLong(Long):
"""Helper class for constructing a Long from an int."""
def __init__(self, val: int):
if val < 0:
pass
super().__init__(
high=(val & 0xffffffff00000000) >> 32,
low = val & 0x00000000ffffffff,
unsigned=val < 0,
)
class StrLong(IntLong):
"""Helper class for constructing a Long from the string representation of an int."""
def __init__(self, val: str):
super().__init__(int(val))

View File

@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Generic, TypeVar from typing import Generic, TypeVar, Optional
from attr import dataclass from attr import dataclass
@ -48,7 +48,7 @@ class ChannelInfo(Channel):
newChatCountInvalid: bool newChatCountInvalid: bool
lastChatLogId: Long lastChatLogId: Long
lastSeenLogId: Long lastSeenLogId: Long
lastChatLog: Chatlog | None = None lastChatLog: Optional[Chatlog] = None
metaMap: ChannelMetaMap metaMap: ChannelMetaMap
displayUserList: list[DisplayUserInfo] displayUserList: list[DisplayUserInfo]
pushAlert: bool pushAlert: bool

View File

@ -23,12 +23,12 @@ class KnownChannelType(str, Enum):
DirectChat = "DirectChat" DirectChat = "DirectChat"
PlusChat = "PlusChat" PlusChat = "PlusChat"
MemoChat = "MemoChat" MemoChat = "MemoChat"
OM = "OM" OM = "OM" # "OpenMulti"?
OD = "OD" OD = "OD" # "OpenDirect"?
@classmethod @classmethod
def is_direct(cls, value: Union["KnownChannelType", str]) -> bool: def is_direct(cls, value: Union["KnownChannelType", str]) -> bool:
return str in [KnownChannelType.DirectChat, KnownChannelType.OD] return value == KnownChannelType.DirectChat
ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str

View File

@ -13,6 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from attr import dataclass from attr import dataclass
from mautrix.types import JSON from mautrix.types import JSON
@ -25,7 +27,7 @@ from . import OpenTokenComponent, OpenLink
@dataclass(kw_only=True) @dataclass(kw_only=True)
class OpenChannelInfo(ChannelInfo, OpenChannel, OpenTokenComponent): class OpenChannelInfo(ChannelInfo, OpenChannel, OpenTokenComponent):
directChannel: bool directChannel: bool
openLink: OpenLink | None = None openLink: Optional[OpenLink] = None
@dataclass @dataclass

View File

@ -13,12 +13,12 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Generic, Type, TypeVar, Union from typing import Generic, Type, TypeVar, Union, Iterable
from attr import dataclass from attr import dataclass
from enum import IntEnum from enum import IntEnum
from mautrix.types import SerializableAttrs, JSON from mautrix.types import Serializable, SerializableAttrs, JSON
from .api.auth_api_client import KnownAuthStatusCode from .api.auth_api_client import KnownAuthStatusCode
@ -80,7 +80,22 @@ class RootCommandResult(ResponseState):
success: bool success: bool
ResultType = TypeVar("ResultType", bound=SerializableAttrs) ResultType = TypeVar("ResultType", bound=Serializable)
def ResultListType(result_type: Type[ResultType]):
class _ResultListType(list[result_type], Serializable):
def __init__(self, iterable: Iterable[result_type]=()):
list.__init__(self, (result_type.deserialize(x) for x in iterable))
def serialize(self) -> list[JSON]:
return [v.serialize() for v in self]
@classmethod
def deserialize(cls, data: list[JSON]) -> "_ResultListType":
return cls(data)
return _ResultListType
@dataclass @dataclass
class CommandResultDoneValue(RootCommandResult, Generic[ResultType]): class CommandResultDoneValue(RootCommandResult, Generic[ResultType]):

View File

@ -13,6 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from attr import dataclass from attr import dataclass
from mautrix.types import SerializableAttrs from mautrix.types import SerializableAttrs
@ -28,7 +30,7 @@ class ChannelUser(SerializableAttrs):
@dataclass(kw_only=True) @dataclass(kw_only=True)
class PartialOpenLinkComponent(SerializableAttrs): class PartialOpenLinkComponent(SerializableAttrs):
"""Substitute for Partial<OpenLinkComponent>""" """Substitute for Partial<OpenLinkComponent>"""
linkId: Long | None = None linkId: Optional[Long] = None
@dataclass @dataclass

View File

@ -45,12 +45,14 @@ from .db import (
Message as DBMessage, Message as DBMessage,
Portal as DBPortal, Portal as DBPortal,
) )
from .formatter.from_kakaotalk import kakaotalk_to_matrix
from .kt.types.bson import Long from .kt.types.bson import Long, IntLong
from .kt.types.channel.channel_info import ChannelInfo
from .kt.types.channel.channel_type import KnownChannelType, ChannelType from .kt.types.channel.channel_type import KnownChannelType, ChannelType
from .kt.types.user.channel_user_info import DisplayUserInfo from .kt.types.chat.chat import Chatlog
from .kt.client.types import PortalChannelInfo from .kt.client.types import UserInfoUnion, PortalChannelInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import KakaoTalkBridge from .__main__ import KakaoTalkBridge
@ -201,7 +203,7 @@ class Portal(DBPortal, BasePortal):
#self._update_photo(source, info.image), #self._update_photo(source, info.image),
) )
) )
changed = await self._update_participants(source, info.channel_info.displayUserList) or changed changed = await self._update_participants(source, info.participants) or changed
if changed or force_save: if changed or force_save:
await self.update_bridge_info() await self.update_bridge_info()
await self.save() await self.save()
@ -342,7 +344,7 @@ class Portal(DBPortal, BasePortal):
) )
""" """
async def _update_participants(self, source: u.User, participants: list[DisplayUserInfo]) -> bool: async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool:
changed = False changed = False
# TODO nick_map? # TODO nick_map?
for participant in participants: for participant in participants:
@ -556,10 +558,10 @@ class Portal(DBPortal, BasePortal):
) )
if not self.is_direct: if not self.is_direct:
await self._update_participants(source, info.channel_info.displayUserList) await self._update_participants(source, info.participants)
try: try:
await self.backfill(source, is_initial=True, channel=info.channel_info) await self.backfill(source, is_initial=True, channel_info=info.channel_info)
except Exception: except Exception:
self.log.exception("Failed to backfill new portal") self.log.exception("Failed to backfill new portal")
@ -752,31 +754,131 @@ class Portal(DBPortal, BasePortal):
self, self,
source: u.User, source: u.User,
sender: p.Puppet, sender: p.Puppet,
message: str, message: Chatlog,
reply_to: None = None, reply_to: Chatlog | None = None,
) -> None: ) -> None:
try: try:
await self._handle_remote_message(source, sender, message, reply_to) await self._handle_remote_message(source, sender, message, reply_to)
except Exception: except Exception:
self.log.exception( self.log.exception(
"Error handling Kakaotalk message <TODO: ID>" "Error handling KakaoTalk message %s",
message.logId,
) )
async def _handle_remote_message( async def _handle_remote_message(
self, self,
source: u.User, source: u.User,
sender: p.Puppet, sender: p.Puppet,
message: str, message: Chatlog,
reply_to: None = None, reply_to: Chatlog | None = None,
) -> None: ) -> None:
self.log.info("TODO") self.log.debug(f"Handling KakaoTalk event {message.logId}")
if not self.mxid:
mxid = await self.create_matrix_room(source)
if not mxid:
# Failed to create
return
if not await self._bridge_own_message_pm(source, sender, f"message {message.logId}"):
return
intent = sender.intent_for(self)
if (
self._backfill_leave is not None
and self.ktid != sender.ktid
and intent != sender.intent
and intent not in self._backfill_leave
):
self.log.debug("Adding %s's default puppet to room for backfilling", sender.mxid)
await self.main_intent.invite_user(self.mxid, intent.mxid)
await intent.ensure_joined(self.mxid)
self._backfill_leave.add(intent)
if message.attachment:
self.log.info("TODO: _handle_remote_message attachments")
if message.supplement:
self.log.info("TODO: _handle_remote_message supplements")
if message.text:
content = await kakaotalk_to_matrix(message.text)
event_id = await self._send_message(intent, content, timestamp=message.sendAt)
await DBMessage(
mxid=event_id,
mx_room=self.mxid,
ktid=message.logId,
index=0,
kt_chat=self.ktid,
kt_receiver=self.kt_receiver,
timestamp=message.sendAt,
).insert()
await self._send_delivery_receipt(event_id)
else:
self.log.warning(f"Unhandled KakaoTalk message {message.logId}")
# TODO Many more remote handlers # TODO Many more remote handlers
# endregion # endregion
async def backfill(self, source: u.User, is_initial: bool, channel: PortalChannelInfo) -> None: async def backfill(self, source: u.User, is_initial: bool, channel_info: ChannelInfo) -> None:
self.log.info("TODO: backfill") limit = (
self.config["bridge.backfill.initial_limit"]
if is_initial
else self.config["bridge.backfill.missed_limit"]
)
if limit == 0:
return
elif limit < 0:
limit = None
last_log_id = None
if not is_initial and channel_info.lastChatLog:
last_log_id = channel_info.lastChatLog.logId
most_recent = await DBMessage.get_most_recent(self.ktid, self.kt_receiver)
if most_recent and is_initial:
self.log.debug("Not backfilling %s: already bridged messages found", self.ktid_log)
# TODO Should this be removed? With it, can't sync an empty portal!
#elif (not most_recent or not most_recent.timestamp) and not is_initial:
# self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log)
elif last_log_id and most_recent and int(most_recent.ktid) >= int(last_log_id):
self.log.debug(
"Not backfilling %s: last activity is equal to most recent bridged "
"message (%s >= %s)",
self.ktid_log,
most_recent.ktid,
last_log_id,
)
else:
with self.backfill_lock:
await self._backfill(
source,
limit,
most_recent.ktid if most_recent else None,
channel_info=channel_info,
)
async def _backfill(
self,
source: u.User,
limit: int | None,
after_log_id: Long | None,
channel_info: ChannelInfo,
) -> None:
self.log.debug("Backfilling history through %s", source.mxid)
self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid))
messages = await source.client.get_chats(
channel_info.channelId,
limit,
after_log_id
)
if not messages:
self.log.debug("Didn't get any messages from server")
return
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server")
self._backfill_leave = set()
async with NotificationDisabler(self.mxid, source):
for message in messages:
puppet = await p.Puppet.get_by_ktid(message.sender.userId)
await self.handle_remote_message(source, puppet, message)
for intent in self._backfill_leave:
self.log.trace("Leaving room with %s post-backfill", intent.mxid)
await intent.leave_room(self.mxid)
self.log.info("Backfilled %d messages through %s", len(messages), source.mxid)
# region Database getters # region Database getters

View File

@ -31,8 +31,9 @@ from . import matrix as m, portal as p, user as u
from .config import Config from .config import Config
from .db import Puppet as DBPuppet from .db import Puppet as DBPuppet
from .kt.types.bson import Long from .kt.types.bson import Long, StrLong
from .kt.types.user.channel_user_info import DisplayUserInfo
from .kt.client.types import UserInfoUnion
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import KakaoTalkBridge from .__main__ import KakaoTalkBridge
@ -42,7 +43,7 @@ class Puppet(DBPuppet, BasePuppet):
mx: m.MatrixHandler mx: m.MatrixHandler
config: Config config: Config
hs_domain: str hs_domain: str
mxid_template: SimpleTemplate[int] mxid_template: SimpleTemplate[StrLong]
by_ktid: dict[Long, Puppet] = {} by_ktid: dict[Long, Puppet] = {}
by_custom_mxid: dict[UserID, Puppet] = {} by_custom_mxid: dict[UserID, Puppet] = {}
@ -126,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
keyword="userid", keyword="userid",
prefix="@", prefix="@",
suffix=f":{Puppet.hs_domain}", suffix=f":{Puppet.hs_domain}",
type=int, type=StrLong,
) )
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"] cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
cls.homeserver_url_map = { cls.homeserver_url_map = {
@ -147,7 +148,7 @@ class Puppet(DBPuppet, BasePuppet):
async def update_info( async def update_info(
self, self,
source: u.User, source: u.User,
info: DisplayUserInfo, info: UserInfoUnion,
update_avatar: bool = True, update_avatar: bool = True,
) -> Puppet: ) -> Puppet:
self._last_info_sync = datetime.now() self._last_info_sync = datetime.now()
@ -161,7 +162,7 @@ class Puppet(DBPuppet, BasePuppet):
self.log.exception(f"Failed to update info from source {source.ktid}") self.log.exception(f"Failed to update info from source {source.ktid}")
return self return self
async def _update_name(self, info: DisplayUserInfo) -> bool: async def _update_name(self, info: UserInfoUnion) -> bool:
name = info.nickname name = info.nickname
if name != self.name or not self.name_set: if name != self.name or not self.name_set:
self.name = name self.name = name
@ -259,7 +260,7 @@ class Puppet(DBPuppet, BasePuppet):
return None return None
@classmethod @classmethod
def get_id_from_mxid(cls, mxid: UserID) -> int | None: def get_id_from_mxid(cls, mxid: UserID) -> Long | None:
return cls.mxid_template.parse(mxid) return cls.mxid_template.parse(mxid)
@classmethod @classmethod

View File

@ -415,6 +415,7 @@ class User(DBUser, BaseUser):
assert self.client assert self.client
try: try:
# TODO if not is_startup, close existing listeners
login_result = await self.client.start() login_result = await self.client.start()
await self._sync_channels(login_result, is_startup) await self._sync_channels(login_result, is_startup)
# TODO connect listeners, even if channel sync fails (except if it's an auth failure) # TODO connect listeners, even if channel sync fails (except if it's an auth failure)
@ -502,7 +503,7 @@ class User(DBUser, BaseUser):
await portal.create_matrix_room(self, portal_info) await portal.create_matrix_room(self, portal_info)
else: else:
await portal.update_matrix_room(self, portal_info) await portal.update_matrix_room(self, portal_info)
await portal.backfill(self, is_initial=False, channel=channel_info) await portal.backfill(self, is_initial=False, channel_info=channel_info)
async def get_notice_room(self) -> RoomID: async def get_notice_room(self) -> RoomID:
if not self.notice_room: if not self.notice_room:

View File

@ -202,6 +202,14 @@ export default class PeerClient {
return loginRes return loginRes
} }
/**
* TODO Consider caching per-user
* @param {string} uuid
*/
async #createAuthClient(uuid) {
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
}
// TODO Wrapper for per-user commands // TODO Wrapper for per-user commands
/** /**
@ -237,14 +245,6 @@ export default class PeerClient {
return await oAuthClient.renew(req.oauth_credential) return await oAuthClient.renew(req.oauth_credential)
} }
/**
* TODO Consider caching per-user
* @param {string} uuid
*/
async #createAuthClient(uuid) {
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
}
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
@ -314,22 +314,33 @@ export default class PeerClient {
* @param {string} req.mxid * @param {string} req.mxid
* @param {Long} req.channel_id * @param {Long} req.channel_id
*/ */
getPortalChannelInfo = (req) => { getPortalChannelInfo = async (req) => {
const userClient = this.#getUser(req.mxid) const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id) const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
/* TODO Decide if this is needed. If it is, make function async!
const res = await talkChannel.updateAll() const res = await talkChannel.updateAll()
if (!res.success) return res if (!res.success) return res
*/
return this.#makeCommandResult({ return this.#makeCommandResult({
name: talkChannel.getDisplayName(), name: talkChannel.getDisplayName(),
//participants: Array.from(talkChannel.getAllUserInfo()), participants: Array.from(talkChannel.getAllUserInfo()),
// TODO Image // TODO Image
}) })
} }
/**
* @param {Object} req
* @param {string} req.mxid
* @param {Long} req.channel_id
* @param {Long?} req.sync_from
*/
getChats = async (req) => {
const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
return await talkChannel.getChatListFrom(req.sync_from)
}
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
@ -421,6 +432,7 @@ export default class PeerClient {
register_device: this.registerDevice, register_device: this.registerDevice,
get_own_profile: this.getOwnProfile, get_own_profile: this.getOwnProfile,
get_portal_channel_info: this.getPortalChannelInfo, get_portal_channel_info: this.getPortalChannelInfo,
get_chats: this.getChats,
get_profile: this.getProfile, get_profile: this.getProfile,
/* /*
send: req => this.puppet.sendMessage(req.chat_id, req.text), send: req => this.puppet.sendMessage(req.chat_id, req.text),