Puppets and backfilling
This commit is contained in:
parent
e28694c987
commit
0c9550841c
|
@ -99,3 +99,15 @@ async def ping(evt: CommandEvent) -> None:
|
|||
async def refresh(evt: CommandEvent) -> None:
|
||||
await evt.sender.refresh(force_notice=True)
|
||||
"""
|
||||
|
||||
|
||||
@command_handler(
|
||||
needs_auth=True,
|
||||
management_only=True,
|
||||
help_section=SECTION_CONNECTION,
|
||||
help_text="Resync chats",
|
||||
)
|
||||
async def sync(evt: CommandEvent) -> None:
|
||||
await evt.mark_read()
|
||||
await evt.sender.post_login(is_startup=False)
|
||||
await evt.reply("Sync complete")
|
||||
|
|
|
@ -23,6 +23,8 @@ from attr import dataclass
|
|||
from mautrix.types import EventID, RoomID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ..kt.types.bson import Long, StrLong
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
|
@ -30,46 +32,44 @@ fake_db = Database.create("") if TYPE_CHECKING else None
|
|||
class Message:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
# TODO Store all Long values as the same type
|
||||
mxid: EventID
|
||||
mx_room: RoomID
|
||||
ktid: str | None
|
||||
kt_txn_id: int | None
|
||||
ktid: Long
|
||||
index: int
|
||||
kt_chat: int
|
||||
kt_receiver: int
|
||||
kt_sender: int
|
||||
kt_chat: Long
|
||||
kt_receiver: Long
|
||||
timestamp: int
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record | None) -> Message | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
columns = 'mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, kt_sender, timestamp'
|
||||
data = {**row}
|
||||
ktid = data.pop("ktid")
|
||||
kt_chat = data.pop("kt_chat")
|
||||
kt_receiver = data.pop("kt_receiver")
|
||||
return cls(
|
||||
**data,
|
||||
ktid=StrLong(ktid),
|
||||
kt_chat=Long.from_bytes(kt_chat),
|
||||
kt_receiver=Long.from_bytes(kt_receiver)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_all_by_ktid(cls, ktid: str, kt_receiver: int) -> list[Message]:
|
||||
def _from_optional_row(cls, row: Record | None) -> Message | None:
|
||||
return cls._from_row(row) if row is not None else None
|
||||
|
||||
columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'
|
||||
|
||||
@classmethod
|
||||
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
|
||||
rows = await cls.db.fetch(q, ktid, kt_receiver)
|
||||
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: str, kt_receiver: int, index: int = 0) -> Message | None:
|
||||
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
|
||||
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||
row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
|
||||
return cls._from_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid_or_oti(
|
||||
cls, ktid: str, oti: int, kt_receiver: int, kt_sender: int, index: int = 0
|
||||
) -> Message | None:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
"FROM message WHERE (ktid=$1 OR (kt_txn_id=$2 AND kt_sender=$3)) AND "
|
||||
' kt_receiver=$4 AND "index"=$5'
|
||||
)
|
||||
row = await cls.db.fetchrow(q, ktid, oti, kt_sender, kt_receiver, index)
|
||||
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
|
||||
return cls._from_row(row)
|
||||
|
||||
@classmethod
|
||||
|
@ -83,18 +83,18 @@ class Message:
|
|||
return cls._from_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
|
||||
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
"FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
|
||||
"ORDER BY timestamp DESC LIMIT 1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
|
||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
|
||||
return cls._from_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_closest_before(
|
||||
cls, kt_chat: int, kt_receiver: int, timestamp: int
|
||||
cls, kt_chat: Long, kt_receiver: Long, timestamp: int
|
||||
) -> Message | None:
|
||||
q = (
|
||||
f"SELECT {cls.columns} "
|
||||
|
@ -102,23 +102,21 @@ class Message:
|
|||
" ktid IS NOT NULL "
|
||||
"ORDER BY timestamp DESC LIMIT 1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
|
||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
||||
return cls._from_row(row)
|
||||
|
||||
_insert_query = (
|
||||
'INSERT INTO message (mxid, mx_room, ktid, kt_txn_id, "index", kt_chat, kt_receiver, '
|
||||
" kt_sender, timestamp) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
|
||||
" timestamp) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def bulk_create(
|
||||
cls,
|
||||
ktid: str,
|
||||
oti: int,
|
||||
kt_chat: int,
|
||||
kt_receiver: int,
|
||||
kt_sender: int,
|
||||
ktid: Long,
|
||||
kt_chat: Long,
|
||||
kt_receiver: Long,
|
||||
event_ids: list[EventID],
|
||||
timestamp: int,
|
||||
mx_room: RoomID,
|
||||
|
@ -127,7 +125,7 @@ class Message:
|
|||
return
|
||||
columns = [col.strip('"') for col in cls.columns.split(", ")]
|
||||
records = [
|
||||
(mxid, mx_room, ktid, oti, index, kt_chat, kt_receiver, kt_sender, timestamp)
|
||||
(mxid, mx_room, str(ktid), index, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
||||
for index, mxid in enumerate(event_ids)
|
||||
]
|
||||
async with cls.db.acquire() as conn, conn.transaction():
|
||||
|
@ -142,19 +140,17 @@ class Message:
|
|||
q,
|
||||
self.mxid,
|
||||
self.mx_room,
|
||||
self.ktid,
|
||||
self.kt_txn_id,
|
||||
str(self.ktid),
|
||||
self.index,
|
||||
self.kt_chat,
|
||||
self.kt_receiver,
|
||||
self.kt_sender,
|
||||
bytes(self.kt_chat),
|
||||
bytes(self.kt_receiver),
|
||||
self.timestamp,
|
||||
)
|
||||
|
||||
async def delete(self) -> None:
|
||||
q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||
await self.db.execute(q, self.ktid, self.kt_receiver, self.index)
|
||||
await self.db.execute(q, str(self.ktid), bytes(self.kt_receiver), self.index)
|
||||
|
||||
async def update(self) -> None:
|
||||
q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
|
||||
await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)
|
||||
await self.db.execute(q, str(self.ktid), self.timestamp, self.mxid, self.mx_room)
|
||||
|
|
|
@ -186,7 +186,6 @@ bridge:
|
|||
# Whether or not the KakaoTalk users of logged in Matrix users should be
|
||||
# invited to private chats when backfilling history from KakaoTalk. This is
|
||||
# usually needed to prevent rate limits and to allow timestamp massaging.
|
||||
# TODO Is this necessary?
|
||||
invite_own_puppet: true
|
||||
# Maximum number of messages to backfill initially.
|
||||
# Set to 0 to disable backfilling when creating portal.
|
||||
|
|
|
@ -43,7 +43,11 @@ from ..types.chat.chat import Chatlog
|
|||
from ..types.oauth import OAuthCredential, OAuthInfo
|
||||
from ..types.request import (
|
||||
deserialize_result,
|
||||
ResultType, RootCommandResult, CommandResultDoneValue)
|
||||
ResultType,
|
||||
ResultListType,
|
||||
RootCommandResult,
|
||||
CommandResultDoneValue
|
||||
)
|
||||
|
||||
from .types import ChannelInfoUnion
|
||||
from .types import PortalChannelInfo
|
||||
|
@ -155,7 +159,8 @@ class Client:
|
|||
) -> _RequestContextManager:
|
||||
# TODO Is auth ever needed?
|
||||
headers = {
|
||||
**self._headers,
|
||||
# TODO Are any default headers needed?
|
||||
#**self._headers,
|
||||
**(headers or {}),
|
||||
}
|
||||
url = URL(url)
|
||||
|
@ -185,16 +190,24 @@ class Client:
|
|||
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
||||
return login_result
|
||||
|
||||
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
|
||||
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
|
||||
return profile_req_struct.profile
|
||||
|
||||
"""
|
||||
async def is_connected(self) -> bool:
|
||||
resp = await self._rpc_client.request("is_connected")
|
||||
return resp["is_connected"]
|
||||
"""
|
||||
|
||||
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct:
|
||||
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
|
||||
return profile_req_struct.profile
|
||||
|
||||
async def get_profile(self, user_id: Long) -> ProfileStruct:
|
||||
profile_req_struct = await self._api_user_request_result(
|
||||
ProfileReqStruct,
|
||||
"get_profile",
|
||||
user_id=user_id.serialize()
|
||||
)
|
||||
return profile_req_struct.profile
|
||||
|
||||
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo:
|
||||
req = await self._api_user_request_result(
|
||||
PortalChannelInfo,
|
||||
|
@ -204,13 +217,13 @@ class Client:
|
|||
req.channel_info = channel_info
|
||||
return req
|
||||
|
||||
async def get_profile(self, user_id: Long) -> ProfileStruct:
|
||||
profile_req_struct = await self._api_user_request_result(
|
||||
ProfileReqStruct,
|
||||
"get_profile",
|
||||
user_id=user_id.serialize()
|
||||
)
|
||||
return profile_req_struct.profile
|
||||
async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
|
||||
return (await self._api_user_request_result(
|
||||
ResultListType(Chatlog),
|
||||
"get_chats",
|
||||
channel_id=channel_id.serialize(),
|
||||
sync_from=sync_from.serialize() if sync_from else None
|
||||
))[-limit if limit else 0:]
|
||||
|
||||
async def stop(self) -> None:
|
||||
# TODO Stop all event handlers
|
||||
|
|
|
@ -15,23 +15,44 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""Custom wrapper classes around types defined by the KakaoTalk API."""
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, NewType, Union
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import SerializableAttrs
|
||||
from mautrix.types import SerializableAttrs, JSON, deserializer
|
||||
|
||||
from ..types.channel.channel_info import NormalChannelInfo
|
||||
from ..types.openlink.open_channel_info import OpenChannelInfo
|
||||
from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo
|
||||
|
||||
|
||||
ChannelInfoUnion = Union[NormalChannelInfo, OpenChannelInfo]
|
||||
UserInfoUnion = Union[NormalChannelUserInfo, OpenChannelUserInfo]
|
||||
ChannelInfoUnion = NewType("ChannelInfoUnion", Union[NormalChannelInfo, OpenChannelInfo])
|
||||
|
||||
@deserializer(ChannelInfoUnion)
|
||||
def deserialize_channel_info_union(data: JSON) -> ChannelInfoUnion:
|
||||
if "openLink" in data:
|
||||
return OpenChannelInfo.deserialize(data)
|
||||
else:
|
||||
return NormalChannelInfo.deserialize(data)
|
||||
|
||||
setattr(ChannelInfoUnion, "deserialize", deserialize_channel_info_union)
|
||||
|
||||
|
||||
UserInfoUnion = NewType("UserInfoUnion", Union[NormalChannelUserInfo, OpenChannelUserInfo])
|
||||
|
||||
@deserializer(UserInfoUnion)
|
||||
def deserialize_user_info_union(data: JSON) -> UserInfoUnion:
|
||||
if "perm" in data:
|
||||
return OpenChannelUserInfo.deserialize(data)
|
||||
else:
|
||||
return NormalChannelUserInfo.deserialize(data)
|
||||
|
||||
setattr(UserInfoUnion, "deserialize", deserialize_user_info_union)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortalChannelInfo(SerializableAttrs):
|
||||
name: str
|
||||
#participants: list[PuppetUserInfo]
|
||||
participants: list[UserInfoUnion]
|
||||
# TODO Image
|
||||
channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller
|
||||
|
|
|
@ -32,11 +32,11 @@ class Long(SerializableAttrs):
|
|||
return cls(**bson.loads(raw))
|
||||
|
||||
@classmethod
|
||||
def from_optional_bytes(cls, raw: bytes | None) -> Optional["Long"]:
|
||||
def from_optional_bytes(cls, raw: Optional[bytes]) -> Optional["Long"]:
|
||||
return cls(**bson.loads(raw)) if raw is not None else None
|
||||
|
||||
@classmethod
|
||||
def to_optional_bytes(cls, value: Optional["Long"]) -> bytes | None:
|
||||
def to_optional_bytes(cls, value: Optional["Long"]) -> Optional[bytes]:
|
||||
return bytes(value) if value is not None else None
|
||||
|
||||
def serialize(self) -> JSON:
|
||||
|
@ -48,13 +48,34 @@ class Long(SerializableAttrs):
|
|||
return bson.dumps(asdict(self))
|
||||
|
||||
def __int__(self) -> int:
|
||||
# TODO Is this right?
|
||||
return self.high << 32 + self.low
|
||||
if self.unsigned:
|
||||
pass
|
||||
result = \
|
||||
((self.high + (1 << 32 if self.high < 0 else 0)) << 32) + \
|
||||
( self.low + (1 << 32 if self.low < 0 else 0))
|
||||
return result + (1 << 32 if self.unsigned and result < 0 else 0)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.high << 32 if self.high else ''}{self.low}"
|
||||
return str(int(self))
|
||||
|
||||
ZERO: ClassVar["Long"]
|
||||
|
||||
|
||||
Long.ZERO = Long(0, 0, False)
|
||||
|
||||
|
||||
class IntLong(Long):
|
||||
"""Helper class for constructing a Long from an int."""
|
||||
def __init__(self, val: int):
|
||||
if val < 0:
|
||||
pass
|
||||
super().__init__(
|
||||
high=(val & 0xffffffff00000000) >> 32,
|
||||
low = val & 0x00000000ffffffff,
|
||||
unsigned=val < 0,
|
||||
)
|
||||
|
||||
|
||||
class StrLong(IntLong):
|
||||
"""Helper class for constructing a Long from the string representation of an int."""
|
||||
def __init__(self, val: str):
|
||||
super().__init__(int(val))
|
||||
|
|
|
@ -23,12 +23,12 @@ class KnownChannelType(str, Enum):
|
|||
DirectChat = "DirectChat"
|
||||
PlusChat = "PlusChat"
|
||||
MemoChat = "MemoChat"
|
||||
OM = "OM"
|
||||
OD = "OD"
|
||||
OM = "OM" # "OpenMulti"?
|
||||
OD = "OD" # "OpenDirect"?
|
||||
|
||||
@classmethod
|
||||
def is_direct(cls, value: Union["KnownChannelType", str]) -> bool:
|
||||
return str in [KnownChannelType.DirectChat, KnownChannelType.OD]
|
||||
return value == KnownChannelType.DirectChat
|
||||
|
||||
|
||||
ChannelType = Union[KnownChannelType, str] # Substitute for ChannelType = "name1" | ... | "nameN" | str
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Generic, Type, TypeVar, Union
|
||||
from typing import Generic, Type, TypeVar, Union, Iterable
|
||||
|
||||
from attr import dataclass
|
||||
from enum import IntEnum
|
||||
|
||||
from mautrix.types import SerializableAttrs, JSON
|
||||
from mautrix.types import Serializable, SerializableAttrs, JSON
|
||||
|
||||
from .api.auth_api_client import KnownAuthStatusCode
|
||||
|
||||
|
@ -80,7 +80,22 @@ class RootCommandResult(ResponseState):
|
|||
success: bool
|
||||
|
||||
|
||||
ResultType = TypeVar("ResultType", bound=SerializableAttrs)
|
||||
ResultType = TypeVar("ResultType", bound=Serializable)
|
||||
|
||||
def ResultListType(result_type: Type[ResultType]):
|
||||
class _ResultListType(list[result_type], Serializable):
|
||||
def __init__(self, iterable: Iterable[result_type]=()):
|
||||
list.__init__(self, (result_type.deserialize(x) for x in iterable))
|
||||
|
||||
def serialize(self) -> list[JSON]:
|
||||
return [v.serialize() for v in self]
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: list[JSON]) -> "_ResultListType":
|
||||
return cls(data)
|
||||
|
||||
return _ResultListType
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandResultDoneValue(RootCommandResult, Generic[ResultType]):
|
||||
|
|
|
@ -45,12 +45,14 @@ from .db import (
|
|||
Message as DBMessage,
|
||||
Portal as DBPortal,
|
||||
)
|
||||
from .formatter.from_kakaotalk import kakaotalk_to_matrix
|
||||
|
||||
from .kt.types.bson import Long
|
||||
from .kt.types.bson import Long, IntLong
|
||||
from .kt.types.channel.channel_info import ChannelInfo
|
||||
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
|
||||
from .kt.types.user.channel_user_info import DisplayUserInfo
|
||||
from .kt.types.chat.chat import Chatlog
|
||||
|
||||
from .kt.client.types import PortalChannelInfo
|
||||
from .kt.client.types import UserInfoUnion, PortalChannelInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .__main__ import KakaoTalkBridge
|
||||
|
@ -201,7 +203,7 @@ class Portal(DBPortal, BasePortal):
|
|||
#self._update_photo(source, info.image),
|
||||
)
|
||||
)
|
||||
changed = await self._update_participants(source, info.channel_info.displayUserList) or changed
|
||||
changed = await self._update_participants(source, info.participants) or changed
|
||||
if changed or force_save:
|
||||
await self.update_bridge_info()
|
||||
await self.save()
|
||||
|
@ -342,7 +344,7 @@ class Portal(DBPortal, BasePortal):
|
|||
)
|
||||
"""
|
||||
|
||||
async def _update_participants(self, source: u.User, participants: list[DisplayUserInfo]) -> bool:
|
||||
async def _update_participants(self, source: u.User, participants: list[UserInfoUnion]) -> bool:
|
||||
changed = False
|
||||
# TODO nick_map?
|
||||
for participant in participants:
|
||||
|
@ -556,10 +558,10 @@ class Portal(DBPortal, BasePortal):
|
|||
)
|
||||
|
||||
if not self.is_direct:
|
||||
await self._update_participants(source, info.channel_info.displayUserList)
|
||||
await self._update_participants(source, info.participants)
|
||||
|
||||
try:
|
||||
await self.backfill(source, is_initial=True, channel=info.channel_info)
|
||||
await self.backfill(source, is_initial=True, channel_info=info.channel_info)
|
||||
except Exception:
|
||||
self.log.exception("Failed to backfill new portal")
|
||||
|
||||
|
@ -752,31 +754,131 @@ class Portal(DBPortal, BasePortal):
|
|||
self,
|
||||
source: u.User,
|
||||
sender: p.Puppet,
|
||||
message: str,
|
||||
reply_to: None = None,
|
||||
message: Chatlog,
|
||||
reply_to: Chatlog | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
await self._handle_remote_message(source, sender, message, reply_to)
|
||||
except Exception:
|
||||
self.log.exception(
|
||||
"Error handling Kakaotalk message <TODO: ID>"
|
||||
"Error handling KakaoTalk message %s",
|
||||
message.logId,
|
||||
)
|
||||
|
||||
async def _handle_remote_message(
|
||||
self,
|
||||
source: u.User,
|
||||
sender: p.Puppet,
|
||||
message: str,
|
||||
reply_to: None = None,
|
||||
message: Chatlog,
|
||||
reply_to: Chatlog | None = None,
|
||||
) -> None:
|
||||
self.log.info("TODO")
|
||||
self.log.debug(f"Handling KakaoTalk event {message.logId}")
|
||||
if not self.mxid:
|
||||
mxid = await self.create_matrix_room(source)
|
||||
if not mxid:
|
||||
# Failed to create
|
||||
return
|
||||
if not await self._bridge_own_message_pm(source, sender, f"message {message.logId}"):
|
||||
return
|
||||
intent = sender.intent_for(self)
|
||||
if (
|
||||
self._backfill_leave is not None
|
||||
and self.ktid != sender.ktid
|
||||
and intent != sender.intent
|
||||
and intent not in self._backfill_leave
|
||||
):
|
||||
self.log.debug("Adding %s's default puppet to room for backfilling", sender.mxid)
|
||||
await self.main_intent.invite_user(self.mxid, intent.mxid)
|
||||
await intent.ensure_joined(self.mxid)
|
||||
self._backfill_leave.add(intent)
|
||||
|
||||
if message.attachment:
|
||||
self.log.info("TODO: _handle_remote_message attachments")
|
||||
if message.supplement:
|
||||
self.log.info("TODO: _handle_remote_message supplements")
|
||||
if message.text:
|
||||
content = await kakaotalk_to_matrix(message.text)
|
||||
event_id = await self._send_message(intent, content, timestamp=message.sendAt)
|
||||
await DBMessage(
|
||||
mxid=event_id,
|
||||
mx_room=self.mxid,
|
||||
ktid=message.logId,
|
||||
index=0,
|
||||
kt_chat=self.ktid,
|
||||
kt_receiver=self.kt_receiver,
|
||||
timestamp=message.sendAt,
|
||||
).insert()
|
||||
await self._send_delivery_receipt(event_id)
|
||||
else:
|
||||
self.log.warning(f"Unhandled KakaoTalk message {message.logId}")
|
||||
|
||||
# TODO Many more remote handlers
|
||||
|
||||
# endregion
|
||||
|
||||
async def backfill(self, source: u.User, is_initial: bool, channel: PortalChannelInfo) -> None:
|
||||
self.log.info("TODO: backfill")
|
||||
async def backfill(self, source: u.User, is_initial: bool, channel_info: ChannelInfo) -> None:
|
||||
limit = (
|
||||
self.config["bridge.backfill.initial_limit"]
|
||||
if is_initial
|
||||
else self.config["bridge.backfill.missed_limit"]
|
||||
)
|
||||
if limit == 0:
|
||||
return
|
||||
elif limit < 0:
|
||||
limit = None
|
||||
last_log_id = None
|
||||
if not is_initial and channel_info.lastChatLog:
|
||||
last_log_id = channel_info.lastChatLog.logId
|
||||
most_recent = await DBMessage.get_most_recent(self.ktid, self.kt_receiver)
|
||||
if most_recent and is_initial:
|
||||
self.log.debug("Not backfilling %s: already bridged messages found", self.ktid_log)
|
||||
# TODO Should this be removed? With it, can't sync an empty portal!
|
||||
#elif (not most_recent or not most_recent.timestamp) and not is_initial:
|
||||
# self.log.debug("Not backfilling %s: no most recent message found", self.ktid_log)
|
||||
elif last_log_id and most_recent and int(most_recent.ktid) >= int(last_log_id):
|
||||
self.log.debug(
|
||||
"Not backfilling %s: last activity is equal to most recent bridged "
|
||||
"message (%s >= %s)",
|
||||
self.ktid_log,
|
||||
most_recent.ktid,
|
||||
last_log_id,
|
||||
)
|
||||
else:
|
||||
with self.backfill_lock:
|
||||
await self._backfill(
|
||||
source,
|
||||
limit,
|
||||
most_recent.ktid if most_recent else None,
|
||||
channel_info=channel_info,
|
||||
)
|
||||
|
||||
async def _backfill(
|
||||
self,
|
||||
source: u.User,
|
||||
limit: int | None,
|
||||
after_log_id: Long | None,
|
||||
channel_info: ChannelInfo,
|
||||
) -> None:
|
||||
self.log.debug("Backfilling history through %s", source.mxid)
|
||||
self.log.debug("Fetching %s messages through %s", f"up to {limit}" if limit else "all", str(source.ktid))
|
||||
messages = await source.client.get_chats(
|
||||
channel_info.channelId,
|
||||
limit,
|
||||
after_log_id
|
||||
)
|
||||
if not messages:
|
||||
self.log.debug("Didn't get any messages from server")
|
||||
return
|
||||
self.log.debug(f"Got {len(messages)} message{'s' if len(messages) is not 1 else ''} from server")
|
||||
self._backfill_leave = set()
|
||||
async with NotificationDisabler(self.mxid, source):
|
||||
for message in messages:
|
||||
puppet = await p.Puppet.get_by_ktid(message.sender.userId)
|
||||
await self.handle_remote_message(source, puppet, message)
|
||||
for intent in self._backfill_leave:
|
||||
self.log.trace("Leaving room with %s post-backfill", intent.mxid)
|
||||
await intent.leave_room(self.mxid)
|
||||
self.log.info("Backfilled %d messages through %s", len(messages), source.mxid)
|
||||
|
||||
# region Database getters
|
||||
|
||||
|
|
|
@ -31,8 +31,9 @@ from . import matrix as m, portal as p, user as u
|
|||
from .config import Config
|
||||
from .db import Puppet as DBPuppet
|
||||
|
||||
from .kt.types.bson import Long
|
||||
from .kt.types.user.channel_user_info import DisplayUserInfo
|
||||
from .kt.types.bson import Long, StrLong
|
||||
|
||||
from .kt.client.types import UserInfoUnion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .__main__ import KakaoTalkBridge
|
||||
|
@ -42,7 +43,7 @@ class Puppet(DBPuppet, BasePuppet):
|
|||
mx: m.MatrixHandler
|
||||
config: Config
|
||||
hs_domain: str
|
||||
mxid_template: SimpleTemplate[int]
|
||||
mxid_template: SimpleTemplate[StrLong]
|
||||
|
||||
by_ktid: dict[Long, Puppet] = {}
|
||||
by_custom_mxid: dict[UserID, Puppet] = {}
|
||||
|
@ -126,7 +127,7 @@ class Puppet(DBPuppet, BasePuppet):
|
|||
keyword="userid",
|
||||
prefix="@",
|
||||
suffix=f":{Puppet.hs_domain}",
|
||||
type=int,
|
||||
type=StrLong,
|
||||
)
|
||||
cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
|
||||
cls.homeserver_url_map = {
|
||||
|
@ -147,7 +148,7 @@ class Puppet(DBPuppet, BasePuppet):
|
|||
async def update_info(
|
||||
self,
|
||||
source: u.User,
|
||||
info: DisplayUserInfo,
|
||||
info: UserInfoUnion,
|
||||
update_avatar: bool = True,
|
||||
) -> Puppet:
|
||||
self._last_info_sync = datetime.now()
|
||||
|
@ -161,7 +162,7 @@ class Puppet(DBPuppet, BasePuppet):
|
|||
self.log.exception(f"Failed to update info from source {source.ktid}")
|
||||
return self
|
||||
|
||||
async def _update_name(self, info: DisplayUserInfo) -> bool:
|
||||
async def _update_name(self, info: UserInfoUnion) -> bool:
|
||||
name = info.nickname
|
||||
if name != self.name or not self.name_set:
|
||||
self.name = name
|
||||
|
@ -259,7 +260,7 @@ class Puppet(DBPuppet, BasePuppet):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid: UserID) -> int | None:
|
||||
def get_id_from_mxid(cls, mxid: UserID) -> Long | None:
|
||||
return cls.mxid_template.parse(mxid)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -415,6 +415,7 @@ class User(DBUser, BaseUser):
|
|||
|
||||
assert self.client
|
||||
try:
|
||||
# TODO if not is_startup, close existing listeners
|
||||
login_result = await self.client.start()
|
||||
await self._sync_channels(login_result, is_startup)
|
||||
# TODO connect listeners, even if channel sync fails (except if it's an auth failure)
|
||||
|
@ -502,7 +503,7 @@ class User(DBUser, BaseUser):
|
|||
await portal.create_matrix_room(self, portal_info)
|
||||
else:
|
||||
await portal.update_matrix_room(self, portal_info)
|
||||
await portal.backfill(self, is_initial=False, channel=channel_info)
|
||||
await portal.backfill(self, is_initial=False, channel_info=channel_info)
|
||||
|
||||
async def get_notice_room(self) -> RoomID:
|
||||
if not self.notice_room:
|
||||
|
|
|
@ -202,6 +202,14 @@ export default class PeerClient {
|
|||
return loginRes
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO Consider caching per-user
|
||||
* @param {string} uuid
|
||||
*/
|
||||
async #createAuthClient(uuid) {
|
||||
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
|
||||
}
|
||||
|
||||
// TODO Wrapper for per-user commands
|
||||
|
||||
/**
|
||||
|
@ -237,14 +245,6 @@ export default class PeerClient {
|
|||
return await oAuthClient.renew(req.oauth_credential)
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO Consider caching per-user
|
||||
* @param {string} uuid
|
||||
*/
|
||||
async #createAuthClient(uuid) {
|
||||
return await AuthApiClient.create("KakaoTalk Bridge", uuid)
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} req
|
||||
* @param {string} req.mxid
|
||||
|
@ -314,22 +314,33 @@ export default class PeerClient {
|
|||
* @param {string} req.mxid
|
||||
* @param {Long} req.channel_id
|
||||
*/
|
||||
getPortalChannelInfo = (req) => {
|
||||
getPortalChannelInfo = async (req) => {
|
||||
const userClient = this.#getUser(req.mxid)
|
||||
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
|
||||
|
||||
/* TODO Decide if this is needed. If it is, make function async!
|
||||
const res = await talkChannel.updateAll()
|
||||
if (!res.success) return res
|
||||
*/
|
||||
|
||||
return this.#makeCommandResult({
|
||||
name: talkChannel.getDisplayName(),
|
||||
//participants: Array.from(talkChannel.getAllUserInfo()),
|
||||
participants: Array.from(talkChannel.getAllUserInfo()),
|
||||
// TODO Image
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} req
|
||||
* @param {string} req.mxid
|
||||
* @param {Long} req.channel_id
|
||||
* @param {Long?} req.sync_from
|
||||
*/
|
||||
getChats = async (req) => {
|
||||
const userClient = this.#getUser(req.mxid)
|
||||
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
|
||||
|
||||
return await talkChannel.getChatListFrom(req.sync_from)
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} req
|
||||
* @param {string} req.mxid
|
||||
|
@ -421,6 +432,7 @@ export default class PeerClient {
|
|||
register_device: this.registerDevice,
|
||||
get_own_profile: this.getOwnProfile,
|
||||
get_portal_channel_info: this.getPortalChannelInfo,
|
||||
get_chats: this.getChats,
|
||||
get_profile: this.getProfile,
|
||||
/*
|
||||
send: req => this.puppet.sendMessage(req.chat_id, req.text),
|
||||
|
|
Loading…
Reference in New Issue