First crack at incoming messages
This commit is contained in:
parent
6e6c6f5c48
commit
66b66bd27b
|
@ -42,7 +42,7 @@ class Message:
|
|||
timestamp: int
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record | None) -> Message | None:
|
||||
def _from_row(cls, row: Record) -> Message | None:
|
||||
data = {**row}
|
||||
ktid = data.pop("ktid")
|
||||
kt_chat = data.pop("kt_chat")
|
||||
|
@ -64,13 +64,13 @@ class Message:
|
|||
async def get_all_by_ktid(cls, ktid: Long, kt_receiver: Long) -> list[Message]:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
|
||||
rows = await cls.db.fetch(q, str(ktid), bytes(kt_receiver))
|
||||
return [cls._from_row(row) for row in rows]
|
||||
return [cls._from_row(row) for row in rows if row]
|
||||
|
||||
@classmethod
|
||||
async def get_by_ktid(cls, ktid: Long, kt_receiver: Long, index: int = 0) -> Message | None:
|
||||
q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
|
||||
row = await cls.db.fetchrow(q, str(ktid), bytes(kt_receiver), index)
|
||||
return cls._from_row(row)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def delete_all_by_room(cls, room_id: RoomID) -> None:
|
||||
|
@ -80,7 +80,7 @@ class Message:
|
|||
async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
|
||||
q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2"
|
||||
row = await cls.db.fetchrow(q, mxid, mx_room)
|
||||
return cls._from_row(row)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_most_recent(cls, kt_chat: Long, kt_receiver: Long) -> Message | None:
|
||||
|
@ -90,7 +90,7 @@ class Message:
|
|||
"ORDER BY timestamp DESC LIMIT 1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver))
|
||||
return cls._from_row(row)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
@classmethod
|
||||
async def get_closest_before(
|
||||
|
@ -103,7 +103,7 @@ class Message:
|
|||
"ORDER BY timestamp DESC LIMIT 1"
|
||||
)
|
||||
row = await cls.db.fetchrow(q, bytes(kt_chat), bytes(kt_receiver), timestamp)
|
||||
return cls._from_row(row)
|
||||
return cls._from_optional_row(row)
|
||||
|
||||
_insert_query = (
|
||||
'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
|
||||
|
|
|
@ -39,6 +39,7 @@ from ...rpc import RPCClient
|
|||
from ..types.api.struct.profile import ProfileReqStruct, ProfileStruct
|
||||
from ..types.bson import Long
|
||||
from ..types.client.client_session import LoginResult
|
||||
from ..types.channel.channel_type import ChannelType
|
||||
from ..types.chat.chat import Chatlog
|
||||
from ..types.oauth import OAuthCredential, OAuthInfo
|
||||
from ..types.request import (
|
||||
|
@ -208,14 +209,12 @@ class Client:
|
|||
)
|
||||
return profile_req_struct.profile
|
||||
|
||||
async def get_portal_channel_info(self, channel_info: ChannelInfoUnion) -> PortalChannelInfo:
|
||||
req = await self._api_user_request_result(
|
||||
async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo:
|
||||
return await self._api_user_request_result(
|
||||
PortalChannelInfo,
|
||||
"get_portal_channel_info",
|
||||
channel_id=channel_info.channelId.serialize()
|
||||
channel_id=channel_id.serialize()
|
||||
)
|
||||
req.channel_info = channel_info
|
||||
return req
|
||||
|
||||
async def get_chats(self, channel_id: Long, limit: int | None, sync_from: Long | None) -> list[Chatlog]:
|
||||
return (await self._api_user_request_result(
|
||||
|
@ -233,6 +232,10 @@ class Client:
|
|||
text=text
|
||||
)
|
||||
|
||||
async def start_listen(self) -> None:
|
||||
# TODO Connect all listeners here?
|
||||
await self._api_user_request_void("start_listen")
|
||||
|
||||
async def stop(self) -> None:
|
||||
# TODO Stop all event handlers
|
||||
await self._api_user_request_void("stop")
|
||||
|
@ -269,15 +272,19 @@ class Client:
|
|||
|
||||
# region listeners
|
||||
|
||||
async def on_message(self, func: Callable[[Chatlog, Long], Awaitable[None]]) -> None:
|
||||
async def on_message(self, func: Callable[[Chatlog, Long, ChannelType], Awaitable[None]]) -> None:
|
||||
async def wrapper(data: dict[str, JSON]) -> None:
|
||||
await func(Chatlog.deserialize(data["chatlog"]), data["channelId"])
|
||||
await func(
|
||||
Chatlog.deserialize(data["chatlog"]),
|
||||
Long.deserialize(data["channelId"]),
|
||||
data["channelType"]
|
||||
)
|
||||
|
||||
self._add_user_handler("chat", wrapper)
|
||||
self._add_user_handler("message", wrapper)
|
||||
|
||||
|
||||
def _add_user_handler(self, command: str, handler: EventHandler) -> str:
|
||||
self._rpc_client.add_event_handler(f"{command}:{self.mxid}", handler)
|
||||
self._rpc_client.add_event_handler(f"{command}:{self.user.mxid}", handler)
|
||||
|
||||
# endregion
|
||||
|
||||
|
|
|
@ -93,13 +93,14 @@ class Portal(DBPortal, BasePortal):
|
|||
|
||||
_main_intent: IntentAPI | None
|
||||
_create_room_lock: asyncio.Lock
|
||||
_dedup: deque[str]
|
||||
_oti_dedup: dict[int, DBMessage]
|
||||
_send_locks: dict[int, asyncio.Lock]
|
||||
_noop_lock: FakeLock = FakeLock()
|
||||
_typing: set[UserID]
|
||||
backfill_lock: SimpleLock
|
||||
_backfill_leave: set[IntentAPI] | None
|
||||
_sleeping_to_resync: bool
|
||||
_scheduled_resync: asyncio.Task | None
|
||||
_resync_targets: dict[int, p.Puppet]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -132,10 +133,11 @@ class Portal(DBPortal, BasePortal):
|
|||
|
||||
self._main_intent = None
|
||||
self._create_room_lock = asyncio.Lock()
|
||||
self._dedup = deque(maxlen=100)
|
||||
self._oti_dedup = {}
|
||||
self._send_locks = {}
|
||||
self._typing = set()
|
||||
self._sleeping_to_resync = False
|
||||
self._scheduled_resync = None
|
||||
self._resync_targets = {}
|
||||
|
||||
self.backfill_lock = SimpleLock(
|
||||
"Waiting for backfilling to finish before handling %s", log=self.log
|
||||
|
@ -190,12 +192,45 @@ class Portal(DBPortal, BasePortal):
|
|||
# endregion
|
||||
# region Chat info updating
|
||||
|
||||
def schedule_resync(self, source: u.User, target: p.Puppet) -> None:
|
||||
self._resync_targets[target.ktid] = target
|
||||
if (
|
||||
self._sleeping_to_resync
|
||||
and self._scheduled_resync
|
||||
and not self._scheduled_resync.done()
|
||||
):
|
||||
return
|
||||
self._sleeping_to_resync = True
|
||||
self.log.debug(f"Scheduling resync through {source.mxid}/{source.ktid}")
|
||||
self._scheduled_resync = asyncio.create_task(self._sleep_and_resync(source, 10))
|
||||
|
||||
async def _sleep_and_resync(self, source: u.User, sleep: int) -> None:
|
||||
await asyncio.sleep(sleep)
|
||||
targets = self._resync_targets
|
||||
self._sleeping_to_resync = False
|
||||
self._resync_targets = {}
|
||||
for puppet in targets.values():
|
||||
if not puppet.name or not puppet.name_set:
|
||||
break
|
||||
else:
|
||||
self.log.debug(
|
||||
f"Cancelled resync through {source.mxid}/{source.ktid}, all puppets have names"
|
||||
)
|
||||
return
|
||||
self.log.debug(f"Resyncing chat through {source.mxid}/{source.ktid} after sleeping")
|
||||
await self.update_info(source)
|
||||
self._scheduled_resync = None
|
||||
self.log.debug(f"Completed scheduled resync through {source.mxid}/{source.ktid}")
|
||||
|
||||
async def update_info(
|
||||
self,
|
||||
source: u.User,
|
||||
info: PortalChannelInfo,
|
||||
info: PortalChannelInfo | None = None,
|
||||
force_save: bool = False,
|
||||
) -> None:
|
||||
) -> PortalChannelInfo | None:
|
||||
if not info:
|
||||
self.log.debug("Called update_info with no info, fetching channel info...")
|
||||
info = await source.client.get_portal_channel_info(self.ktid)
|
||||
changed = False
|
||||
if not self.is_direct:
|
||||
changed = any(
|
||||
|
@ -209,6 +244,7 @@ class Portal(DBPortal, BasePortal):
|
|||
if changed or force_save:
|
||||
await self.update_bridge_info()
|
||||
await self.save()
|
||||
return info
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
|
@ -365,7 +401,7 @@ class Portal(DBPortal, BasePortal):
|
|||
# endregion
|
||||
# region Matrix room creation
|
||||
|
||||
async def update_matrix_room(self, source: u.User, info: PortalChannelInfo) -> None:
|
||||
async def update_matrix_room(self, source: u.User, info: PortalChannelInfo | None = None) -> None:
|
||||
try:
|
||||
await self._update_matrix_room(source, info)
|
||||
except Exception:
|
||||
|
@ -380,7 +416,7 @@ class Portal(DBPortal, BasePortal):
|
|||
return invite_content
|
||||
|
||||
async def _update_matrix_room(
|
||||
self, source: u.User, info: PortalChannelInfo
|
||||
self, source: u.User, info: PortalChannelInfo | None = None
|
||||
) -> None:
|
||||
puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
|
||||
await self.main_intent.invite_user(
|
||||
|
@ -394,7 +430,10 @@ class Portal(DBPortal, BasePortal):
|
|||
if did_join and self.is_direct:
|
||||
await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
|
||||
|
||||
await self.update_info(source, info)
|
||||
info = await self.update_info(source, info)
|
||||
if not info:
|
||||
self.log.warning("Canceling _update_matrix_room as update_info didn't return info")
|
||||
return
|
||||
|
||||
# TODO
|
||||
#await self._sync_read_receipts(info.read_receipts.nodes)
|
||||
|
@ -421,7 +460,7 @@ class Portal(DBPortal, BasePortal):
|
|||
"""
|
||||
|
||||
async def create_matrix_room(
|
||||
self, source: u.User, info: PortalChannelInfo
|
||||
self, source: u.User, info: PortalChannelInfo | None = None
|
||||
) -> RoomID | None:
|
||||
if self.mxid:
|
||||
try:
|
||||
|
@ -474,7 +513,7 @@ class Portal(DBPortal, BasePortal):
|
|||
self.log.warning("Failed to update bridge info", exc_info=True)
|
||||
|
||||
async def _create_matrix_room(
|
||||
self, source: u.User, info: PortalChannelInfo
|
||||
self, source: u.User, info: PortalChannelInfo | None = None
|
||||
) -> RoomID | None:
|
||||
if self.mxid:
|
||||
await self._update_matrix_room(source, info)
|
||||
|
@ -507,7 +546,10 @@ class Portal(DBPortal, BasePortal):
|
|||
if self.is_direct:
|
||||
invites.append(self.az.bot_mxid)
|
||||
|
||||
await self.update_info(source=source, info=info)
|
||||
info = await self.update_info(source=source, info=info)
|
||||
if not info:
|
||||
self.log.debug("update_info() didn't return info, cancelling room creation")
|
||||
return None
|
||||
|
||||
if self.encrypted or not self.is_direct:
|
||||
name = self.name
|
||||
|
|
|
@ -39,14 +39,12 @@ from .kt.client import Client
|
|||
from .kt.client.errors import AuthenticationRequired, ResponseError
|
||||
from .kt.types.api.struct.profile import ProfileStruct
|
||||
from .kt.types.bson import Long
|
||||
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo
|
||||
from .kt.types.channel.channel_info import NormalChannelData
|
||||
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
|
||||
from .kt.types.channel.channel_type import ChannelType
|
||||
from .kt.types.chat.chat import Chatlog
|
||||
from .kt.types.client.client_session import ChannelLoginDataItem
|
||||
from .kt.types.client.client_session import LoginResult
|
||||
from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
|
||||
from .kt.types.oauth import OAuthCredential
|
||||
from .kt.types.openlink.open_channel_info import OpenChannelData
|
||||
from .kt.types.openlink.open_channel_info import OpenChannelInfo
|
||||
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
|
||||
|
||||
METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels")
|
||||
METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync")
|
||||
|
@ -418,7 +416,7 @@ class User(DBUser, BaseUser):
|
|||
# TODO if not is_startup, close existing listeners
|
||||
login_result = await self.client.start()
|
||||
await self._sync_channels(login_result, is_startup)
|
||||
# TODO connect listeners, even if channel sync fails (except if it's an auth failure)
|
||||
self.start_listen()
|
||||
except AuthenticationRequired as e:
|
||||
await self.send_bridge_notice(
|
||||
f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n",
|
||||
|
@ -498,7 +496,8 @@ class User(DBUser, BaseUser):
|
|||
kt_receiver=self.ktid,
|
||||
kt_type=channel_info.type
|
||||
)
|
||||
portal_info = await self.client.get_portal_channel_info(channel_info)
|
||||
portal_info = await self.client.get_portal_channel_info(channel_info.channelId)
|
||||
portal_info.channel_info = channel_info
|
||||
if not portal.mxid:
|
||||
await portal.create_matrix_room(self, portal_info)
|
||||
else:
|
||||
|
@ -585,6 +584,30 @@ class User(DBUser, BaseUser):
|
|||
|
||||
# region KakaoTalk event handling
|
||||
|
||||
def start_listen(self) -> None:
|
||||
self.listen_task = asyncio.create_task(self._try_listen())
|
||||
|
||||
def _disconnect_listener_after_error(self) -> None:
|
||||
self.log.info("TODO: _disconnect_listener_after_error")
|
||||
|
||||
async def _try_listen(self) -> None:
|
||||
try:
|
||||
# TODO Pass all listeners to start_listen instead of registering them one-by-one?
|
||||
await self.client.start_listen()
|
||||
await self.client.on_message(self.on_message)
|
||||
# TODO Handle auth errors specially?
|
||||
#except AuthenticationRequired as e:
|
||||
except Exception:
|
||||
#self.is_connected = False
|
||||
self.log.exception("Fatal error in listener")
|
||||
await self.send_bridge_notice(
|
||||
"Fatal error in listener (see logs for more info)",
|
||||
state_event=BridgeStateEvent.UNKNOWN_ERROR,
|
||||
important=True,
|
||||
error_code="kt-connection-error",
|
||||
)
|
||||
self._disconnect_listener_after_error()
|
||||
|
||||
def stop_listen(self) -> None:
|
||||
self.log.info("TODO: stop_listen")
|
||||
|
||||
|
@ -603,8 +626,18 @@ class User(DBUser, BaseUser):
|
|||
asyncio.create_task(self.post_login(is_startup=True))
|
||||
|
||||
@async_time(METRIC_MESSAGE)
|
||||
async def on_message(self, evt: Chatlog, channel_id: Long) -> None:
|
||||
self.log.info("TODO: on_message")
|
||||
async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None:
|
||||
portal = await po.Portal.get_by_ktid(
|
||||
channel_id,
|
||||
kt_receiver=self.ktid,
|
||||
kt_type=channel_type
|
||||
)
|
||||
puppet = await pu.Puppet.get_by_ktid(evt.sender.userId)
|
||||
await portal.backfill_lock.wait(evt.logId)
|
||||
if not puppet.name:
|
||||
portal.schedule_resync(self, puppet)
|
||||
# TODO reply_to
|
||||
await portal.handle_remote_message(self, puppet, evt)
|
||||
|
||||
# TODO Many more handlers
|
||||
|
||||
|
|
|
@ -266,28 +266,33 @@ export default class PeerClient {
|
|||
const res = await userClient.talkClient.login(req.oauth_credential)
|
||||
if (!res.success) return res
|
||||
|
||||
// Attach listeners in something like start_listen
|
||||
/*
|
||||
this.userClients.set(req.mxid, userClient)
|
||||
return res
|
||||
}
|
||||
|
||||
startListen = async (req) => {
|
||||
const userClient = this.#getUser(req.mxid)
|
||||
|
||||
userClient.talkClient.on("chat", (data, channel) => {
|
||||
this.log(`Found message in channel ${channel.channelId}`)
|
||||
this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
|
||||
return this.#write({
|
||||
id: --this.notificationID,
|
||||
command: userClient.getCmd("chat"),
|
||||
//is_sequential: true, // TODO make sequential per user!
|
||||
chatlog: data.chat(),
|
||||
command: userClient.getCmd("message"),
|
||||
//is_sequential: true, // TODO Make sequential per user & channel (if it isn't already)
|
||||
chatlog: data.chat,
|
||||
channelId: channel.channelId,
|
||||
channelType: channel.info.type,
|
||||
})
|
||||
})
|
||||
|
||||
/*
|
||||
/* TODO Many more listeners
|
||||
userClient.talkClient.on("chat_read", (chat, channel, reader) => {
|
||||
this.log(`chat_read in channel ${channel.channelId}`)
|
||||
//chat.logId
|
||||
})
|
||||
*/
|
||||
|
||||
this.userClients.set(req.mxid, userClient)
|
||||
return res
|
||||
return this.#voidCommandResult
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -448,6 +453,7 @@ export default class PeerClient {
|
|||
renew: this.handleRenew,
|
||||
generate_uuid: util.randomAndroidSubDeviceUUID,
|
||||
register_device: this.registerDevice,
|
||||
start_listen: this.startListen,
|
||||
get_own_profile: this.getOwnProfile,
|
||||
get_portal_channel_info: this.getPortalChannelInfo,
|
||||
get_chats: this.getChats,
|
||||
|
|
Loading…
Reference in New Issue