# matrix-puppeteer-line - A very hacky Matrix-LINE bridge based on running LINE's Chrome extension in Puppeteer # Copyright (C) 2020-2022 Tulir Asokan, Andrew Ferrazzutti # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Dict, Any, Callable, Awaitable, List, Optional import logging import asyncio import json from mautrix.types import UserID from mautrix.util.logging import TraceLogger from ..config import Config from .types import RPCError EventHandler = Callable[[Dict[str, Any]], Awaitable[None]] class RPCClient: config: Config loop: asyncio.AbstractEventLoop log: TraceLogger = logging.getLogger("mau.rpc") user_id: UserID ephemeral_events: bool _reader: Optional[asyncio.StreamReader] _writer: Optional[asyncio.StreamWriter] _req_id: int _min_broadcast_id: int _response_waiters: Dict[int, asyncio.Future] _event_handlers: Dict[str, List[EventHandler]] def __init__(self, user_id: UserID, own_id: str, ephemeral_events: bool) -> None: self.log = self.log.getChild(user_id) self.loop = asyncio.get_running_loop() self.user_id = user_id self.own_id = own_id self.ephemeral_events = ephemeral_events self._req_id = 0 self._min_broadcast_id = 0 self._event_handlers = {} self._response_waiters = {} self._writer = None self._reader = None self._command_queue = asyncio.Queue() async def connect(self) -> None: if self._writer is not None: return if self.config["puppeteer.connection.type"] == "unix": while True: try: r, w = await asyncio.open_unix_connection(self.config["puppeteer.connection.path"]) break except: self.log.warn(f'No unix socket available at {self.config["puppeteer.connection.path"]}, wait for it to exist...') await asyncio.sleep(10) elif self.config["puppeteer.connection.type"] == "tcp": while True: try: r, w = await asyncio.open_connection(self.config["puppeteer.connection.host"], self.config["puppeteer.connection.port"]) break except: self.log.warn(f'No TCP connection open at {self.config["puppeteer.connection.host"]}:{self.config["puppeteer.connection.path"]}, wait for it to become available...') await asyncio.sleep(10) else: raise RuntimeError("invalid puppeteer connection type") self._reader = r self._writer = w self.loop.create_task(self._try_read_loop()) self.loop.create_task(self._command_loop()) await self.request("register", user_id=self.user_id, own_id = self.own_id, ephemeral_events=self.ephemeral_events) async def disconnect(self) -> None: self._writer.write_eof() await self._writer.drain() self._writer = None self._reader = None @property def _next_req_id(self) -> int: self._req_id += 1 return self._req_id def add_event_handler(self, method: str, handler: EventHandler) -> None: self._event_handlers.setdefault(method, []).append(handler) def remove_event_handler(self, method: str, handler: EventHandler) -> None: self._event_handlers.setdefault(method, []).remove(handler) async def _run_event_handler(self, req_id: int, command: str, req: Dict[str, Any]) -> None: if req_id > self._min_broadcast_id: self.log.debug(f"Ignoring duplicate broadcast {req_id}") return self._min_broadcast_id = req_id try: handlers = self._event_handlers[command] except KeyError: self.log.warning("No handlers for %s", command) else: for handler in handlers: try: await handler(req) except Exception: self.log.exception("Exception in event handler") async def _handle_incoming_line(self, line: str) -> None: try: req = json.loads(line) except json.JSONDecodeError: self.log.debug(f"Got non-JSON data from server: {line}") return try: req_id = req.pop("id") command = req.pop("command") is_sequential = req.pop("is_sequential", False) except KeyError: self.log.debug(f"Got invalid request from server: {line}") return if req_id < 0: if not is_sequential: self.loop.create_task(self._run_event_handler(req_id, command, req)) else: self._command_queue.put_nowait((req_id, command, req)) return try: waiter = self._response_waiters[req_id] except KeyError: self.log.debug(f"Nobody waiting for response to {req_id}") return if command == "response": waiter.set_result(req.get("response")) elif command == "error": waiter.set_exception(RPCError(req.get("error", line))) else: self.log.warning(f"Unexpected response command to {req_id}: {command} {req}") async def _command_loop(self) -> None: while True: req_id, command, req = await self._command_queue.get() await self._run_event_handler(req_id, command, req) self._command_queue.task_done() async def _try_read_loop(self) -> None: try: await self._read_loop() except Exception: self.log.exception("Fatal error in read loop") async def _read_loop(self) -> None: while self._reader is not None and not self._reader.at_eof(): line = b'' while True: try: line += await self._reader.readuntil() break except asyncio.IncompleteReadError as e: line += e.partial break except asyncio.LimitOverrunError as e: self.log.warning(f"Buffer overrun: {e}") line += await self._reader.read(self._reader._limit) if not line: continue try: line_str = line.decode("utf-8") except UnicodeDecodeError: self.log.exception("Got non-unicode request from server: %s", line) continue try: await self._handle_incoming_line(line_str) except Exception: self.log.exception("Failed to handle incoming request %s", line_str) self.log.debug("Reader disconnected") self._reader = None self._writer = None async def _raw_request(self, command: str, **data: Any) -> asyncio.Future: req_id = self._next_req_id future = self._response_waiters[req_id] = self.loop.create_future() req = {"id": req_id, "command": command, **data} self.log.trace("Request %d: %s %s", req_id, command, data) self._writer.write(json.dumps(req).encode("utf-8")) self._writer.write(b"\n") await self._writer.drain() return future async def request(self, command: str, **data: Any) -> Any: future = await self._raw_request(command, **data) return await future