matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/rpc/rpc.py

210 lines
8.0 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any, Callable, Awaitable
import asyncio
import json
import logging
from mautrix.types.primitive import JSON
from ..config import Config
from .types import RPCError
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
class RPCClient:
config: Config
loop: asyncio.AbstractEventLoop
log: logging.Logger = logging.getLogger("mau.rpc")
_reader: asyncio.StreamReader | None
_writer: asyncio.StreamWriter | None
_req_id: int
_min_broadcast_id: int
_response_waiters: dict[int, asyncio.Future[JSON]]
_event_handlers: dict[str, list[EventHandler]]
_command_queue: asyncio.Queue
def __init__(self, config: Config) -> None:
self.config = config
self.loop = asyncio.get_running_loop()
self._req_id = 0
self._min_broadcast_id = 0
self._event_handlers = {}
self._response_waiters = {}
self._writer = None
self._reader = None
self._command_queue = asyncio.Queue()
async def connect(self) -> None:
if self._writer is not None:
return
if self.config["rpc.connection.type"] == "unix":
while True:
try:
r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
break
except:
self.log.warn(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
await asyncio.sleep(10)
elif self.config["rpc.connection.type"] == "tcp":
while True:
try:
r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
self.config["rpc.connection.port"])
break
except:
self.log.warn(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
await asyncio.sleep(10)
else:
raise RuntimeError("invalid rpc connection type")
self._reader = r
self._writer = w
self.loop.create_task(self._try_read_loop())
self.loop.create_task(self._command_loop())
await self.request("register", peer_id=self.config["appservice.address"])
async def disconnect(self) -> None:
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
self._writer = None
self._reader = None
@property
def _next_req_id(self) -> int:
self._req_id += 1
return self._req_id
def add_event_handler(self, method: str, handler: EventHandler) -> None:
self._event_handlers.setdefault(method, []).append(handler)
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
try:
self._event_handlers.setdefault(method, []).remove(handler)
except ValueError:
pass
def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
self._event_handlers[method] = handlers
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
return
self._min_broadcast_id = req_id
try:
handlers = self._event_handlers[command]
except KeyError:
self.log.warning("No handlers for %s", command)
else:
for handler in handlers:
try:
await handler(req)
except Exception:
self.log.exception("Exception in event handler")
async def _handle_incoming_line(self, line: str) -> None:
try:
req = json.loads(line)
except json.JSONDecodeError:
self.log.debug(f"Got non-JSON data from server: {line}")
return
try:
req_id = req.pop("id")
command = req.pop("command")
is_sequential = req.pop("is_sequential", False)
except KeyError:
self.log.debug(f"Got invalid request from server: {line}")
return
if req_id < 0:
if not is_sequential:
self.loop.create_task(self._run_event_handler(req_id, command, req))
else:
self._command_queue.put_nowait((req_id, command, req))
return
try:
waiter = self._response_waiters[req_id]
except KeyError:
self.log.debug(f"Nobody waiting for response to {req_id}")
return
if command == "response":
waiter.set_result(req.get("response"))
elif command == "error":
waiter.set_exception(RPCError(req.get("error", line)))
else:
self.log.warning(f"Unexpected response command to {req_id}: {command} {req}")
async def _command_loop(self) -> None:
while True:
req_id, command, req = await self._command_queue.get()
await self._run_event_handler(req_id, command, req)
self._command_queue.task_done()
async def _try_read_loop(self) -> None:
try:
await self._read_loop()
except Exception:
self.log.exception("Fatal error in read loop")
async def _read_loop(self) -> None:
while self._reader is not None and not self._reader.at_eof():
line = b''
while True:
try:
line += await self._reader.readuntil()
break
except asyncio.IncompleteReadError as e:
line += e.partial
break
except asyncio.LimitOverrunError as e:
self.log.warning(f"Buffer overrun: {e}")
line += await self._reader.read(self._reader._limit)
if not line:
continue
try:
line_str = line.decode("utf-8")
except UnicodeDecodeError:
self.log.exception("Got non-unicode request from server: %s", line)
continue
try:
await self._handle_incoming_line(line_str)
except Exception:
self.log.exception("Failed to handle incoming request %s", line_str)
self.log.debug("Reader disconnected")
self._reader = None
self._writer = None
async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]:
req_id = self._next_req_id
future = self._response_waiters[req_id] = self.loop.create_future()
req = {"id": req_id, "command": command, **data}
self.log.debug("Request %d: %s %s", req_id, command, data if not is_secret else "<REDACTED>")
assert self._writer is not None
self._writer.write(json.dumps(req).encode("utf-8"))
self._writer.write(b"\n")
await self._writer.drain()
return future
async def request(self, command: str, **data: JSON) -> JSON:
future = await self._raw_request(command, **data)
return await future