# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. from __future__ import annotations from typing import Any, Callable, Awaitable import asyncio import json import logging from mautrix.types.primitive import JSON from ..config import Config from .types import RPCError EventHandler = Callable[[dict[str, Any]], Awaitable[None]] class RPCClient: config: Config loop: asyncio.AbstractEventLoop log: logging.Logger = logging.getLogger("mau.rpc") _reader: asyncio.StreamReader | None _writer: asyncio.StreamWriter | None _req_id: int _min_broadcast_id: int _response_waiters: dict[int, asyncio.Future[JSON]] _event_handlers: dict[str, list[EventHandler]] _command_queue: asyncio.Queue def __init__(self, config: Config) -> None: self.config = config self.loop = asyncio.get_running_loop() self._req_id = 0 self._min_broadcast_id = 0 self._event_handlers = {} self._response_waiters = {} self._writer = None self._reader = None self._command_queue = asyncio.Queue() async def connect(self) -> None: if self._writer is not None: return if self.config["rpc.connection.type"] == "unix": while True: try: r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"]) break except: self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...') await asyncio.sleep(10) elif self.config["rpc.connection.type"] == "tcp": while True: try: r, w = await asyncio.open_connection(self.config["rpc.connection.host"], self.config["rpc.connection.port"]) break except: self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...') await asyncio.sleep(10) else: raise RuntimeError("invalid rpc connection type") self._reader = r self._writer = w self.loop.create_task(self._try_read_loop()) self.loop.create_task(self._command_loop()) await self.request("register", peer_id=self.config["appservice.address"]) async def disconnect(self) -> None: if self._writer is not None: self._writer.write_eof() await self._writer.drain() self._writer = None self._reader = None @property def _next_req_id(self) -> int: self._req_id += 1 return self._req_id def add_event_handler(self, method: str, handler: EventHandler) -> None: self._event_handlers.setdefault(method, []).append(handler) def remove_event_handler(self, method: str, handler: EventHandler) -> None: try: self._event_handlers.setdefault(method, []).remove(handler) except ValueError: pass def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None: self._event_handlers[method] = handlers async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None: if req_id > self._min_broadcast_id: self.log.debug(f"Ignoring duplicate broadcast {req_id}") return self._min_broadcast_id = req_id try: handlers = self._event_handlers[command] except KeyError: self.log.warning("No handlers for %s", command) else: for handler in handlers: try: await handler(req) except Exception: self.log.exception("Exception in event handler") async def _handle_incoming_line(self, line: str) -> None: try: req = json.loads(line) except json.JSONDecodeError: self.log.debug(f"Got non-JSON data from server: {line}") return try: req_id = req.pop("id") command = req.pop("command") is_sequential = req.pop("is_sequential", False) except KeyError: self.log.debug(f"Got invalid request from server: {line}") return if req_id < 0: if not is_sequential: self.loop.create_task(self._run_event_handler(req_id, command, req)) else: self._command_queue.put_nowait((req_id, command, req)) return try: waiter = self._response_waiters[req_id] except KeyError: self.log.debug(f"Nobody waiting for response to {req_id}") return if command == "response": waiter.set_result(req.get("response")) elif command == "error": waiter.set_exception(RPCError(req.get("error", line))) else: self.log.warning(f"Unexpected response command to {req_id}: {command} {req}") async def _command_loop(self) -> None: while True: req_id, command, req = await self._command_queue.get() await self._run_event_handler(req_id, command, req) self._command_queue.task_done() async def _try_read_loop(self) -> None: try: await self._read_loop() except Exception: self.log.exception("Fatal error in read loop") async def _read_loop(self) -> None: while self._reader is not None and not self._reader.at_eof(): line = b'' while True: try: line += await self._reader.readuntil() break except asyncio.IncompleteReadError as e: line += e.partial break except asyncio.LimitOverrunError as e: self.log.warning(f"Buffer overrun: {e}") line += await self._reader.read(self._reader._limit) if not line: continue try: line_str = line.decode("utf-8") except UnicodeDecodeError: self.log.exception("Got non-unicode request from server: %s", line) continue try: await self._handle_incoming_line(line_str) except Exception: self.log.exception("Failed to handle incoming request %s", line_str) self.log.debug("Reader disconnected") self._reader = None self._writer = None async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]: req_id = self._next_req_id future = self._response_waiters[req_id] = self.loop.create_future() req = {"id": req_id, "command": command, **data} self.log.debug("Request %d: %s %s", req_id, command, data if not is_secret else "<REDACTED>") assert self._writer is not None self._writer.write(json.dumps(req).encode("utf-8")) self._writer.write(b"\n") await self._writer.drain() return future async def request(self, command: str, **data: JSON) -> JSON: future = await self._raw_request(command, **data) return await future