# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. from __future__ import annotations from typing import Any, Callable, Awaitable import asyncio import json import logging from mautrix.types.primitive import JSON from ..config import Config from .types import RPCError EventHandler = Callable[[dict[str, Any]], Awaitable[None]] class CancelableEvent: _event: asyncio.Event _task: asyncio.Task | None _cancelled: bool _loop: asyncio.AbstractEventLoop def __init__(self, loop: asyncio.AbstractEventLoop | None): self._event = asyncio.Event() self._task = None self._cancelled = False self._loop = loop or asyncio.get_running_loop() def is_set(self) -> bool: return self._event.is_set() def set(self) -> None: self._event.set() self._task = None def clear(self) -> None: self._event.clear() async def wait(self) -> None: if self._cancelled: raise asyncio.CancelledError() if self._event.is_set(): return if not self._task: self._task = asyncio.create_task(self._event.wait()) await self._task def cancel(self) -> None: self._cancelled = True if self._task is not None: self._task.cancel() def cancelled(self) -> bool: return self._cancelled class RPCClient: config: Config loop: asyncio.AbstractEventLoop log: logging.Logger = logging.getLogger("mau.rpc") _reader: asyncio.StreamReader | None _writer: asyncio.StreamWriter | None _req_id: int _min_broadcast_id: int _response_waiters: dict[int, asyncio.Future[JSON]] _event_handlers: dict[str, list[EventHandler]] _command_queue: asyncio.Queue _read_task: asyncio.Task | None _connection_task: asyncio.Task | None _is_connected: CancelableEvent _is_disconnected: CancelableEvent _connection_lock: asyncio.Lock def __init__(self, config: Config, register_config_key: str) -> None: self.config = config self.register_config_key = register_config_key self.loop = asyncio.get_running_loop() self._req_id = 0 self._min_broadcast_id = 0 self._event_handlers = {} self._response_waiters = {} self._writer = None self._reader = None self._command_queue = asyncio.Queue() self.loop.create_task(self._command_loop()) self._read_task = None self._connection_task = None self._is_connected = CancelableEvent(self.loop) self._is_disconnected = CancelableEvent(self.loop) self._is_disconnected.set() self._connection_lock = asyncio.Lock() async def connect(self) -> None: async with self._connection_lock: if self._is_connected.cancelled(): raise asyncio.CancelledError() if self._is_connected.is_set(): return self._connection_task = self.loop.create_task(self._connect()) try: await self._connection_task finally: self._connection_task = None async def _connect(self) -> None: if self.config["rpc.connection.type"] == "unix": while True: try: r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"]) break except asyncio.CancelledError: raise except: self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...') await asyncio.sleep(10) elif self.config["rpc.connection.type"] == "tcp": while True: try: r, w = await asyncio.open_connection(self.config["rpc.connection.host"], self.config["rpc.connection.port"]) break except asyncio.CancelledError: raise except: self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...') await asyncio.sleep(10) else: raise RuntimeError("invalid rpc connection type") self._reader = r self._writer = w self._read_task = self.loop.create_task(self._try_read_loop()) await self._raw_request("register", peer_id=self.config["appservice.address"], register_config=self.config[self.register_config_key]) self._is_connected.set() self._is_disconnected.clear() async def disconnect(self) -> None: async with self._connection_lock: if self._is_disconnected.cancelled(): raise asyncio.CancelledError() if self._is_disconnected.is_set(): return await self._disconnect() async def _disconnect(self) -> None: if self._writer is not None: self._writer.write_eof() await self._writer.drain() if self._read_task is not None: self._read_task.cancel() self._read_task = None self._on_disconnect() def _on_disconnect(self) -> None: self._req_id = 0 self._min_broadcast_id = 0 self._reader = None self._writer = None self._is_connected.clear() self._is_disconnected.set() def wait_for_connection(self) -> Awaitable[None]: return self._is_connected.wait() def wait_for_disconnection(self) -> Awaitable[None]: return self._is_disconnected.wait() def cancel(self) -> None: self._is_connected.cancel() self._is_disconnected.cancel() if self._connection_task is not None: self._connection_task.cancel() asyncio.run(self._disconnect()) @property def _next_req_id(self) -> int: self._req_id += 1 return self._req_id def add_event_handler(self, method: str, handler: EventHandler) -> None: self._event_handlers.setdefault(method, []).append(handler) def remove_event_handler(self, method: str, handler: EventHandler) -> None: try: self._event_handlers.setdefault(method, []).remove(handler) except ValueError: pass def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None: self._event_handlers[method] = handlers async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None: if req_id > self._min_broadcast_id: self.log.debug(f"Ignoring duplicate broadcast {req_id}") return self._min_broadcast_id = req_id try: handlers = self._event_handlers[command] except KeyError: self.log.warning("No handlers for %s", command) else: for handler in handlers: try: await handler(req) except: self.log.exception("Exception in event handler") async def _handle_incoming_line(self, line: str) -> None: try: req = json.loads(line) except json.JSONDecodeError: self.log.debug(f"Got non-JSON data from server: {line}") return try: req_id = req.pop("id") command = req.pop("command") is_sequential = req.pop("is_sequential", False) except KeyError: self.log.debug(f"Got invalid request from server: {line}") return if req_id < 0: if not is_sequential: self.loop.create_task(self._run_event_handler(req_id, command, req)) else: self._command_queue.put_nowait((req_id, command, req)) return try: waiter = self._response_waiters[req_id] except KeyError: self.log.debug(f"Nobody waiting for response to {req_id}") return if command == "response": self.log.debug("Received response %d", req_id) waiter.set_result(req.get("response")) elif command == "error": waiter.set_exception(RPCError(req.get("error", line))) else: self.log.warning(f"Unexpected response command to {req_id}: {command} {req}") async def _command_loop(self) -> None: while True: req_id, command, req = await self._command_queue.get() await self._run_event_handler(req_id, command, req) self._command_queue.task_done() async def _try_read_loop(self) -> None: try: await self._read_loop() except asyncio.CancelledError: return except: self.log.exception("Fatal error in read loop") self.log.debug("Reader disconnected") self._on_disconnect() async def _read_loop(self) -> None: while self._reader is not None and not self._reader.at_eof(): line = b'' while True: try: line += await self._reader.readuntil() break except asyncio.IncompleteReadError as e: line += e.partial break except asyncio.LimitOverrunError as e: self.log.warning(f"Buffer overrun: {e}") line += await self._reader.read(e.consumed) except asyncio.CancelledError: raise if not line: continue try: line_str = line.decode("utf-8") except UnicodeDecodeError: self.log.exception("Got non-unicode request from server: %s", line) continue try: await self._handle_incoming_line(line_str) except asyncio.CancelledError: raise except: self.log.exception("Failed to handle incoming request %s", line_str) async def _raw_request(self, command: str, **data: JSON) -> asyncio.Future[JSON]: req_id = self._next_req_id future = self._response_waiters[req_id] = self.loop.create_future() req = {"id": req_id, "command": command, **data} self.log.debug("Request %d: %s", req_id, command) assert self._writer is not None self._writer.write(json.dumps(req).encode("utf-8")) self._writer.write(b"\n") await self._writer.drain() return future async def request(self, command: str, **data: JSON) -> JSON: await self.wait_for_connection() future = await self._raw_request(command, **data) return await future