First crack at incoming messages

This commit is contained in:
Andrew Ferrazzutti 2022-03-10 02:46:24 -05:00
parent 6e6c6f5c48
commit 66b66bd27b
5 changed files with 134 additions and 46 deletions

View File

@ -42,7 +42,7 @@ class Message:
timestamp: int timestamp: int
@classmethod @classmethod
def _from_row(cls, row: Record | None) -> Message | None: def _from_row(cls, row: Record) -> Message | None:
data = {**row} data = {**row}
ktid = data.pop("ktid") ktid = data.pop("ktid")
kt_chat = data.pop("kt_chat") kt_chat = data.pop("kt_chat")
@ -64,13 +64,13 @@ class Message:
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]: async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2" q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver)) rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
return [cls._from_row(row) for row in rows] return [cls._from_row(row) for row in rows if row]
@classmethod @classmethod
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None: async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3' q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index) row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
return cls._from_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
async def delete_all_by_room(cls, room_id: RoomID) -> None: async def delete_all_by_room(cls, room_id: RoomID) -> None:
@ -80,7 +80,7 @@ class Message:
async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None: async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2" q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2"
row = await cls.db.fetchrow(q, mxid, mx_room) row = await cls.db.fetchrow(q, mxid, mx_room)
return cls._from_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None: async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
@ -90,7 +90,7 @@ class Message:
"ORDER BY timestamp DESC LIMIT 1" "ORDER BY timestamp DESC LIMIT 1"
) )
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver)) row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
return cls._from_row(row) return cls._from_optional_row(row)
@classmethod @classmethod
async def get_closest_before( async def get_closest_before(
@ -103,7 +103,7 @@ class Message:
"ORDER BY timestamp DESC LIMIT 1" "ORDER BY timestamp DESC LIMIT 1"
) )
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp) row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
return cls._from_row(row) return cls._from_optional_row(row)
_insert_query = ( _insert_query = (
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, ' 'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '

View File

@ -39,6 +39,7 @@ from ...rpc import RPCClient
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
from ..types.bson import Long from ..types.bson import Long
from ..types.client.client_session import LoginResult from ..types.client.client_session import LoginResult
from ..types.channel.channel_type import ChannelType
from ..types.chat.chat import Chatlog from ..types.chat.chat import Chatlog
from ..types.oauth import OAuthCredential, OAuthInfo from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.request import ( from ..types.request import (
@ -208,14 +209,12 @@ class Client:
) )
return profile_req_struct.profile return profile_req_struct.profile
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo: async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo:
req = await self._api_user_request_result( return await self._api_user_request_result(
PortalChannelInfo, PortalChannelInfo,
"get_portal_channel_info", "get_portal_channel_info",
channel_id=channel_info.channelId.serialize() channel_id=channel_id.serialize()
) )
req.channel_info = channel_info
return req
async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]: async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
return (await self._api_user_request_result( return (await self._api_user_request_result(
@ -233,6 +232,10 @@ class Client:
text=text text=text
) )
async def start_listen(self) -> None:
# TODO Connect all listeners here?
await self._api_user_request_void("start_listen")
async def stop(self) -> None: async def stop(self) -> None:
# TODO Stop all event handlers # TODO Stop all event handlers
await self._api_user_request_void("stop") await self._api_user_request_void("stop")
@ -269,15 +272,19 @@ class Client:
# region listeners # region listeners
async def on_message(self, func: Callable[[Chatlog, Long], Awaitable[None]]) -> None: async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None:
async def wrapper(data: dict[str, JSON]) -> None: async def wrapper(data: dict[str, JSON]) -> None:
await func(Chatlog.deserialize(data["chatlog"]), data["channelId"]) await func(
Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]),
data["channelType"]
)
self._add_user_handler("chat", wrapper) self._add_user_handler("message", wrapper)
def _add_user_handler(self, command: str, handler: EventHandler) -> str: def _add_user_handler(self, command: str, handler: EventHandler) -> str:
self._rpc_client.add_event_handler(f"{command}:{self.mxid}", handler) self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler)
# endregion # endregion

View File

@ -93,13 +93,14 @@ class Portal(DBPortal, BasePortal):
_main_intent: IntentAPI | None _main_intent: IntentAPI | None
_create_room_lock: asyncio.Lock _create_room_lock: asyncio.Lock
_dedup: deque[str]
_oti_dedup: dict[int, DBMessage]
_send_locks: dict[int, asyncio.Lock] _send_locks: dict[int, asyncio.Lock]
_noop_lock: FakeLock = FakeLock() _noop_lock: FakeLock = FakeLock()
_typing: set[UserID] _typing: set[UserID]
backfill_lock: SimpleLock backfill_lock: SimpleLock
_backfill_leave: set[IntentAPI] | None _backfill_leave: set[IntentAPI] | None
_sleeping_to_resync: bool
_scheduled_resync: asyncio.Task | None
_resync_targets: dict[int, p.Puppet]
def __init__( def __init__(
self, self,
@ -132,10 +133,11 @@ class Portal(DBPortal, BasePortal):
self._main_intent = None self._main_intent = None
self._create_room_lock = asyncio.Lock() self._create_room_lock = asyncio.Lock()
self._dedup = deque(maxlen=100)
self._oti_dedup = {}
self._send_locks = {} self._send_locks = {}
self._typing = set() self._typing = set()
self._sleeping_to_resync = False
self._scheduled_resync = None
self._resync_targets = {}
self.backfill_lock = SimpleLock( self.backfill_lock = SimpleLock(
"Waiting for backfilling to finish before handling %s", log=self.log "Waiting for backfilling to finish before handling %s", log=self.log
@ -190,12 +192,45 @@ class Portal(DBPortal, BasePortal):
# endregion # endregion
# region Chat info updating # region Chat info updating
def schedule_resync(self, source: u.User, target: p.Puppet) -> None:
self._resync_targets[target.ktid] = target
if (
self._sleeping_to_resync
and self._scheduled_resync
and not self._scheduled_resync.done()
):
return
self._sleeping_to_resync = True
self.log.debug(f"Scheduling resync through {source.mxid}/{source.ktid}")
self._scheduled_resync = asyncio.create_task(self._sleep_and_resync(source, 10))
async def _sleep_and_resync(self, source: u.User, sleep: int) -> None:
await asyncio.sleep(sleep)
targets = self._resync_targets
self._sleeping_to_resync = False
self._resync_targets = {}
for puppet in targets.values():
if not puppet.name or not puppet.name_set:
break
else:
self.log.debug(
f"Cancelled resync through {source.mxid}/{source.ktid}, all puppets have names"
)
return
self.log.debug(f"Resyncing chat through {source.mxid}/{source.ktid} after sleeping")
await self.update_info(source)
self._scheduled_resync = None
self.log.debug(f"Completed scheduled resync through {source.mxid}/{source.ktid}")
async def update_info( async def update_info(
self, self,
source: u.User, source: u.User,
info: PortalChannelInfo, info: PortalChannelInfo | None = None,
force_save: bool = False, force_save: bool = False,
) -> None: ) -> PortalChannelInfo | None:
if not info:
self.log.debug("Called update_info with no info, fetching channel info...")
info = await source.client.get_portal_channel_info(self.ktid)
changed = False changed = False
if not self.is_direct: if not self.is_direct:
changed = any( changed = any(
@ -209,6 +244,7 @@ class Portal(DBPortal, BasePortal):
if changed or force_save: if changed or force_save:
await self.update_bridge_info() await self.update_bridge_info()
await self.save() await self.save()
return info
""" """
@classmethod @classmethod
@ -365,7 +401,7 @@ class Portal(DBPortal, BasePortal):
# endregion # endregion
# region Matrix room creation # region Matrix room creation
async def update_matrix_room(self, source: u.User, info: PortalChannelInfo) -> None: async def update_matrix_room(self, source: u.User, info: PortalChannelInfo | None = None) -> None:
try: try:
await self._update_matrix_room(source, info) await self._update_matrix_room(source, info)
except Exception: except Exception:
@ -380,7 +416,7 @@ class Portal(DBPortal, BasePortal):
return invite_content return invite_content
async def _update_matrix_room( async def _update_matrix_room(
self, source: u.User, info: PortalChannelInfo self, source: u.User, info: PortalChannelInfo | None = None
) -> None: ) -> None:
puppet = await p.Puppet.get_by_custom_mxid(source.mxid) puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
await self.main_intent.invite_user( await self.main_intent.invite_user(
@ -394,7 +430,10 @@ class Portal(DBPortal, BasePortal):
if did_join and self.is_direct: if did_join and self.is_direct:
await source.update_direct_chats({self.main_intent.mxid: [self.mxid]}) await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
await self.update_info(source, info) info = await self.update_info(source, info)
if not info:
self.log.warning("Canceling _update_matrix_room as update_info didn't return info")
return
# TODO # TODO
#await self._sync_read_receipts(info.read_receipts.nodes) #await self._sync_read_receipts(info.read_receipts.nodes)
@ -421,7 +460,7 @@ class Portal(DBPortal, BasePortal):
""" """
async def create_matrix_room( async def create_matrix_room(
self, source: u.User, info: PortalChannelInfo self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID | None: ) -> RoomID | None:
if self.mxid: if self.mxid:
try: try:
@ -474,7 +513,7 @@ class Portal(DBPortal, BasePortal):
self.log.warning("Failed to update bridge info", exc_info=True) self.log.warning("Failed to update bridge info", exc_info=True)
async def _create_matrix_room( async def _create_matrix_room(
self, source: u.User, info: PortalChannelInfo self, source: u.User, info: PortalChannelInfo | None = None
) -> RoomID | None: ) -> RoomID | None:
if self.mxid: if self.mxid:
await self._update_matrix_room(source, info) await self._update_matrix_room(source, info)
@ -507,7 +546,10 @@ class Portal(DBPortal, BasePortal):
if self.is_direct: if self.is_direct:
invites.append(self.az.bot_mxid) invites.append(self.az.bot_mxid)
await self.update_info(source=source, info=info) info = await self.update_info(source=source, info=info)
if not info:
self.log.debug("update_info() didn't return info, cancelling room creation")
return None
if self.encrypted or not self.is_direct: if self.encrypted or not self.is_direct:
name = self.name name = self.name

View File

@ -39,14 +39,12 @@ from .kt.client import Client
from .kt.client.errors import AuthenticationRequired, ResponseError from .kt.client.errors import AuthenticationRequired, ResponseError
from .kt.types.api.struct.profile import ProfileStruct from .kt.types.api.struct.profile import ProfileStruct
from .kt.types.bson import Long from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
from .kt.types.channel.channel_info import NormalChannelData from .kt.types.channel.channel_type import ChannelType
from .kt.types.chat.chat import Chatlog from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import ChannelLoginDataItem from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
from .kt.types.client.client_session import LoginResult
from .kt.types.oauth import OAuthCredential from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
from .kt.types.openlink.open_channel_info import OpenChannelInfo
METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels") METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels")
METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync") METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync")
@ -418,7 +416,7 @@ class User(DBUser, BaseUser):
# TODO if not is_startup, close existing listeners # TODO if not is_startup, close existing listeners
login_result = await self.client.start() login_result = await self.client.start()
await self._sync_channels(login_result, is_startup) await self._sync_channels(login_result, is_startup)
# TODO connect listeners, even if channel sync fails (except if it's an auth failure) self.start_listen()
except AuthenticationRequired as e: except AuthenticationRequired as e:
await self.send_bridge_notice( await self.send_bridge_notice(
f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n", f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n",
@ -498,7 +496,8 @@ class User(DBUser, BaseUser):
kt_receiver=self.ktid, kt_receiver=self.ktid,
kt_type=channel_info.type kt_type=channel_info.type
) )
portal_info = await self.client.get_portal_channel_info(channel_info) portal_info = await self.client.get_portal_channel_info(channel_info.channelId)
portal_info.channel_info = channel_info
if not portal.mxid: if not portal.mxid:
await portal.create_matrix_room(self, portal_info) await portal.create_matrix_room(self, portal_info)
else: else:
@ -585,6 +584,30 @@ class User(DBUser, BaseUser):
# region KakaoTalk event handling # region KakaoTalk event handling
def start_listen(self) -> None:
self.listen_task = asyncio.create_task(self._try_listen())
def _disconnect_listener_after_error(self) -> None:
self.log.info("TODO: _disconnect_listener_after_error")
async def _try_listen(self) -> None:
try:
# TODO Pass all listeners to start_listen instead of registering them one-by-one?
await self.client.start_listen()
await self.client.on_message(self.on_message)
# TODO Handle auth errors specially?
#except AuthenticationRequired as e:
except Exception:
#self.is_connected = False
self.log.exception("Fatal error in listener")
await self.send_bridge_notice(
"Fatal error in listener (see logs for more info)",
state_event=BridgeStateEvent.UNKNOWN_ERROR,
important=True,
error_code="kt-connection-error",
)
self._disconnect_listener_after_error()
def stop_listen(self) -> None: def stop_listen(self) -> None:
self.log.info("TODO: stop_listen") self.log.info("TODO: stop_listen")
@ -603,8 +626,18 @@ class User(DBUser, BaseUser):
asyncio.create_task(self.post_login(is_startup=True)) asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE) @async_time(METRIC_MESSAGE)
async def on_message(self, evt: Chatlog, channel_id: Long) -> None: async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None:
self.log.info("TODO: on_message") portal = await po.Portal.get_by_ktid(
channel_id,
kt_receiver=self.ktid,
kt_type=channel_type
)
puppet = await pu.Puppet.get_by_ktid(evt.sender.userId)
await portal.backfill_lock.wait(evt.logId)
if not puppet.name:
portal.schedule_resync(self, puppet)
# TODO reply_to
await portal.handle_remote_message(self, puppet, evt)
# TODO Many more handlers # TODO Many more handlers

View File

@ -266,28 +266,33 @@ export default class PeerClient {
const res = await userClient.talkClient.login(req.oauth_credential) const res = await userClient.talkClient.login(req.oauth_credential)
if (!res.success) return res if (!res.success) return res
// Attach listeners in something like start_listen this.userClients.set(req.mxid, userClient)
/* return res
}
startListen = async (req) => {
const userClient = this.#getUser(req.mxid)
userClient.talkClient.on("chat", (data, channel) => { userClient.talkClient.on("chat", (data, channel) => {
this.log(`Found message in channel ${channel.channelId}`) this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
return this.#write({ return this.#write({
id: --this.notificationID, id: --this.notificationID,
command: userClient.getCmd("chat"), command: userClient.getCmd("message"),
//is_sequential: true, // TODO make sequential per user! //is_sequential: true, // TODO Make sequential per user & channel (if it isn't already)
chatlog: data.chat(), chatlog: data.chat,
channelId: channel.channelId, channelId: channel.channelId,
channelType: channel.info.type,
}) })
}) })
/* /* TODO Many more listeners
userClient.talkClient.on("chat_read", (chat, channel, reader) => { userClient.talkClient.on("chat_read", (chat, channel, reader) => {
this.log(`chat_read in channel ${channel.channelId}`) this.log(`chat_read in channel ${channel.channelId}`)
//chat.logId //chat.logId
}) })
*/ */
this.userClients.set(req.mxid, userClient) return this.#voidCommandResult
return res
} }
/** /**
@ -448,6 +453,7 @@ export default class PeerClient {
renew: this.handleRenew, renew: this.handleRenew,
generate_uuid: util.randomAndroidSubDeviceUUID, generate_uuid: util.randomAndroidSubDeviceUUID,
register_device: this.registerDevice, register_device: this.registerDevice,
start_listen: this.startListen,
get_own_profile: this.getOwnProfile, get_own_profile: this.getOwnProfile,
get_portal_channel_info: this.getPortalChannelInfo, get_portal_channel_info: this.getPortalChannelInfo,
get_chats: this.getChats, get_chats: this.getChats,