matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/rpc/rpc.py

317 lines
11 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any, Awaitable, Callable, Dict, List, Optional
import asyncio
import json
import logging
from mautrix.types.primitive import JSON
from ..config import Config
from .types import RPCError
EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
class CancelableEvent:
_event: asyncio.Event
_task: Optional[asyncio.Task]
_cancelled: bool
_loop: asyncio.AbstractEventLoop
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]):
self._event = asyncio.Event()
self._task = None
self._cancelled = False
self._loop = loop or asyncio.get_running_loop()
def is_set(self) -> bool:
return self._event.is_set()
def set(self) -> None:
self._event.set()
self._task = None
def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
if self._cancelled:
raise asyncio.CancelledError()
if self._event.is_set():
return
if not self._task:
self._task = asyncio.create_task(self._event.wait())
await self._task
def cancel(self) -> None:
self._cancelled = True
if self._task is not None:
self._task.cancel()
def cancelled(self) -> bool:
return self._cancelled
class RPCClient:
config: Config
loop: asyncio.AbstractEventLoop
log: logging.Logger = logging.getLogger("mau.rpc")
_reader: Optional[asyncio.StreamReader]
_writer: Optional[asyncio.StreamWriter]
_req_id: int
_min_broadcast_id: int
_response_waiters: Dict[int, asyncio.Future[JSON]]
_event_handlers: Dict[str, List[EventHandler]]
_command_queue: asyncio.Queue
_read_task: Optional[asyncio.Task]
_connection_task: Optional[asyncio.Task]
_is_connected: CancelableEvent
_is_disconnected: CancelableEvent
_connection_lock: asyncio.Lock
def __init__(self, config: Config, register_config_key: str) -> None:
self.config = config
self.register_config_key = register_config_key
self.loop = asyncio.get_running_loop()
self._req_id = 0
self._min_broadcast_id = 0
self._event_handlers = {}
self._response_waiters = {}
self._writer = None
self._reader = None
self._command_queue = asyncio.Queue()
self.loop.create_task(self._command_loop())
self._read_task = None
self._connection_task = None
self._is_connected = CancelableEvent(self.loop)
self._is_disconnected = CancelableEvent(self.loop)
self._is_disconnected.set()
self._connection_lock = asyncio.Lock()
async def connect(self) -> None:
async with self._connection_lock:
if self._is_connected.cancelled():
raise asyncio.CancelledError()
if self._is_connected.is_set():
return
self._connection_task = self.loop.create_task(self._connect())
try:
await self._connection_task
finally:
self._connection_task = None
async def _connect(self) -> None:
if self.config["rpc.connection.type"] == "unix":
while True:
try:
r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
await asyncio.sleep(10)
elif self.config["rpc.connection.type"] == "tcp":
while True:
try:
r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
self.config["rpc.connection.port"])
break
except asyncio.CancelledError:
raise
except:
self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
await asyncio.sleep(10)
else:
raise RuntimeError("invalid rpc connection type")
self._reader = r
self._writer = w
self._read_task = self.loop.create_task(self._try_read_loop())
await self._raw_request("register",
peer_id=self.config["appservice.address"],
register_config=self.config[self.register_config_key])
self._is_connected.set()
self._is_disconnected.clear()
async def disconnect(self) -> None:
async with self._connection_lock:
if self._is_disconnected.cancelled():
raise asyncio.CancelledError()
if self._is_disconnected.is_set():
return
await self._disconnect()
async def _disconnect(self) -> None:
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
if self._read_task is not None:
self._read_task.cancel()
self._read_task = None
self._on_disconnect()
def _on_disconnect(self) -> None:
self._req_id = 0
self._min_broadcast_id = 0
self._reader = None
self._writer = None
self._is_connected.clear()
self._is_disconnected.set()
def wait_for_connection(self) -> Awaitable[None]:
return self._is_connected.wait()
def wait_for_disconnection(self) -> Awaitable[None]:
return self._is_disconnected.wait()
def cancel(self) -> None:
self._is_connected.cancel()
self._is_disconnected.cancel()
if self._connection_task is not None:
self._connection_task.cancel()
asyncio.run(self._disconnect())
@property
def _next_req_id(self) -> int:
self._req_id += 1
return self._req_id
def add_event_handler(self, method: str, handler: EventHandler) -> None:
self._event_handlers.setdefault(method, []).append(handler)
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
try:
self._event_handlers.setdefault(method, []).remove(handler)
except ValueError:
pass
def set_event_handlers(self, method: str, handlers: List[EventHandler]) -> None:
self._event_handlers[method] = handlers
async def _run_event_handler(self, req_id: int, command: str, req: Dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
return
self._min_broadcast_id = req_id
try:
handlers = self._event_handlers[command]
except KeyError:
self.log.warning("No handlers for %s", command)
else:
for handler in handlers:
try:
await handler(req)
except:
self.log.exception("Exception in event handler")
async def _handle_incoming_line(self, line: str) -> None:
try:
req = json.loads(line)
except json.JSONDecodeError:
self.log.debug(f"Got non-JSON data from server: {line}")
return
try:
req_id = req.pop("id")
command = req.pop("command")
is_sequential = req.pop("is_sequential", False)
except KeyError:
self.log.debug(f"Got invalid request from server: {line}")
return
if req_id < 0:
if not is_sequential:
self.loop.create_task(self._run_event_handler(req_id, command, req))
else:
self._command_queue.put_nowait((req_id, command, req))
return
try:
waiter = self._response_waiters[req_id]
except KeyError:
self.log.debug(f"Nobody waiting for response to {req_id}")
return
if command == "response":
self.log.debug("Received response %d", req_id)
waiter.set_result(req.get("response"))
elif command == "error":
waiter.set_exception(RPCError(req.get("error", line)))
else:
self.log.warning(f"Unexpected response command to {req_id}: {command} {req}")
async def _command_loop(self) -> None:
while True:
req_id, command, req = await self._command_queue.get()
await self._run_event_handler(req_id, command, req)
self._command_queue.task_done()
async def _try_read_loop(self) -> None:
try:
await self._read_loop()
except asyncio.CancelledError:
return
except:
self.log.exception("Fatal error in read loop")
self.log.debug("Reader disconnected")
self._on_disconnect()
async def _read_loop(self) -> None:
while self._reader is not None and not self._reader.at_eof():
line = b''
while True:
try:
line += await self._reader.readuntil()
break
except asyncio.IncompleteReadError as e:
line += e.partial
break
except asyncio.LimitOverrunError as e:
self.log.warning(f"Buffer overrun: {e}")
line += await self._reader.read(e.consumed)
except asyncio.CancelledError:
raise
if not line:
continue
try:
line_str = line.decode("utf-8")
except UnicodeDecodeError:
self.log.exception("Got non-unicode request from server: %s", line)
continue
try:
await self._handle_incoming_line(line_str)
except asyncio.CancelledError:
raise
except:
self.log.exception("Failed to handle incoming request %s", line_str)
async def _raw_request(self, command: str, **data: JSON) -> asyncio.Future[JSON]:
req_id = self._next_req_id
future = self._response_waiters[req_id] = self.loop.create_future()
req = {"id": req_id, "command": command, **data}
self.log.debug("Request %d: %s", req_id, command)
assert self._writer is not None
self._writer.write(json.dumps(req).encode("utf-8"))
self._writer.write(b"\n")
await self._writer.drain()
return future
async def request(self, command: str, **data: JSON) -> JSON:
await self.wait_for_connection()
future = await self._raw_request(command, **data)
return await future