matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/rpc/rpc.py

317 lines
11 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any, Callable, Awaitable
import asyncio
import json
import logging
from mautrix.types.primitive import JSON
from ..config import Config
from .types import RPCError
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
class CancelableEvent:
_event: asyncio.Event
_task: asyncio.Task | None
_cancelled: bool
_loop: asyncio.AbstractEventLoop
def __init__(self, loop: asyncio.AbstractEventLoop | None):
self._event = asyncio.Event()
self._task = None
self._cancelled = False
self._loop = loop or asyncio.get_running_loop()
def is_set(self) -> bool:
return self._event.is_set()
def set(self) -> None:
self._event.set()
self._task = None
def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
if self._cancelled:
raise asyncio.CancelledError()
if self._event.is_set():
return
if not self._task:
self._task = asyncio.create_task(self._event.wait())
await self._task
def cancel(self) -> None:
self._cancelled = True
if self._task is not None:
self._task.cancel()
def cancelled(self) -> bool:
return self._cancelled
class RPCClient:
config: Config
loop: asyncio.AbstractEventLoop
log: logging.Logger = logging.getLogger("mau.rpc")
_reader: asyncio.StreamReader | None
_writer: asyncio.StreamWriter | None
_req_id: int
_min_broadcast_id: int
_response_waiters: dict[int, asyncio.Future[JSON]]
_event_handlers: dict[str, list[EventHandler]]
_command_queue: asyncio.Queue
_read_task: asyncio.Task | None
_connection_task: asyncio.Task | None
_is_connected: CancelableEvent
_is_disconnected: CancelableEvent
_connection_lock: asyncio.Lock
def __init__(self, config: Config, register_config_key: str) -> None:
self.config = config
self.register_config_key = register_config_key
self.loop = asyncio.get_running_loop()
self._req_id = 0
self._min_broadcast_id = 0
self._event_handlers = {}
self._response_waiters = {}
self._writer = None
self._reader = None
self._command_queue = asyncio.Queue()
self.loop.create_task(self._command_loop())
self._read_task = None
self._connection_task = None
self._is_connected = CancelableEvent(self.loop)
self._is_disconnected = CancelableEvent(self.loop)
self._is_disconnected.set()
self._connection_lock = asyncio.Lock()
async def connect(self) -> None:
async with self._connection_lock:
if self._is_connected.cancelled():
raise asyncio.CancelledError()
if self._is_connected.is_set():
return
self._connection_task = self.loop.create_task(self._connect())
try:
await self._connection_task
finally:
self._connection_task = None
async def _connect(self) -> None:
if self.config["rpc.connection.type"] == "unix":
while True:
try:
r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
await asyncio.sleep(10)
elif self.config["rpc.connection.type"] == "tcp":
while True:
try:
r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
self.config["rpc.connection.port"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
await asyncio.sleep(10)
else:
raise RuntimeError("invalid rpc connection type")
self._reader = r
self._writer = w
self._read_task = self.loop.create_task(self._try_read_loop())
await self._raw_request("register",
peer_id=self.config["appservice.address"],
register_config=self.config[self.register_config_key])
self._is_connected.set()
self._is_disconnected.clear()
async def disconnect(self) -> None:
async with self._connection_lock:
if self._is_disconnected.cancelled():
raise asyncio.CancelledError()
if self._is_disconnected.is_set():
return
await self._disconnect()
async def _disconnect(self) -> None:
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
if self._read_task is not None:
self._read_task.cancel()
self._read_task = None
self._on_disconnect()
def _on_disconnect(self) -> None:
self._req_id = 0
self._min_broadcast_id = 0
self._reader = None
self._writer = None
self._is_connected.clear()
self._is_disconnected.set()
def wait_for_connection(self) -> Awaitable[None]:
return self._is_connected.wait()
def wait_for_disconnection(self) -> Awaitable[None]:
return self._is_disconnected.wait()
def cancel(self) -> None:
self._is_connected.cancel()
self._is_disconnected.cancel()
if self._connection_task is not None:
self._connection_task.cancel()
asyncio.run(self._disconnect())
@property
def _next_req_id(self) -> int:
self._req_id += 1
return self._req_id
def add_event_handler(self, method: str, handler: EventHandler) -> None:
self._event_handlers.setdefault(method, []).append(handler)
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
try:
self._event_handlers.setdefault(method, []).remove(handler)
except ValueError:
pass
def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
self._event_handlers[method] = handlers
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
return
self._min_broadcast_id = req_id
try:
handlers = self._event_handlers[command]
except KeyError:
self.log.warning("No handlers for %s", command)
else:
for handler in handlers:
try:
await handler(req)
except:
self.log.exception("Exception in event handler")
async def _handle_incoming_line(self, line: str) -> None:
try:
req = json.loads(line)
except json.JSONDecodeError:
self.log.debug(f"Got non-JSON data from server: {line}")
return
try:
req_id = req.pop("id")
command = req.pop("command")
is_sequential = req.pop("is_sequential", False)
except KeyError:
self.log.debug(f"Got invalid request from server: {line}")
return
if req_id < 0:
if not is_sequential:
self.loop.create_task(self._run_event_handler(req_id, command, req))
else:
self._command_queue.put_nowait((req_id, command, req))
return
try:
waiter = self._response_waiters[req_id]
except KeyError:
self.log.debug(f"Nobody waiting for response to {req_id}")
return
if command == "response":
self.log.debug("Received response %d", req_id)
waiter.set_result(req.get("response"))
elif command == "error":
waiter.set_exception(RPCError(req.get("error", line)))
else:
self.log.warning(f"Unexpected response command to {req_id}: {command} {req}")
async def _command_loop(self) -> None:
while True:
req_id, command, req = await self._command_queue.get()
await self._run_event_handler(req_id, command, req)
self._command_queue.task_done()
async def _try_read_loop(self) -> None:
try:
await self._read_loop()
except asyncio.CancelledError:
return
except:
self.log.exception("Fatal error in read loop")
self.log.debug("Reader disconnected")
self._on_disconnect()
async def _read_loop(self) -> None:
while self._reader is not None and not self._reader.at_eof():
line = b''
while True:
try:
line += await self._reader.readuntil()
break
except asyncio.IncompleteReadError as e:
line += e.partial
break
except asyncio.LimitOverrunError as e:
self.log.warning(f"Buffer overrun: {e}")
line += await self._reader.read(e.consumed)
except asyncio.CancelledError:
raise
if not line:
continue
try:
line_str = line.decode("utf-8")
except UnicodeDecodeError:
self.log.exception("Got non-unicode request from server: %s", line)
continue
try:
await self._handle_incoming_line(line_str)
except asyncio.CancelledError:
raise
except:
self.log.exception("Failed to handle incoming request %s", line_str)
async def _raw_request(self, command: str, **data: JSON) -> asyncio.Future[JSON]:
req_id = self._next_req_id
future = self._response_waiters[req_id] = self.loop.create_future()
req = {"id": req_id, "command": command, **data}
self.log.debug("Request %d: %s", req_id, command)
assert self._writer is not None
self._writer.write(json.dumps(req).encode("utf-8"))
self._writer.write(b"\n")
await self._writer.drain()
return future
async def request(self, command: str, **data: JSON) -> JSON:
await self.wait_for_connection()
future = await self._raw_request(command, **data)
return await future