First crack at incoming messages

This commit is contained in:
Andrew Ferrazzutti 2022-03-10 02:46:24 -05:00
parent 6e6c6f5c48
commit 66b66bd27b
5 changed files with 134 additions and 46 deletions

View File

@ -42,7 +42,7 @@ class Message:
timestamp: int
@classmethod
def _from_row(cls, row: Record | None) -> Message | None:
def _from_row(cls, row: Record) -> Message | None:
data = {**row}
ktid = data.pop("ktid")
kt_chat = data.pop("kt_chat")
@ -64,13 +64,13 @@ class Message:
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
return [cls._from_row(row) for row in rows]
return [cls._from_row(row) for row in rows if row]
@classmethod
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
return cls._from_row(row)
return cls._from_optional_row(row)
@classmethod
async def delete_all_by_room(cls, room_id: RoomID) -> None:
@ -80,7 +80,7 @@ class Message:
async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2"
row = await cls.db.fetchrow(q, mxid, mx_room)
return cls._from_row(row)
return cls._from_optional_row(row)
@classmethod
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
@ -90,7 +90,7 @@ class Message:
"ORDER BY timestamp DESC LIMIT 1"
)
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
return cls._from_row(row)
return cls._from_optional_row(row)
@classmethod
async def get_closest_before(
@ -103,7 +103,7 @@ class Message:
"ORDER BY timestamp DESC LIMIT 1"
)
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
return cls._from_row(row)
return cls._from_optional_row(row)
_insert_query = (
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '

View File

@ -39,6 +39,7 @@ from ...rpc import RPCClient
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
from ..types.bson import Long
from ..types.client.client_session import LoginResult
from ..types.channel.channel_type import ChannelType
from ..types.chat.chat import Chatlog
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.request import (
@ -208,14 +209,12 @@ class Client:
)
return profile_req_struct.profile
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo:
req = await self._api_user_request_result(
async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo:
return await self._api_user_request_result(
PortalChannelInfo,
"get_portal_channel_info",
channel_id=channel_info.channelId.serialize()
channel_id=channel_id.serialize()
)
req.channel_info = channel_info
return req
async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
return (await self._api_user_request_result(
@ -233,6 +232,10 @@ class Client:
text=text
)
async def start_listen(self) -> None:
# TODO Connect all listeners here?
await self._api_user_request_void("start_listen")
async def stop(self) -> None:
# TODO Stop all event handlers
await self._api_user_request_void("stop")
@ -269,15 +272,19 @@ class Client:
# region listeners
async def on_message(self, func: Callable[[Chatlog, Long], Awaitable[None]]) -> None:
async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None:
async def wrapper(data: dict[str, JSON]) -> None:
await func(Chatlog.deserialize(data["chatlog"]), data["channelId"])
await func(
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
data["channelType"]
)
self._add_user_handler("chat", wrapper)
self._add_user_handler("message", wrapper)
def _add_user_handler(self, command: str, handler: EventHandler) -> str:
self._rpc_client.add_event_handler(f"{command}:{self.mxid}", handler)
self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler)
# endregion

View File

@ -93,13 +93,14 @@ class Portal(DBPortal, BasePortal):
_main_intent: IntentAPI | None
_create_room_lock: asyncio.Lock
_dedup: deque[str]
_oti_dedup: dict[int, DBMessage]
_send_locks: dict[int, asyncio.Lock]
_noop_lock: FakeLock = FakeLock()
_typing: set[UserID]
backfill_lock: SimpleLock
_backfill_leave: set[IntentAPI] | None
_sleeping_to_resync: bool
_scheduled_resync: asyncio.Task | None
_resync_targets: dict[int, p.Puppet]
def __init__(
self,
@ -132,10 +133,11 @@ class Portal(DBPortal, BasePortal):
self._main_intent = None
self._create_room_lock = asyncio.Lock()
self._dedup = deque(maxlen=100)
self._oti_dedup = {}
self._send_locks = {}
self._typing = set()
self._sleeping_to_resync = False
self._scheduled_resync = None
self._resync_targets = {}
self.backfill_lock = SimpleLock(
"Waiting for backfilling to finish before handling %s", log=self.log
@ -190,12 +192,45 @@ class Portal(DBPortal, BasePortal):
# endregion
# region Chat info updating
def schedule_resync(self, source: u.User, target: p.Puppet) -> None:
self._resync_targets[target.ktid] = target
if (
self._sleeping_to_resync
and self._scheduled_resync
and not self._scheduled_resync.done()
):
return
self._sleeping_to_resync = True
self.log.debug(f"Scheduling resync through {source.mxid}/{source.ktid}")
self._scheduled_resync = asyncio.create_task(self._sleep_and_resync(source, 10))
async def _sleep_and_resync(self, source: u.User, sleep: int) -> None:
await asyncio.sleep(sleep)
targets = self._resync_targets
self._sleeping_to_resync = False
self._resync_targets = {}
for puppet in targets.values():
if not puppet.name or not puppet.name_set:
break
else:
self.log.debug(
f"Cancelled resync through {source.mxid}/{source.ktid}, all puppets have names"
)
return
self.log.debug(f"Resyncing chat through {source.mxid}/{source.ktid} after sleeping")
await self.update_info(source)
self._scheduled_resync = None
self.log.debug(f"Completed scheduled resync through {source.mxid}/{source.ktid}")
async def update_info(
self,
source: u.User,
info: PortalChannelInfo,
info: PortalChannelInfo | None = None,
force_save: bool = False,
) -> None:
) -> PortalChannelInfo | None:
if not info:
self.log.debug("Called update_info with no info, fetching channel info...")
info = await source.client.get_portal_channel_info(self.ktid)
changed = False
if not self.is_direct:
changed = any(
@ -209,6 +244,7 @@ class Portal(DBPortal, BasePortal):
if changed or force_save:
await self.update_bridge_info()
await self.save()
return info
"""
@classmethod
@ -365,7 +401,7 @@ class Portal(DBPortal, BasePortal):
# endregion
# region Matrix room creation
async def update_matrix_room(self, source: u.User, info: PortalChannelInfo) -> None:
async def update_matrix_room(self, source: u.User, info: PortalChannelInfo | None = None) -> None:
try:
await self._update_matrix_room(source, info)
except Exception:
@ -380,7 +416,7 @@ class Portal(DBPortal, BasePortal):
return invite_content
async def _update_matrix_room(
self, source: u.User, info: PortalChannelInfo
self, source: u.User, info: PortalChannelInfo | None = None
) -> None:
puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
await self.main_intent.invite_user(
@ -394,7 +430,10 @@ class Portal(DBPortal, BasePortal):
if did_join and self.is_direct:
await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
await self.update_info(source, info)
info = await self.update_info(source, info)
if not info:
self.log.warning("Canceling _update_matrix_room as update_info didn't return info")
return
# TODO
#await self._sync_read_receipts(info.read_receipts.nodes)
@ -421,7 +460,7 @@ class Portal(DBPortal, BasePortal):
"""
async def create_matrix_room(
self, source: u.User, info: PortalChannelInfo
self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID | None:
if self.mxid:
try:
@ -474,7 +513,7 @@ class Portal(DBPortal, BasePortal):
self.log.warning("Failed to update bridge info", exc_info=True)
async def _create_matrix_room(
self, source: u.User, info: PortalChannelInfo
self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID | None:
if self.mxid:
await self._update_matrix_room(source, info)
@ -507,7 +546,10 @@ class Portal(DBPortal, BasePortal):
if self.is_direct:
invites.append(self.az.bot_mxid)
await self.update_info(source=source, info=info)
info = await self.update_info(source=source, info=info)
if not info:
self.log.debug("update_info() didn't return info, cancelling room creation")
return None
if self.encrypted or not self.is_direct:
name = self.name

View File

@ -39,14 +39,12 @@ from .kt.client import Client
from .kt.client.errors import AuthenticationRequired, ResponseError
from .kt.types.api.struct.profile import ProfileStruct
from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo
from .kt.types.channel.channel_info import NormalChannelData
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
from .kt.types.channel.channel_type import ChannelType
from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import ChannelLoginDataItem
from .kt.types.client.client_session import LoginResult
from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData
from .kt.types.openlink.open_channel_info import OpenChannelInfo
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels")
METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync")
@ -418,7 +416,7 @@ class User(DBUser, BaseUser):
# TODO if not is_startup, close existing listeners
login_result = await self.client.start()
await self._sync_channels(login_result, is_startup)
# TODO connect listeners, even if channel sync fails (except if it's an auth failure)
self.start_listen()
except AuthenticationRequired as e:
await self.send_bridge_notice(
f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n",
@ -498,7 +496,8 @@ class User(DBUser, BaseUser):
kt_receiver=self.ktid,
kt_type=channel_info.type
)
portal_info = await self.client.get_portal_channel_info(channel_info)
portal_info = await self.client.get_portal_channel_info(channel_info.channelId)
portal_info.channel_info = channel_info
if not portal.mxid:
await portal.create_matrix_room(self, portal_info)
else:
@ -585,6 +584,30 @@ class User(DBUser, BaseUser):
# region KakaoTalk event handling
def start_listen(self) -> None:
self.listen_task = asyncio.create_task(self._try_listen())
def _disconnect_listener_after_error(self) -> None:
self.log.info("TODO: _disconnect_listener_after_error")
async def _try_listen(self) -> None:
try:
# TODO Pass all listeners to start_listen instead of registering them one-by-one?
await self.client.start_listen()
await self.client.on_message(self.on_message)
# TODO Handle auth errors specially?
#except AuthenticationRequired as e:
except Exception:
#self.is_connected = False
self.log.exception("Fatal error in listener")
await self.send_bridge_notice(
"Fatal error in listener (see logs for more info)",
state_event=BridgeStateEvent.UNKNOWN_ERROR,
important=True,
error_code="kt-connection-error",
)
self._disconnect_listener_after_error()
def stop_listen(self) -> None:
self.log.info("TODO: stop_listen")
@ -603,8 +626,18 @@ class User(DBUser, BaseUser):
asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE)
async def on_message(self, evt: Chatlog, channel_id: Long) -> None:
self.log.info("TODO: on_message")
async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None:
portal = await po.Portal.get_by_ktid(
channel_id,
kt_receiver=self.ktid,
kt_type=channel_type
)
puppet = await pu.Puppet.get_by_ktid(evt.sender.userId)
await portal.backfill_lock.wait(evt.logId)
if not puppet.name:
portal.schedule_resync(self, puppet)
# TODO reply_to
await portal.handle_remote_message(self, puppet, evt)
# TODO Many more handlers

View File

@ -266,28 +266,33 @@ export default class PeerClient {
const res = await userClient.talkClient.login(req.oauth_credential)
if (!res.success) return res
// Attach listeners in something like start_listen
/*
this.userClients.set(req.mxid, userClient)
return res
}
startListen = async (req) => {
const userClient = this.#getUser(req.mxid)
userClient.talkClient.on("chat", (data, channel) => {
this.log(`Found message in channel ${channel.channelId}`)
this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
return this.#write({
id: --this.notificationID,
command: userClient.getCmd("chat"),
//is_sequential: true, // TODO make sequential per user!
chatlog: data.chat(),
command: userClient.getCmd("message"),
//is_sequential: true, // TODO Make sequential per user & channel (if it isn't already)
chatlog: data.chat,
channelId: channel.channelId,
channelType: channel.info.type,
})
})
/*
/* TODO Many more listeners
userClient.talkClient.on("chat_read", (chat, channel, reader) => {
this.log(`chat_read in channel ${channel.channelId}`)
//chat.logId
})
*/
this.userClients.set(req.mxid, userClient)
return res
return this.#voidCommandResult
}
/**
@ -448,6 +453,7 @@ export default class PeerClient {
renew: this.handleRenew,
generate_uuid: util.randomAndroidSubDeviceUUID,
register_device: this.registerDevice,
start_listen: this.startListen,
get_own_profile: this.getOwnProfile,
get_portal_channel_info: this.getPortalChannelInfo,
get_chats: this.getChats,