Compare commits

...

4 Commits

18 changed files with 322 additions and 120 deletions

View File

@ -128,6 +128,7 @@ async def enter_dv_code(evt: CommandEvent) -> None:
assert(evt.sender.command_status)
req: dict = evt.sender.command_status["req"]
passcode = evt.content.body
await evt.mark_read()
try:
await KakaoTalkClient.register_device(passcode, **req)
await _do_login(evt, req)
@ -191,6 +192,7 @@ async def reset_device(evt: CommandEvent) -> None:
if await evt.sender.is_logged_in():
await evt.reply("This command requires you to be logged out.")
else:
await evt.mark_read()
await evt.sender.logout(reset_device=True)
await evt.reply(
"Your next login will use a different device ID.\n\n"

View File

@ -71,6 +71,7 @@ async def ping(evt: CommandEvent) -> None:
if not await evt.sender.is_logged_in():
await evt.reply("You're not logged into KakaoTalk")
return
await evt.mark_read()
# try:
own_info = await evt.sender.get_own_info()
# TODO catch errors
@ -99,3 +100,15 @@ async def ping(evt: CommandEvent) -> None:
async def refresh(evt: CommandEvent) -> None:
await evt.sender.refresh(force_notice=True)
"""
@command_handler(
needs_auth=True,
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Resync chats",
)
async def sync(evt: CommandEvent) -> None:
await evt.mark_read()
await evt.sender.post_login(is_startup=False)
await evt.reply("Sync complete")

View File

@ -23,6 +23,8 @@ from attr import dataclass
from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database
from ..kt.types.bson import Long, StrLong
fake_db = Database.create("") if TYPE_CHECKING else None
@ -30,46 +32,44 @@ fake_db = Database.create("") if TYPE_CHECKING else None
class Message:
db: ClassVar[Database] = fake_db
# TODO Store all Long values as the same type
mxid: EventID
mx_room: RoomID
ktid: str | None
kt_txn_id: int | None
ktid: Long
index: int
kt_chat: int
kt_receiver: int
kt_sender: int
kt_chat: Long
kt_receiver: Long
timestamp: int
@classmethod
def _from_row(cls, row: Record | None) -> Message | None:
if row is None:
return None
return cls(**row)
columns = 'mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, kt_sender, timestamp'
data = {**row}
ktid = data.pop("ktid")
kt_chat = data.pop("kt_chat")
kt_receiver = data.pop("kt_receiver")
return cls(
**data,
ktid=StrLong(ktid),
kt_chat=Long.from_bytes(kt_chat),
kt_receiver=Long.from_bytes(kt_receiver)
)
@classmethod
async def get_all_by_ktid(cls, ktid: str, kt_receiver: int) -> list[Message]:
def _from_optional_row(cls, row: Record | None) -> Message | None:
return cls._from_row(row) if row is not None else None
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
@classmethod
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
rows = await cls.db.fetch(q, ktid, kt_receiver)
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
return [cls._from_row(row) for row in rows]
@classmethod
async def get_by_ktid(cls, ktid: str, kt_receiver: int, index: int = 0) -> Message | None:
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
return cls._from_row(row)
@classmethod
async def get_by_ktid_or_oti(
cls, ktid: str, oti: int, kt_receiver: int, kt_sender: int, index: int = 0
) -> Message | None:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE (ktid=$1 OR (kt_txn_id=$2 AND kt_sender=$3)) AND "
' kt_receiver=$4 AND "index"=$5'
)
row = await cls.db.fetchrow(q, ktid, oti, kt_sender, kt_receiver, index)
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
return cls._from_row(row)
@classmethod
@ -83,18 +83,18 @@ class Message:
return cls._from_row(row)
@classmethod
async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
q = (
f"SELECT {cls.columns} "
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1"
)
row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
return cls._from_row(row)
@classmethod
async def get_closest_before(
cls, kt_chat: int, kt_receiver: int, timestamp: int
cls, kt_chat: Long, kt_receiver: Long, timestamp: int
) -> Message | None:
q = (
f"SELECT {cls.columns} "
@ -102,23 +102,21 @@ class Message:
" ktid IS NOT NULL "
"ORDER BY timestamp DESC LIMIT 1"
)
row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
return cls._from_row(row)
_insert_query = (
'INSERT INTO message (mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, '
" kt_sender, timestamp) "
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
" timestamp) "
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
@classmethod
async def bulk_create(
cls,
ktid: str,
oti: int,
kt_chat: int,
kt_receiver: int,
kt_sender: int,
ktid: Long,
kt_chat: Long,
kt_receiver: Long,
event_ids: list[EventID],
timestamp: int,
mx_room: RoomID,
@ -127,7 +125,7 @@ class Message:
return
columns = [col.strip('"') for col in cls.columns.split(", ")]
records = [
(mxid, mx_room, ktid, oti, index, kt_chat, kt_receiver, kt_sender, timestamp)
(mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
for index, mxid in enumerate(event_ids)
]
async with cls.db.acquire() as conn, conn.transaction():
@ -142,19 +140,17 @@ class Message:
q,
self.mxid,
self.mx_room,
self.ktid,
self.kt_txn_id,
str(self.ktid),
self.index,
self.kt_chat,
self.kt_receiver,
self.kt_sender,
bytes(self.kt_chat),
bytes(self.kt_receiver),
self.timestamp,
)
async def delete(self) -> None:
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
await self.db.execute(q, self.ktid, self.kt_receiver, self.index)
await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index)
async def update(self) -> None:
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)
await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room)

View File

@ -50,7 +50,7 @@ class Portal:
data = {**row}
ktid = data.pop("ktid")
kt_receiver = data.pop("kt_receiver")
return cls(**data, ktid=Long.from_optional_bytes(ktid), kt_receiver=Long.from_optional_bytes(kt_receiver))
return cls(**data, ktid=Long.from_bytes(ktid), kt_receiver=Long.from_bytes(kt_receiver))
@classmethod
def _from_optional_row(cls, row: Record | None) -> Portal | None:

View File

@ -186,7 +186,6 @@ bridge:
# Whether or not the KakaoTalk users of logged in Matrix users should be
# invited to private chats when backfilling history from KakaoTalk. This is
# usually needed to prevent rate limits and to allow timestamp massaging.
# TODO Is this necessary?
invite_own_puppet: true
# Maximum number of messages to backfill initially.
# Set to 0 to disable backfilling when creating portal.

View File

@ -43,7 +43,11 @@ from ..types.chat.chat import Chatlog
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.request import (
deserialize_result,
ResultType, RootCommandResult, CommandResultDoneValue)
ResultType,
ResultListType,
RootCommandResult,
CommandResultDoneValue
)
from .types import ChannelInfoUnion
from .types import PortalChannelInfo
@ -155,7 +159,8 @@ class Client:
) -> _RequestContextManager:
# TODO Is auth ever needed?
headers = {
**self._headers,
# TODO Are any default headers needed?
#**self._headers,
**(headers or {}),
}
url = URL(url)
@ -185,16 +190,24 @@ class Client:
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
return login_result
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile
"""
async def is_connected(self) -> bool:
resp = await self._rpc_client.request("is_connected")
return resp["is_connected"]
"""
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile
async def get_profile(self, user_id: Long) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(
ProfileReqStruct,
"get_profile",
user_id=user_id.serialize()
)
return profile_req_struct.profile
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo:
req = await self._api_user_request_result(
PortalChannelInfo,
@ -204,13 +217,13 @@ class Client:
req.channel_info = channel_info
return req
async def get_profile(self, user_id: Long) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(
ProfileReqStruct,
"get_profile",
user_id=user_id.serialize()
)
return profile_req_struct.profile
async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
return (await self._api_user_request_result(
ResultListType(Chatlog),
"get_chats",
channel_id=channel_id.serialize(),
sync_from=sync_from.serialize() if sync_from else None
))[-limit if limit else 0:]
async def stop(self) -> None:
# TODO Stop all event handlers

View File

@ -15,23 +15,44 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Custom wrapper classes around types defined by the KakaoTalk API."""
from typing import Optional, Union
from typing import Optional, NewType, Union
from attr import dataclass
from mautrix.types import SerializableAttrs
from mautrix.types import SerializableAttrs, JSON, deserializer
from ..types.channel.channel_info import NormalChannelInfo
from ..types.openlink.open_channel_info import OpenChannelInfo
from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo
ChannelInfoUnion = Union[NormalChannelInfo, OpenChannelInfo]
UserInfoUnion = Union[NormalChannelUserInfo, OpenChannelUserInfo]
ChannelInfoUnion = NewType("ChannelInfoUnion", Union[NormalChannelInfo, OpenChannelInfo])
@deserializer(ChannelInfoUnion)
def deserialize_channel_info_union(data: JSON) -> ChannelInfoUnion:
if "openLink" in data:
return OpenChannelInfo.deserialize(data)
else:
return NormalChannelInfo.deserialize(data)
setattr(ChannelInfoUnion, "deserialize", deserialize_channel_info_union)
UserInfoUnion = NewType("UserInfoUnion", Union[NormalChannelUserInfo, OpenChannelUserInfo])
@deserializer(UserInfoUnion)
def deserialize_user_info_union(data: JSON) -> UserInfoUnion:
if "perm" in data:
return OpenChannelUserInfo.deserialize(data)
else:
return NormalChannelUserInfo.deserialize(data)
setattr(UserInfoUnion, "deserialize", deserialize_user_info_union)
@dataclass
class PortalChannelInfo(SerializableAttrs):
name: str
#participants: list[PuppetUserInfo]
participants: list[UserInfoUnion]
# TODO Image
channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller

View File

@ -13,6 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from attr import dataclass
from mautrix.types import SerializableAttrs
@ -23,9 +25,9 @@ from .mention import MentionStruct
@dataclass(kw_only=True)
class Attachment(SerializableAttrs):
shout: bool | None = None
mentions: list[MentionStruct] | None = None
urls: list[str] | None = None
shout: Optional[bool] = None
mentions: Optional[list[MentionStruct]] = None
urls: Optional[list[str]] = None
@dataclass

View File

@ -32,11 +32,11 @@ class Long(SerializableAttrs):
return cls(**bson.loads(raw))
@classmethod
def from_optional_bytes(cls, raw: bytes | None) -> Optional["Long"]:
def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
return cls(**bson.loads(raw)) if raw is not None else None
@classmethod
def to_optional_bytes(cls, value: Optional["Long"]) -> bytes | None:
def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
return bytes(value) if value is not None else None
def serialize(self) -> JSON:
@ -48,13 +48,34 @@ class Long(SerializableAttrs):
return bson.dumps(asdict(self))
def __int__(self) -> int:
# TODO Is this right?
return self.high << 32 + self.low
if self.unsigned:
pass
result = \
((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
( self.low + (1 << 32 if self.low < 0 else 0))
return result + (1 << 32 if self.unsigned and result < 0 else 0)
def __str__(self) -> str:
return f"{self.high << 32 if self.high else ''}{self.low}"
return str(int(self))
ZERO: ClassVar["Long"]
Long.ZERO = Long(0, 0, False)
class IntLong(Long):
"""Helper class for constructing a Long from an int."""
def __init__(self, val: int):
if val < 0:
pass
super().__init__(
high=(val & 0xffffffff00000000) >> 32,
low = val & 0x00000000ffffffff,
unsigned=val < 0,
)
class StrLong(IntLong):
"""Helper class for constructing a Long from the string representation of an int."""
def __init__(self, val: str):
super().__init__(int(val))

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Optional
from attr import dataclass
@ -48,7 +48,7 @@ class ChannelInfo(Channel):
newChatCountInvalid: bool
lastChatLogId: Long
lastSeenLogId: Long
lastChatLog: Chatlog | None = None
lastChatLog: Optional[Chatlog] = None
metaMap: ChannelMetaMap
displayUserList: list[DisplayUserInfo]
pushAlert: bool

View File

@ -23,12 +23,12 @@ class KnownChannelType(str, Enum):
DirectChat = "DirectChat"
PlusChat = "PlusChat"
MemoChat = "MemoChat"
OM = "OM"
OD = "OD"
OM = "OM" # "OpenMulti"?
OD = "OD" # "OpenDirect"?
@classmethod
def is_direct(cls, value: Union["KnownChannelType", str]) -> bool:
return str in [KnownChannelType.DirectChat, KnownChannelType.OD]
return value == KnownChannelType.DirectChat
ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str

View File

@ -13,6 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from attr import dataclass
from mautrix.types import JSON
@ -25,7 +27,7 @@ from . import OpenTokenComponent, OpenLink
@dataclass(kw_only=True)
class OpenChannelInfo(ChannelInfo, OpenChannel, OpenTokenComponent):
directChannel: bool
openLink: OpenLink | None = None
openLink: Optional[OpenLink] = None
@dataclass

View File

@ -13,12 +13,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Generic, Type, TypeVar, Union
from typing import Generic, Type, TypeVar, Union, Iterable
from attr import dataclass
from enum import IntEnum
from mautrix.types import SerializableAttrs, JSON
from mautrix.types import Serializable, SerializableAttrs, JSON
from .api.auth_api_client import KnownAuthStatusCode
@ -80,7 +80,22 @@ class RootCommandResult(ResponseState):
success: bool
ResultType = TypeVar("ResultType", bound=SerializableAttrs)
ResultType = TypeVar("ResultType", bound=Serializable)
def ResultListType(result_type: Type[ResultType]):
class _ResultListType(list[result_type], Serializable):
def __init__(self, iterable: Iterable[result_type]=()):
list.__init__(self, (result_type.deserialize(x) for x in iterable))
def serialize(self) -> list[JSON]:
return [v.serialize() for v in self]
@classmethod
def deserialize(cls, data: list[JSON]) -> "_ResultListType":
return cls(data)
return _ResultListType
@dataclass
class CommandResultDoneValue(RootCommandResult, Generic[ResultType]):

View File

@ -13,6 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from attr import dataclass
from mautrix.types import SerializableAttrs
@ -28,7 +30,7 @@ class ChannelUser(SerializableAttrs):
@dataclass(kw_only=True)
class PartialOpenLinkComponent(SerializableAttrs):
"""Substitute for Partial<OpenLinkComponent>"""
linkId: Long | None = None
linkId: Optional[Long] = None
@dataclass

View File

@ -45,12 +45,14 @@ from .db import (
Message as DBMessage,
Portal as DBPortal,
)
from .formatter.from_kakaotalk import kakaotalk_to_matrix
from .kt.types.bson import Long
from .kt.types.bson import Long, IntLong
from .kt.types.channel.channel_info import ChannelInfo
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
from .kt.types.user.channel_user_info import DisplayUserInfo
from .kt.types.chat.chat import Chatlog
from .kt.client.types import PortalChannelInfo
from .kt.client.types import UserInfoUnion, PortalChannelInfo
if TYPE_CHECKING:
from .__main__ import KakaoTalkBridge
@ -201,7 +203,7 @@ class Portal(DBPortal, BasePortal):
#self._update_photo(source, info.image),
)
)
changed = await self._update_participants(source, info.channel_info.displayUserList) or changed
changed = await self._update_participants(source, info.participants) or changed
if changed or force_save:
await self.update_bridge_info()
await self.save()
@ -342,7 +344,7 @@ class Portal(DBPortal, BasePortal):
)
"""
async def _update_participants(self, source: u.User, participants: list[DisplayUserInfo]) -> bool:
async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool:
changed = False
# TODO nick_map?
for participant in participants:
@ -556,10 +558,10 @@ class Portal(DBPortal, BasePortal):
)
if not self.is_direct:
await self._update_participants(source, info.channel_info.displayUserList)
await self._update_participants(source, info.participants)
try:
await self.backfill(source, is_initial=True, channel=info.channel_info)
await self.backfill(source, is_initial=True, channel_info=info.channel_info)
except Exception:
self.log.exception("Failed to backfill new portal")
@ -752,31 +754,131 @@ class Portal(DBPortal, BasePortal):
self,
source: u.User,
sender: p.Puppet,
message: str,
reply_to: None = None,
message: Chatlog,
reply_to: Chatlog | None = None,
) -> None:
try:
await self._handle_remote_message(source, sender, message, reply_to)
except Exception:
self.log.exception(
"Error handling Kakaotalk message <TODO: ID>"
"Error handling KakaoTalk message %s",
message.logId,
)
async def _handle_remote_message(
self,
source: u.User,
sender: p.Puppet,
message: str,
reply_to: None = None,
message: Chatlog,
reply_to: Chatlog | None = None,
) -> None:
self.log.info("TODO")
self.log.debug(f"Handling KakaoTalk event {message.logId}")
if not self.mxid:
mxid = await self.create_matrix_room(source)
if not mxid:
# Failed to create
return
if not await self._bridge_own_message_pm(source, sender, f"message {message.logId}"):
return
intent = sender.intent_for(self)
if (
self._backfill_leave is not None
and self.ktid != sender.ktid
and intent != sender.intent
and intent not in self._backfill_leave
):
self.log.debug("Adding %s's default puppet to room for backfilling", sender.mxid)
await self.main_intent.invite_user(self.mxid, intent.mxid)
await intent.ensure_joined(self.mxid)
self._backfill_leave.add(intent)
if message.attachment:
self.log.info("TODO: _handle_remote_message attachments")
if message.supplement:
self.log.info("TODO: _handle_remote_message supplements")
if message.text:
content = await kakaotalk_to_matrix(message.text)
event_id = await self._send_message(intent, content, timestamp=message.sendAt)
await DBMessage(
mxid=event_id,
mx_room=self.mxid,
ktid=message.logId,
index=0,
kt_chat=self.ktid,
kt_receiver=self.kt_receiver,
timestamp=message.sendAt,
).insert()
await self._send_delivery_receipt(event_id)
else:
self.log.warning(f"Unhandled KakaoTalk message {message.logId}")
# TODO Many more remote handlers
# endregion
async def backfill(self, source: u.User, is_initial: bool, channel: PortalChannelInfo) -> None:
self.log.info("TODO: backfill")
async def backfill(self, source: u.User, is_initial: bool, channel_info: ChannelInfo) -> None:
limit = (
self.config["bridge.backfill.initial_limit"]
if is_initial
else self.config["bridge.backfill.missed_limit"]
)
if limit == 0:
return
elif limit < 0:
limit = None
last_log_id = None
if not is_initial and channel_info.lastChatLog:
last_log_id = channel_info.lastChatLog.logId
most_recent = await DBMessage.get_most_recent(self.ktid, self.kt_receiver)
if most_recent and is_initial:
self.log.debug("Not backfilling %s: already bridged messages found", self.ktid_log)
# TODO Should this be removed? With it, can't sync an empty portal!
#elif (not most_recent or not most_recent.timestamp) and not is_initial:
# self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log)
elif last_log_id and most_recent and int(most_recent.ktid) >= int(last_log_id):
self.log.debug(
"Not backfilling %s: last activity is equal to most recent bridged "
"message (%s >= %s)",
self.ktid_log,
most_recent.ktid,
last_log_id,
)
else:
with self.backfill_lock:
await self._backfill(
source,
limit,
most_recent.ktid if most_recent else None,
channel_info=channel_info,
)
async def _backfill(
self,
source: u.User,
limit: int | None,
after_log_id: Long | None,
channel_info: ChannelInfo,
) -> None:
self.log.debug("Backfilling history through %s", source.mxid)
self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid))
messages = await source.client.get_chats(
channel_info.channelId,
limit,
after_log_id
)
if not messages:
self.log.debug("Didn't get any messages from server")
return
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server")
self._backfill_leave = set()
async with NotificationDisabler(self.mxid, source):
for message in messages:
puppet = await p.Puppet.get_by_ktid(message.sender.userId)
await self.handle_remote_message(source, puppet, message)
for intent in self._backfill_leave:
self.log.trace("Leaving room with %s post-backfill", intent.mxid)
await intent.leave_room(self.mxid)
self.log.info("Backfilled %d messages through %s", len(messages), source.mxid)
# region Database getters

View File

@ -31,8 +31,9 @@ from . import matrix as m, portal as p, user as u
from .config import Config
from .db import Puppet as DBPuppet
from .kt.types.bson import Long
from .kt.types.user.channel_user_info import DisplayUserInfo
from .kt.types.bson import Long, StrLong
from .kt.client.types import UserInfoUnion
if TYPE_CHECKING:
from .__main__ import KakaoTalkBridge
@ -42,7 +43,7 @@ class Puppet(DBPuppet, BasePuppet):
mx: m.MatrixHandler
config: Config
hs_domain: str
mxid_template: SimpleTemplate[int]
mxid_template: SimpleTemplate[StrLong]
by_ktid: dict[Long, Puppet] = {}
by_custom_mxid: dict[UserID, Puppet] = {}
@ -126,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
keyword="userid",
prefix="@",
suffix=f":{Puppet.hs_domain}",
type=int,
type=StrLong,
)
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
cls.homeserver_url_map = {
@ -147,7 +148,7 @@ class Puppet(DBPuppet, BasePuppet):
async def update_info(
self,
source: u.User,
info: DisplayUserInfo,
info: UserInfoUnion,
update_avatar: bool = True,
) -> Puppet:
self._last_info_sync = datetime.now()
@ -161,7 +162,7 @@ class Puppet(DBPuppet, BasePuppet):
self.log.exception(f"Failed to update info from source {source.ktid}")
return self
async def _update_name(self, info: DisplayUserInfo) -> bool:
async def _update_name(self, info: UserInfoUnion) -> bool:
name = info.nickname
if name != self.name or not self.name_set:
self.name = name
@ -259,7 +260,7 @@ class Puppet(DBPuppet, BasePuppet):
return None
@classmethod
def get_id_from_mxid(cls, mxid: UserID) -> int | None:
def get_id_from_mxid(cls, mxid: UserID) -> Long | None:
return cls.mxid_template.parse(mxid)
@classmethod

View File

@ -415,6 +415,7 @@ class User(DBUser, BaseUser):
assert self.client
try:
# TODO if not is_startup, close existing listeners
login_result = await self.client.start()
await self._sync_channels(login_result, is_startup)
# TODO connect listeners, even if channel sync fails (except if it's an auth failure)
@ -502,7 +503,7 @@ class User(DBUser, BaseUser):
await portal.create_matrix_room(self, portal_info)
else:
await portal.update_matrix_room(self, portal_info)
await portal.backfill(self, is_initial=False, channel=channel_info)
await portal.backfill(self, is_initial=False, channel_info=channel_info)
async def get_notice_room(self) -> RoomID:
if not self.notice_room:

View File

@ -202,6 +202,14 @@ export default class PeerClient {
return loginRes
}
/**
* TODO Consider caching per-user
* @param {string} uuid
*/
async #createAuthClient(uuid) {
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
}
// TODO Wrapper for per-user commands
/**
@ -237,14 +245,6 @@ export default class PeerClient {
return await oAuthClient.renew(req.oauth_credential)
}
/**
* TODO Consider caching per-user
* @param {string} uuid
*/
async #createAuthClient(uuid) {
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
}
/**
* @param {Object} req
* @param {string} req.mxid
@ -314,22 +314,33 @@ export default class PeerClient {
* @param {string} req.mxid
* @param {Long} req.channel_id
*/
getPortalChannelInfo = (req) => {
getPortalChannelInfo = async (req) => {
const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
/* TODO Decide if this is needed. If it is, make function async!
const res = await talkChannel.updateAll()
if (!res.success) return res
*/
return this.#makeCommandResult({
name: talkChannel.getDisplayName(),
//participants: Array.from(talkChannel.getAllUserInfo()),
participants: Array.from(talkChannel.getAllUserInfo()),
// TODO Image
})
}
/**
* @param {Object} req
* @param {string} req.mxid
* @param {Long} req.channel_id
* @param {Long?} req.sync_from
*/
getChats = async (req) => {
const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
return await talkChannel.getChatListFrom(req.sync_from)
}
/**
* @param {Object} req
* @param {string} req.mxid
@ -421,6 +432,7 @@ export default class PeerClient {
register_device: this.registerDevice,
get_own_profile: this.getOwnProfile,
get_portal_channel_info: this.getPortalChannelInfo,
get_chats: this.getChats,
get_profile: this.getProfile,
/*
send: req => this.puppet.sendMessage(req.chat_id, req.text),