# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge. # Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from __future__ import annotations from typing import Any, Awaitable, Callable, Dict, List, Optional import asyncio import json import logging from mautrix.types.primitive import JSON from ..config import Config from .types import RPCError EventHandler = Callable[[Dict[str, Any]], Awaitable[None]] class CancelableEvent: _event: asyncio.Event _task: Optional[asyncio.Task] _cancelled: bool _loop: asyncio.AbstractEventLoop def __init__(self, loop: Optional[asyncio.AbstractEventLoop]): self._event = asyncio.Event() self._task = None self._cancelled = False self._loop = loop or asyncio.get_running_loop() def is_set(self) -> bool: return self._event.is_set() def set(self) -> None: self._event.set() self._task = None def clear(self) -> None: self._event.clear() async def wait(self) -> None: if self._cancelled: raise asyncio.CancelledError() if self._event.is_set(): return if not self._task: self._task = asyncio.create_task(self._event.wait()) await self._task def cancel(self) -> None: self._cancelled = True if self._task is not None: self._task.cancel() def cancelled(self) -> bool: return self._cancelled class RPCClient: config: Config loop: asyncio.AbstractEventLoop log: logging.Logger = logging.getLogger("mau.rpc") _reader: Optional[asyncio.StreamReader] _writer: Optional[asyncio.StreamWriter] _req_id: int _min_broadcast_id: int _response_waiters: Dict[int, asyncio.Future[JSON]] _event_handlers: Dict[str, List[EventHandler]] _command_queue: asyncio.Queue _read_task: Optional[asyncio.Task] _connection_task: Optional[asyncio.Task] _is_connected: CancelableEvent _is_disconnected: CancelableEvent _connection_lock: asyncio.Lock def __init__(self, config: Config, register_config_key: str) -> None: self.config = config self.register_config_key = register_config_key self.loop = asyncio.get_running_loop() self._req_id = 0 self._min_broadcast_id = 0 self._event_handlers = {} self._response_waiters = {} self._writer = None self._reader = None self._command_queue = asyncio.Queue() self.loop.create_task(self._command_loop()) self._read_task = None self._connection_task = None self._is_connected = CancelableEvent(self.loop) self._is_disconnected = CancelableEvent(self.loop) self._is_disconnected.set() self._connection_lock = asyncio.Lock() async def connect(self) -> None: async with self._connection_lock: if self._is_connected.cancelled(): raise asyncio.CancelledError() if self._is_connected.is_set(): return self._connection_task = self.loop.create_task(self._connect()) try: await self._connection_task finally: self._connection_task = None async def _connect(self) -> None: if self.config["rpc.connection.type"] == "unix": while True: try: r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"]) break except asyncio.CancelledError: raise except: self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...') await asyncio.sleep(10) elif self.config["rpc.connection.type"] == "tcp": while True: try: r, w = await asyncio.open_connection(self.config["rpc.connection.host"], self.config["rpc.connection.port"]) break except asyncio.CancelledError: raise except: self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...') await asyncio.sleep(10) else: raise RuntimeError("invalid rpc connection type") self._reader = r self._writer = w self._read_task = self.loop.create_task(self._try_read_loop()) await self._raw_request("register", peer_id=self.config["appservice.address"], register_config=self.config[self.register_config_key]) self._is_connected.set() self._is_disconnected.clear() async def disconnect(self) -> None: async with self._connection_lock: if self._is_disconnected.cancelled(): raise asyncio.CancelledError() if self._is_disconnected.is_set(): return await self._disconnect() async def _disconnect(self) -> None: if self._writer is not None: self._writer.write_eof() await self._writer.drain() if self._read_task is not None: self._read_task.cancel() self._read_task = None self._on_disconnect() def _on_disconnect(self) -> None: self._req_id = 0 self._min_broadcast_id = 0 self._reader = None self._writer = None self._is_connected.clear() self._is_disconnected.set() def wait_for_connection(self) -> Awaitable[None]: return self._is_connected.wait() def wait_for_disconnection(self) -> Awaitable[None]: return self._is_disconnected.wait() def cancel(self) -> None: self._is_connected.cancel() self._is_disconnected.cancel() if self._connection_task is not None: self._connection_task.cancel() asyncio.run(self._disconnect()) @property def _next_req_id(self) -> int: self._req_id += 1 return self._req_id def add_event_handler(self, method: str, handler: EventHandler) -> None: self._event_handlers.setdefault(method, []).append(handler) def remove_event_handler(self, method: str, handler: EventHandler) -> None: try: self._event_handlers.setdefault(method, []).remove(handler) except ValueError: pass def set_event_handlers(self, method: str, handlers: List[EventHandler]) -> None: self._event_handlers[method] = handlers async def _run_event_handler(self, req_id: int, command: str, req: Dict[str, Any]) -> None: if req_id > self._min_broadcast_id: self.log.debug(f"Ignoring duplicate broadcast {req_id}") return self._min_broadcast_id = req_id try: handlers = self._event_handlers[command] except KeyError: self.log.warning("No handlers for %s", command) else: for handler in handlers: try: await handler(req) except: self.log.exception("Exception in event handler") async def _handle_incoming_line(self, line: str) -> None: try: req = json.loads(line) except json.JSONDecodeError: self.log.debug(f"Got non-JSON data from server: {line}") return try: req_id = req.pop("id") command = req.pop("command") is_sequential = req.pop("is_sequential", False) except KeyError: self.log.debug(f"Got invalid request from server: {line}") return if req_id < 0: if not is_sequential: self.loop.create_task(self._run_event_handler(req_id, command, req)) else: self._command_queue.put_nowait((req_id, command, req)) return try: waiter = self._response_waiters[req_id] except KeyError: self.log.debug(f"Nobody waiting for response to {req_id}") return if command == "response": self.log.debug("Received response %d", req_id) waiter.set_result(req.get("response")) elif command == "error": waiter.set_exception(RPCError(req.get("error", line))) else: self.log.warning(f"Unexpected response command to {req_id}: {command} {req}") async def _command_loop(self) -> None: while True: req_id, command, req = await self._command_queue.get() await self._run_event_handler(req_id, command, req) self._command_queue.task_done() async def _try_read_loop(self) -> None: try: await self._read_loop() except asyncio.CancelledError: return except: self.log.exception("Fatal error in read loop") self.log.debug("Reader disconnected") self._on_disconnect() async def _read_loop(self) -> None: while self._reader is not None and not self._reader.at_eof(): line = b'' while True: try: line += await self._reader.readuntil() break except asyncio.IncompleteReadError as e: line += e.partial break except asyncio.LimitOverrunError as e: self.log.warning(f"Buffer overrun: {e}") line += await self._reader.read(e.consumed) except asyncio.CancelledError: raise if not line: continue try: line_str = line.decode("utf-8") except UnicodeDecodeError: self.log.exception("Got non-unicode request from server: %s", line) continue try: await self._handle_incoming_line(line_str) except asyncio.CancelledError: raise except: self.log.exception("Failed to handle incoming request %s", line_str) async def _raw_request(self, command: str, **data: JSON) -> asyncio.Future[JSON]: req_id = self._next_req_id future = self._response_waiters[req_id] = self.loop.create_future() req = {"id": req_id, "command": command, **data} self.log.debug("Request %d: %s", req_id, command) assert self._writer is not None self._writer.write(json.dumps(req).encode("utf-8")) self._writer.write(b"\n") await self._writer.drain() return future async def request(self, command: str, **data: JSON) -> JSON: await self.wait_for_connection() future = await self._raw_request(command, **data) return await future