Compare commits

..

No commits in common. "e94c598e3b55e32097b2ea37a05fe76eed8986bb" and "3ced968494cbbdbb6aae027e06f674d8e41d6726" have entirely different histories.

13 changed files with 241 additions and 367 deletions

View File

@ -57,7 +57,6 @@ class KakaoTalkBridge(Bridge):
def prepare_bridge(self) -> None:
super().prepare_bridge()
""" TODO Implement web login
if self.config["appservice.public.enabled"]:
secret = self.config["appservice.public.shared_secret"]
self.public_website = PublicBridgeWebsite(loop=self.loop, shared_secret=secret)
@ -66,21 +65,21 @@ class KakaoTalkBridge(Bridge):
)
else:
self.public_website = None
"""
self.public_website = None
def prepare_stop(self) -> None:
self.log.debug("Stopping RPC connection")
KakaoTalkClient.stop_cls()
self.log.debug("Stopping puppet syncers")
for puppet in Puppet.by_custom_mxid.values():
puppet.stop()
self.log.debug("Stopping kakaotalk listeners")
User.shutdown = True
self.add_shutdown_actions(user.save() for user in User.by_mxid.values())
self.add_shutdown_actions(KakaoTalkClient.stop_cls())
async def start(self) -> None:
KakaoTalkClient.init_cls(self.config)
# Block all other startup actions until RPC is ready
# TODO Remove when/if node backend is replaced with native
await KakaoTalkClient.init_cls(self.config)
self.add_startup_actions(User.init_cls(self))
self.add_startup_actions(Puppet.init_cls(self))
Portal.init_cls(self)

View File

@ -23,6 +23,7 @@ from mautrix.types import UserID
from mautrix.util.config import ConfigUpdateHelper, ForbiddenDefault, ForbiddenKey
# TODO Remove unneeded configs!!
class Config(BaseBridgeConfig):
def __getitem__(self, key: str) -> Any:
try:
@ -93,15 +94,31 @@ class Config(BaseBridgeConfig):
copy("bridge.backfill.initial_limit")
copy("bridge.backfill.missed_limit")
copy("bridge.backfill.disable_notifications")
""" TODO
copy("bridge.periodic_reconnect.interval")
copy("bridge.periodic_reconnect.always")
copy("bridge.periodic_reconnect.min_connected_time")
"""
if "bridge.periodic_reconnect_interval" in self:
base["bridge.periodic_reconnect.interval"] = self["bridge.periodic_reconnect_interval"]
base["bridge.periodic_reconnect.mode"] = self["bridge.periodic_reconnect_mode"]
else:
copy("bridge.periodic_reconnect.interval")
copy("bridge.periodic_reconnect.mode")
copy("bridge.periodic_reconnect.always")
copy("bridge.periodic_reconnect.min_connected_time")
copy("bridge.resync_max_disconnected_time")
copy("bridge.sync_on_startup")
copy("bridge.temporary_disconnect_notices")
copy("bridge.disable_bridge_notices")
if "bridge.refresh_on_reconnection_fail" in self:
base["bridge.on_reconnection_fail.action"] = (
"refresh" if self["bridge.refresh_on_reconnection_fail"] else None
)
base["bridge.on_reconnection_fail.wait_for"] = 0
elif "bridge.on_reconnection_fail.refresh" in self:
base["bridge.on_reconnection_fail.action"] = (
"refresh" if self["bridge.on_reconnection_fail.refresh"] else None
)
copy("bridge.on_reconnection_fail.wait_for")
else:
copy("bridge.on_reconnection_fail.action")
copy("bridge.on_reconnection_fail.wait_for")
copy("bridge.resend_bridge_info")
copy("bridge.mute_bridging")
copy("bridge.tag_only_on_create")
@ -109,14 +126,13 @@ class Config(BaseBridgeConfig):
copy_dict("bridge.permissions")
""" TODO
for key in (
"bridge.periodic_reconnect.interval",
"bridge.on_reconnection_fail.wait_for",
):
value = base.get(key, None)
if isinstance(value, list) and len(value) != 2:
raise ValueError(f"{key} must only be a list of two items")
"""
copy("rpc.connection.type")
if base["rpc.connection.type"] == "unix":

View File

@ -50,7 +50,6 @@ appservice:
max_size: 10
# Public part of web server for out-of-Matrix interaction with the bridge.
# TODO Implement web login
public:
# Whether or not the public-facing endpoints should be enabled.
enabled: false
@ -201,13 +200,15 @@ bridge:
# If using double puppeting, should notifications be disabled
# while the initial backfill is in progress?
disable_notifications: false
# TODO Implement this
# TODO Confirm this isn't needed
#periodic_reconnect:
# # Interval in seconds in which to automatically reconnect all users.
# # This may prevent KakaoTalk from "switching servers".
# # This can be used to automatically mitigate the bug where KakaoTalk stops sending messages.
# # Set to -1 to disable periodic reconnections entirely.
# # Set to a list of two items to randomize the interval (min, max).
# interval: -1
# # What to do in periodic reconnects. Either "refresh" or "reconnect"
# mode: refresh
# # Should even disconnected users be reconnected?
# always: false
# # Only reconnect if the user has been connected for longer than this value
@ -215,7 +216,6 @@ bridge:
# The number of seconds that a disconnection can last without triggering an automatic re-sync
# and missed message backfilling when reconnecting.
# Set to 0 to always re-sync, or -1 to never re-sync automatically.
# TODO Actually use this setting
resync_max_disconnected_time: 5
# Should the bridge do a resync on startup?
sync_on_startup: true

View File

@ -16,12 +16,11 @@
from __future__ import annotations
from typing import Match
from html import escape
import re
from mautrix.types import Format, MessageType, TextMessageEventContent
from ..kt.types.chat.attachment.mention import MentionStruct
from .. import puppet as pu, user as u
_START = r"^|\s"
@ -80,38 +79,92 @@ def _handle_blockquote(output: list[str], blockquote: bool, line: str) -> tuple[
return blockquote, line
async def kakaotalk_to_matrix(msg: str | None, mentions: list[MentionStruct] | None) -> TextMessageEventContent:
# TODO Shouts
text = msg or ""
content = TextMessageEventContent(msgtype=MessageType.TEXT, body=text)
if mentions:
mention_user_ids = []
at_chunks = text.split("@")
for m in mentions:
for idx in m.at:
chunk = at_chunks[idx]
original = chunk[:m.len]
mention_user_ids.append(int(m.user_id))
at_chunks[idx] = f"{m.user_id}\u2063{original}\u2063{chunk[m.len:]}"
text = "@".join(at_chunks)
mention_user_map = {}
for ktid in mention_user_ids:
user = await u.User.get_by_ktid(ktid)
if user:
mention_user_map[ktid] = user.mxid
def _handle_codeblock_pre(
output: list[str], codeblock: bool, line: str
) -> tuple[bool, str, tuple[str | None, str | None, str | None]]:
cb = line.find("```")
cb_lang = None
cb_content = None
post_cb_content = None
if cb != -1:
if not codeblock:
cb_lang = line[cb + 3 :]
if "```" in cb_lang:
end = cb_lang.index("```")
cb_content = cb_lang[:end]
post_cb_content = cb_lang[end + 3 :]
cb_lang = ""
else:
puppet = await pu.Puppet.get_by_ktid(ktid, create=False)
mention_user_map[ktid] = puppet.mxid if puppet else None
codeblock = True
line = line[:cb]
else:
output.append("</code></pre>")
codeblock = False
line = line[cb + 3 :]
return codeblock, line, (cb_lang, cb_content, post_cb_content)
if mention_user_map:
def _mention_replacer(match: Match) -> str:
mxid = mention_user_map[int(match.group(1))]
if not mxid:
return match.group(2)
return f'<a href="https://matrix.to/#/{mxid}">{match.group(2)}</a>'
content.format = Format.HTML
content.formatted_body = MENTION_REGEX.sub(_mention_replacer, text).replace("\n", "<br/>\n")
def _handle_codeblock_post(
output: list[str], cb_lang: str | None, cb_content: str | None, post_cb_content: str | None
) -> None:
if cb_lang is not None:
if cb_lang:
output.append(f'<pre><code class="language-{cb_lang}">')
else:
output.append("<pre><code>")
if cb_content:
output.append(cb_content)
output.append("</code></pre>")
output.append(_convert_formatting(post_cb_content))
async def kakaotalk_to_matrix(msg: str) -> TextMessageEventContent:
text = msg or ""
mentions = []
content = TextMessageEventContent(msgtype=MessageType.TEXT, body=text)
mention_user_ids = []
for m in reversed(mentions):
original = text[m.offset : m.offset + m.length]
if len(original) > 0 and original[0] == "@":
original = original[1:]
mention_user_ids.append(int(m.user_id))
text = f"{text[:m.offset]}@{m.user_id}\u2063{original}\u2063{text[m.offset + m.length:]}"
html = escape(text)
output = []
if html:
codeblock = False
blockquote = False
line: str
lines = html.split("\n")
for i, line in enumerate(lines):
blockquote, line = _handle_blockquote(output, blockquote, line)
codeblock, line, post_args = _handle_codeblock_pre(output, codeblock, line)
output.append(_convert_formatting(line))
if i != len(lines) - 1:
if codeblock:
output.append("\n")
else:
output.append("<br/>")
_handle_codeblock_post(output, *post_args)
html = "".join(output)
mention_user_map = {}
for ktid in mention_user_ids:
user = await u.User.get_by_ktid(ktid)
if user:
mention_user_map[ktid] = user.mxid
else:
puppet = await pu.Puppet.get_by_ktid(ktid, create=False)
mention_user_map[ktid] = puppet.mxid if puppet else None
def _mention_replacer(match: Match) -> str:
mxid = mention_user_map[int(match.group(1))]
if not mxid:
return match.group(2)
return f'<a href="https://matrix.to/#/{mxid}">{match.group(2)}</a>'
html = MENTION_REGEX.sub(_mention_replacer, html)
if html != escape(content.body).replace("\n", "<br/>\n"):
content.format = Format.HTML
content.formatted_body = html
return content

View File

@ -17,35 +17,29 @@ from __future__ import annotations
from typing import NamedTuple
from mautrix.appservice import IntentAPI
from mautrix.types import Format, MessageEventContent, RelationType, RoomID, UserID
from mautrix.types import Format, MessageEventContent, RelationType, RoomID
from mautrix.util.formatter import (
EntityString,
EntityType,
MarkdownString,
MatrixParser,
MatrixParser as BaseMatrixParser,
SimpleEntity,
)
from mautrix.util.logging import TraceLogger
from ..kt.types.bson import Long
from ..kt.types.chat import KnownChatType
from ..kt.types.chat.attachment import ReplyAttachment, MentionStruct
from ..kt.client.types import TO_MSGTYPE_MAP
from .. import puppet as pu, user as u
from ..db import Message as DBMessage
class SendParams(NamedTuple):
text: str
mentions: list[MentionStruct] | None
reply_to: ReplyAttachment
mentions: list[None]
reply_to: str
class KakaoTalkFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
def format(self, entity_type: EntityType, **kwargs) -> KakaoTalkFormatString:
class FacebookFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
def format(self, entity_type: EntityType, **kwargs) -> FacebookFormatString:
prefix = suffix = ""
if entity_type == EntityType.USER_MENTION:
self.entities.append(
SimpleEntity(
@ -55,110 +49,72 @@ class KakaoTalkFormatString(EntityString[SimpleEntity, EntityType], MarkdownStri
extra_info={"user_id": kwargs["user_id"]},
)
)
self.text = f"@{self.text}"
return self
elif entity_type == EntityType.BOLD:
prefix = suffix = "*"
elif entity_type == EntityType.ITALIC:
prefix = suffix = "_"
elif entity_type == EntityType.STRIKETHROUGH:
prefix = suffix = "~"
elif entity_type == EntityType.URL:
if kwargs["url"] != self.text:
suffix = f" ({kwargs['url']})"
elif entity_type == EntityType.PREFORMATTED:
prefix = f"```{kwargs['language']}\n"
suffix = "\n```"
elif entity_type == EntityType.INLINE_CODE:
prefix = suffix = "`"
elif entity_type == EntityType.BLOCKQUOTE:
children = self.trim().split("\n")
children = [child.prepend("> ") for child in children]
return self.join(children, "\n")
elif entity_type == EntityType.HEADER:
prefix = "#" * kwargs["size"] + " "
else:
return self
self._offset_entities(len(prefix))
self.text = f"{prefix}{self.text}{suffix}"
return self
class ToKakaoTalkParser(MatrixParser[KakaoTalkFormatString]):
fs = KakaoTalkFormatString
async def _get_id_from_mxid(mxid: UserID) -> Long | None:
user = await u.User.get_by_mxid(mxid, create=False)
if user and user.ktid:
return user.ktid
else:
puppet = await pu.Puppet.get_by_mxid(mxid, create=False)
return puppet.ktid if puppet else None
class MatrixParser(BaseMatrixParser[FacebookFormatString]):
fs = FacebookFormatString
async def matrix_to_kakaotalk(
content: MessageEventContent,
room_id: RoomID,
log: TraceLogger,
intent: IntentAPI,
skip_reply: bool = False
content: MessageEventContent, room_id: RoomID, log: TraceLogger
) -> SendParams:
# NOTE By design, this *throws* if user intent can't be matched (i.e. if a reply can't be created)
if content.relates_to.rel_type == RelationType.REPLY and not skip_reply:
mentions = []
reply_to = None
if content.relates_to.rel_type == RelationType.REPLY:
message = await DBMessage.get_by_mxid(content.relates_to.event_id, room_id)
if not message:
raise ValueError(
f"Couldn't find reply target {content.relates_to.event_id}"
" to bridge text message reply metadata to KakaoTalk"
)
try:
src_event = await intent.get_event(room_id, message.mxid)
except:
log.exception(f"Failed to find Matrix event for reply target {message.mxid}")
raise
src_kt_sender = await _get_id_from_mxid(src_event.sender)
if src_kt_sender is None:
raise ValueError(
f"Found no KakaoTalk user ID for reply target sender {src_event.sender}"
)
content.trim_reply_fallback()
src_converted = await matrix_to_kakaotalk(src_event.content, room_id, log, intent, skip_reply=True)
if src_event.content.relates_to.rel_type == RelationType.REPLY:
src_type = KnownChatType.REPLY
src_message = src_converted.text
if message:
content.trim_reply_fallback()
reply_to = message.ktid
else:
src_type = TO_MSGTYPE_MAP[src_event.content.msgtype]
if src_type == KnownChatType.FILE:
src_message = _media_type_reply_body_map[KnownChatType.FILE] + src_converted.text
else:
src_message = _media_type_reply_body_map.get(src_type, src_converted.text)
reply_to = ReplyAttachment(
# NOTE mentions will be merged into this later
# TODO Set this for emoticon reply, but must first support them
attach_only=False,
# TODO If replying with media works, must set type AND all attachment properties
# But then, the reply object must be an intersection of a ReplyAttachment and something else
#attach_type=TO_MSGTYPE_MAP.get(content.msgtype),
# TODO Confirm why official client sets this to 0, and whether this should be left as None instead
attach_type=0,
src_logId=message.ktid,
src_mentions=src_converted.mentions or [],
src_message=src_message,
src_type=src_type,
src_userId=src_kt_sender,
)
else:
reply_to = None
if content.get("format", None) == Format.HTML and content["formatted_body"] and content.msgtype.is_text:
parsed = await ToKakaoTalkParser().parse(content["formatted_body"])
log.warning(
f"Couldn't find reply target {content.relates_to.event_id}"
" to bridge text message reply metadata to Facebook"
)
if content.get("format", None) == Format.HTML and content["formatted_body"]:
parsed = await MatrixParser().parse(content["formatted_body"])
text = parsed.text
mentions_by_user: dict[Long, MentionStruct] = {}
# Make sure to not create remote mentions for any remote user not in the room
if parsed.entities:
joined_members = set(await intent.get_room_members(room_id))
last_offset = 0
at = 0
for mention in sorted(parsed.entities, key=lambda entity: entity.offset):
mxid = mention.extra_info["user_id"]
if mxid not in joined_members:
mentions = []
for mention in parsed.entities:
mxid = mention.extra_info["user_id"]
user = await u.User.get_by_mxid(mxid, create=False)
if user and user.ktid:
ktid = user.ktid
else:
puppet = await pu.Puppet.get_by_mxid(mxid, create=False)
if puppet:
ktid = puppet.ktid
else:
continue
ktid = await _get_id_from_mxid(mxid)
if ktid is None:
continue
at += text[last_offset:mention.offset+1].count("@")
last_offset = mention.offset+1
mention = mentions_by_user.setdefault(ktid, MentionStruct(
at=[],
len=mention.length,
user_id=ktid,
))
mention.at.append(at)
mentions = list(mentions_by_user.values()) if mentions_by_user else None
#mentions.append(
# Mention(user_id=str(ktid), offset=mention.offset, length=mention.length)
#)
else:
text = content.body
mentions = None
return SendParams(text=text, mentions=mentions, reply_to=reply_to)
_media_type_reply_body_map: dict[KnownChatType, str] = {
KnownChatType.PHOTO: "Photo",
KnownChatType.VIDEO: "Video",
KnownChatType.AUDIO: "Voice Note",
KnownChatType.FILE: "File: ",
}

View File

@ -22,8 +22,7 @@ with any other potential backend.
from __future__ import annotations
from typing import TYPE_CHECKING, cast, ClassVar, Type, Optional, Union
import asyncio
from typing import TYPE_CHECKING, cast, Type, Optional, Union
from contextlib import asynccontextmanager
import logging
import urllib.request
@ -42,7 +41,6 @@ from ..types.api.struct import FriendListStruct
from ..types.bson import Long
from ..types.client.client_session import LoginResult
from ..types.chat import Chatlog, KnownChatType
from ..types.chat.attachment import MentionStruct, ReplyAttachment
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes
from ..types.request import (
@ -65,7 +63,7 @@ except ImportError:
if TYPE_CHECKING:
from mautrix.types import JSON
from ... import user as u
from ...user import User
@asynccontextmanager
@ -80,22 +78,15 @@ class Client:
_rpc_client: RPCClient
@classmethod
def init_cls(cls, config: Config) -> None:
async def init_cls(cls, config: Config) -> None:
"""Initialize RPC to the Node backend."""
cls._rpc_client = RPCClient(config)
# NOTE No need to store this, as cancelling the RPCClient will cancel this too
asyncio.create_task(cls._keep_connected())
await cls._rpc_client.connect()
@classmethod
async def _keep_connected(cls) -> None:
while True:
await cls._rpc_client.connect()
await cls._rpc_client.wait_for_disconnection()
@classmethod
def stop_cls(cls) -> None:
async def stop_cls(cls) -> None:
"""Stop and disconnect from the Node backend."""
cls._rpc_client.cancel()
await cls._rpc_client.disconnect()
# region tokenless commands
@ -132,15 +123,12 @@ class Client:
# endregion
user: u.User
_rpc_disconnection_task: asyncio.Task | None
http: ClientSession
log: TraceLogger
def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
def __init__(self, user: User, log: Optional[TraceLogger] = None):
"""Create a per-user client object for user-specific client functionality."""
self.user = user
self._rpc_disconnection_task = None
# TODO Let the Node backend use a proxy too!
connector = None
@ -199,27 +187,13 @@ class Client:
Receive the user's profile info in response.
"""
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "start")
if not self._rpc_disconnection_task:
self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler())
else:
self.log.warning("Called \"start\" on an already-started client")
return profile_req_struct.profile
async def stop(self) -> None:
"""Immediately stop bridging this user."""
self._stop_listen()
if self._rpc_disconnection_task:
self._rpc_disconnection_task.cancel()
else:
self.log.warning("Called \"stop\" on an already-stopped client")
await self._rpc_client.request("stop", mxid=self.user.mxid)
async def _rpc_disconnection_handler(self) -> None:
await self._rpc_client.wait_for_disconnection()
self._rpc_disconnection_task = None
self._stop_listen()
asyncio.create_task(self.user.on_client_disconnect())
async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens."""
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
@ -299,20 +273,12 @@ class Client:
await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid)
)
async def send_message(
self,
channel_props: ChannelProps,
text: str,
reply_to: ReplyAttachment | None,
mentions: list[MentionStruct] | None,
) -> Chatlog:
async def send_message(self, channel_props: ChannelProps, text: str) -> Chatlog:
return await self._api_user_request_result(
Chatlog,
"send_chat",
channel_props=channel_props.serialize(),
text=text,
reply_to=reply_to.serialize() if reply_to is not None else None,
mentions=[m.serialize() for m in mentions] if mentions is not None else None,
)
async def send_media(

View File

@ -25,11 +25,11 @@ from .mention import MentionStruct
@dataclass(kw_only=True)
class ReplyAttachment(Attachment):
attach_only: bool = None # NOTE Made optional
attach_type: Optional[ChatType] = None # NOTE Changed from int for outgoing typeless replies
attach_only: bool
attach_type: int
src_linkId: Optional[Long] = None
src_logId: Long
src_mentions: list[MentionStruct] = None # NOTE Made optional
src_mentions: list[MentionStruct]
src_message: str
src_type: ChatType
src_userId: Long

View File

@ -94,12 +94,15 @@ ResultType = TypeVar("ResultType", bound=Serializable)
def ResultListType(result_type: Type[ResultType]):
class _ResultListType(list[result_type], Serializable):
def __init__(self, iterable: Iterable[result_type]=()):
list.__init__(self, (result_type.deserialize(x) for x in iterable))
def serialize(self) -> list[JSON]:
return [v.serialize() for v in self]
@classmethod
def deserialize(cls, data: list[JSON]) -> "_ResultListType":
return [result_type.deserialize(item) for item in data]
return cls(data)
return _ResultListType

View File

@ -46,7 +46,6 @@ from mautrix.types import (
Membership,
MessageEventContent,
MessageType,
RelationType,
RoomID,
TextMessageEventContent,
UserID,
@ -69,7 +68,6 @@ from .kt.types.channel.channel_info import ChannelInfo
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
from .kt.types.chat import Chatlog, ChatType, KnownChatType
from .kt.types.chat.attachment import (
Attachment,
AudioAttachment,
#FileAttachment,
MediaAttachment,
@ -772,8 +770,6 @@ class Portal(DBPortal, BasePortal):
) -> None:
if message.get_edit():
raise NotImplementedError("Edits are not supported by the KakaoTalk bridge.")
if message.relates_to.rel_type == RelationType.REPLY and not message.msgtype.is_text:
raise NotImplementedError("Replying with non-text content is not supported by the KakaoTalk bridge.")
sender, is_relay = await self.get_relay_sender(orig_sender, f"message {event_id}")
if not sender:
raise Exception("not logged in")
@ -806,13 +802,14 @@ class Portal(DBPortal, BasePortal):
async def _handle_matrix_text(
self, event_id: EventID, sender: u.User, message: TextMessageEventContent
) -> None:
converted = await matrix_to_kakaotalk(message, self.mxid, self.log, self.main_intent)
converted = await matrix_to_kakaotalk(message, self.mxid, self.log)
try:
chatlog = await sender.client.send_message(
self.channel_props,
text=converted.text,
reply_to=converted.reply_to,
mentions=converted.mentions,
# TODO
#mentions=converted.mentions,
#reply_to=converted.reply_to,
)
except CommandException as e:
self.log.debug(f"Error handling Matrix message {event_id}: {e!s}")
@ -840,6 +837,18 @@ class Portal(DBPortal, BasePortal):
else:
raise NotImplementedError("No file or URL specified")
mimetype = message.info.mimetype or magic.mimetype(data)
""" TODO Replies
reply_to = None
if message.relates_to.rel_type == RelationType.REPLY:
reply_to_msg = await DBMessage.get_by_mxid(message.relates_to.event_id, self.mxid)
if reply_to_msg:
reply_to = reply_to_msg.ktid
else:
self.log.warning(
f"Couldn't find reply target {message.relates_to.event_id}"
" to bridge media message reply metadata to KakaoTalk"
)
"""
filename = message.body
width, height = None, None
if message.info in (MessageType.IMAGE, MessageType.STICKER, MessageType.VIDEO):
@ -854,6 +863,8 @@ class Portal(DBPortal, BasePortal):
width=width,
height=height,
ext=guess_extension(mimetype)[1:],
# TODO
#reply_to=reply_to,
)
except CommandException as e:
self.log.debug(f"Error uploading media for Matrix message {event_id}: {e!s}")
@ -1082,12 +1093,12 @@ class Portal(DBPortal, BasePortal):
async def _handle_remote_text(
self,
intent: IntentAPI,
attachment: Attachment | None,
timestamp: int,
message_text: str | None,
**_
) -> list[EventID]:
content = await kakaotalk_to_matrix(message_text, attachment.mentions if attachment else None)
# TODO Handle mentions properly
content = await kakaotalk_to_matrix(message_text)
return [await self._send_message(intent, content, timestamp=timestamp)]
async def _handle_remote_reply(
@ -1098,7 +1109,7 @@ class Portal(DBPortal, BasePortal):
message_text: str,
**_
) -> list[EventID]:
content = await kakaotalk_to_matrix(message_text, attachment.mentions)
content = await kakaotalk_to_matrix(message_text)
await self._add_remote_reply(content, attachment)
return [await self._send_message(intent, content, timestamp=timestamp)]

View File

@ -29,46 +29,6 @@ from .types import RPCError
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
class CancelableEvent:
_event: asyncio.Event
_task: asyncio.Task | None
_cancelled: bool
_loop: asyncio.AbstractEventLoop
def __init__(self, loop: asyncio.AbstractEventLoop | None):
self._event = asyncio.Event()
self._task = None
self._cancelled = False
self._loop = loop or asyncio.get_running_loop()
def is_set(self) -> bool:
return self._event.is_set()
def set(self) -> None:
self._event.set()
self._task = None
def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
if self._cancelled:
raise asyncio.CancelledError()
if self._event.is_set():
return
if not self._task:
self._task = asyncio.create_task(self._event.wait())
await self._task
def cancel(self) -> None:
self._cancelled = True
if self._task is not None:
self._task.cancel()
def cancelled(self) -> bool:
return self._cancelled
class RPCClient:
config: Config
loop: asyncio.AbstractEventLoop
@ -81,11 +41,6 @@ class RPCClient:
_response_waiters: dict[int, asyncio.Future[JSON]]
_event_handlers: dict[str, list[EventHandler]]
_command_queue: asyncio.Queue
_read_task: asyncio.Task | None
_connection_task: asyncio.Task | None
_is_connected: CancelableEvent
_is_disconnected: CancelableEvent
_connection_lock: asyncio.Lock
def __init__(self, config: Config) -> None:
self.config = config
@ -97,34 +52,16 @@ class RPCClient:
self._writer = None
self._reader = None
self._command_queue = asyncio.Queue()
self.loop.create_task(self._command_loop())
self._read_task = None
self._connection_task = None
self._is_connected = CancelableEvent(self.loop)
self._is_disconnected = CancelableEvent(self.loop)
self._is_disconnected.set()
self._connection_lock = asyncio.Lock()
async def connect(self) -> None:
async with self._connection_lock:
if self._is_connected.cancelled():
raise asyncio.CancelledError()
if self._is_connected.is_set():
return
self._connection_task = self.loop.create_task(self._connect())
try:
await self._connection_task
finally:
self._connection_task = None
if self._writer is not None:
return
async def _connect(self) -> None:
if self.config["rpc.connection.type"] == "unix":
while True:
try:
r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
await asyncio.sleep(10)
@ -134,8 +71,6 @@ class RPCClient:
r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
self.config["rpc.connection.port"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
await asyncio.sleep(10)
@ -143,46 +78,16 @@ class RPCClient:
raise RuntimeError("invalid rpc connection type")
self._reader = r
self._writer = w
self._read_task = self.loop.create_task(self._try_read_loop())
self._is_connected.set()
self._is_disconnected.clear()
self.loop.create_task(self._try_read_loop())
self.loop.create_task(self._command_loop())
await self.request("register", peer_id=self.config["appservice.address"])
async def disconnect(self) -> None:
async with self._connection_lock:
if self._is_disconnected.cancelled():
raise asyncio.CancelledError()
if self._is_disconnected.is_set():
return
await self._disconnect()
async def _disconnect(self) -> None:
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
if self._read_task is not None:
self._read_task.cancel()
self._read_task = None
self._on_disconnect()
def _on_disconnect(self) -> None:
self._reader = None
self._writer = None
self._is_connected.clear()
self._is_disconnected.set()
def wait_for_connection(self) -> Awaitable[None]:
return self._is_connected.wait()
def wait_for_disconnection(self) -> Awaitable[None]:
return self._is_disconnected.wait()
def cancel(self) -> None:
self._is_connected.cancel()
self._is_disconnected.cancel()
if self._connection_task is not None:
self._connection_task.cancel()
asyncio.run(self._disconnect())
self._writer = None
self._reader = None
@property
def _next_req_id(self) -> int:
@ -214,7 +119,7 @@ class RPCClient:
for handler in handlers:
try:
await handler(req)
except:
except Exception:
self.log.exception("Exception in event handler")
async def _handle_incoming_line(self, line: str) -> None:
@ -257,9 +162,7 @@ class RPCClient:
async def _try_read_loop(self) -> None:
try:
await self._read_loop()
except asyncio.CancelledError:
pass
except:
except Exception:
self.log.exception("Fatal error in read loop")
async def _read_loop(self) -> None:
@ -275,8 +178,6 @@ class RPCClient:
except asyncio.LimitOverrunError as e:
self.log.warning(f"Buffer overrun: {e}")
line += await self._reader.read(self._reader._limit)
except asyncio.CancelledError:
raise
if not line:
continue
try:
@ -286,12 +187,11 @@ class RPCClient:
continue
try:
await self._handle_incoming_line(line_str)
except asyncio.CancelledError:
raise
except:
except Exception:
self.log.exception("Failed to handle incoming request %s", line_str)
self.log.debug("Reader disconnected")
self._on_disconnect()
self._reader = None
self._writer = None
async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]:
req_id = self._next_req_id
@ -305,6 +205,5 @@ class RPCClient:
return future
async def request(self, command: str, **data: JSON) -> JSON:
await self.wait_for_connection()
future = await self._raw_request(command, **data)
return await future

View File

@ -85,7 +85,6 @@ class User(DBUser, BaseUser):
_connection_time: float
_db_instance: DBUser | None
_sync_lock: SimpleLock
_is_rpc_reconnecting: bool
_logged_in_info: ProfileStruct | None
_logged_in_info_time: float
@ -122,7 +121,6 @@ class User(DBUser, BaseUser):
self._sync_lock = SimpleLock(
"Waiting for thread sync to finish before handling %s", log=self.log
)
self._is_rpc_reconnecting = False
self._logged_in_info = None
self._logged_in_info_time = 0
@ -334,8 +332,6 @@ class User(DBUser, BaseUser):
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
finally:
self._is_rpc_reconnecting = False
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self.client:
@ -426,16 +422,7 @@ class User(DBUser, BaseUser):
sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
def get_channel_update_time(login_data: LoginDataItem):
channel_info = login_data.channel.info
return channel_info.lastChatLog.sendAt if channel_info.lastChatLog else 0
for login_data in sorted(
login_result.channelList,
reverse=True,
key=get_channel_update_time
)[:sync_count]:
for login_data in login_result.channelList[:sync_count]:
try:
await self._sync_channel(login_data)
except AuthenticationRequired:
@ -558,8 +545,9 @@ class User(DBUser, BaseUser):
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED
elif self._is_rpc_reconnecting or self.client:
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
# TODO
#elif self._is_logged_in and self._is_reconnecting:
# state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state]
async def get_puppet(self) -> pu.Puppet | None:
@ -594,18 +582,16 @@ class User(DBUser, BaseUser):
# region KakaoTalk event handling
async def on_connect(self) -> None:
self.is_connected = True
self._track_metric(METRIC_CONNECTED, True)
""" TODO Don't auto-resync channels if disconnection was too short
now = time.monotonic()
disconnected_at = self._connection_time
max_delay = self.config["bridge.resync_max_disconnected_time"]
first_connect = self.is_connected is None
self.is_connected = True
self._track_metric(METRIC_CONNECTED, True)
if not first_connect and disconnected_at + max_delay < now:
duration = int(now - disconnected_at)
self.log.debug(f"Disconnection lasted {duration} seconds, not re-syncing channels...")
"""
if self.temp_disconnect_notices:
self.log.debug(f"Disconnection lasted {duration} seconds")
elif self.temp_disconnect_notices:
await self.send_bridge_notice("Connected to KakaoTalk chats")
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
@ -632,19 +618,6 @@ class User(DBUser, BaseUser):
await self.logout()
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
async def on_client_disconnect(self) -> None:
self.is_connected = False
self._track_metric(METRIC_CONNECTED, False)
self.client = None
if self._is_logged_in:
if self.temp_disconnect_notices:
await self.send_bridge_notice(
"Disconnected from KakaoTalk: backend helper module exited. "
"Will reconnect once module resumes."
)
self._is_rpc_reconnecting = True
asyncio.create_task(self.reload_session())
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential

View File

@ -24,10 +24,8 @@ import {
util,
} from "node-kakao"
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
/** @typedef {import("node-kakao").ChannelType} ChannelType */
/** @typedef {import("node-kakao").ReplyAttachment} ReplyAttachment */
/** @typedef {import("node-kakao").MentionStruct} MentionStruct */
/** @typedef {import("node-kakao/dist/talk").TalkChannelList} TalkChannelList */
/** @typedef {import("node-kakao").ChannelType} ChannelType */
import chat from "node-kakao/chat"
const { KnownChatType } = chat
@ -508,16 +506,13 @@ export default class PeerClient {
* @param {string} req.mxid
* @param {Object} req.channel_props
* @param {string} req.text
* @param {?ReplyAttachment} req.reply_to
* @param {?MentionStruct[]} req.mentions
*/
sendChat = async (req) => {
const talkChannel = await this.#getUserChannel(req.mxid, req.channel_props)
return await talkChannel.sendChat({
type: KnownChatType.TEXT,
text: req.text,
type: !!req.reply_to ? KnownChatType.REPLY : KnownChatType.TEXT,
attachment: !req.mentions ? req.reply_to : {...req.reply_to, mentions: req.mentions},
})
}
@ -526,7 +521,7 @@ export default class PeerClient {
* @param {string} req.mxid
* @param {Object} req.channel_props
* @param {int} req.type
* @param {number[]} req.data
* @param {[number]} req.data
* @param {string} req.name
* @param {?int} req.width
* @param {?int} req.height

View File

@ -52,6 +52,9 @@ export default class ClientManager {
} catch (err) {
await fs.promises.mkdir(path.dirname(socketPath), 0o700)
}
try {
await fs.promises.unlink(socketPath)
} catch (err) {}
await promisify(cb => this.server.listen(socketPath, cb))
await fs.promises.chmod(socketPath, 0o700)
this.log("Now listening at", socketPath)