311 lines
11 KiB
Python
311 lines
11 KiB
Python
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
|
|
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable, Awaitable
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
|
|
from mautrix.types.primitive import JSON
|
|
|
|
from ..config import Config
|
|
from .types import RPCError
|
|
|
|
EventHandler = Callable[[dict[str, Any]], Awaitable[None]]
|
|
|
|
|
|
class CancelableEvent:
|
|
_event: asyncio.Event
|
|
_task: asyncio.Task | None
|
|
_cancelled: bool
|
|
_loop: asyncio.AbstractEventLoop
|
|
|
|
def __init__(self, loop: asyncio.AbstractEventLoop | None):
|
|
self._event = asyncio.Event()
|
|
self._task = None
|
|
self._cancelled = False
|
|
self._loop = loop or asyncio.get_running_loop()
|
|
|
|
def is_set(self) -> bool:
|
|
return self._event.is_set()
|
|
|
|
def set(self) -> None:
|
|
self._event.set()
|
|
self._task = None
|
|
|
|
def clear(self) -> None:
|
|
self._event.clear()
|
|
|
|
async def wait(self) -> None:
|
|
if self._cancelled:
|
|
raise asyncio.CancelledError()
|
|
if self._event.is_set():
|
|
return
|
|
if not self._task:
|
|
self._task = asyncio.create_task(self._event.wait())
|
|
await self._task
|
|
|
|
def cancel(self) -> None:
|
|
self._cancelled = True
|
|
if self._task is not None:
|
|
self._task.cancel()
|
|
|
|
def cancelled(self) -> bool:
|
|
return self._cancelled
|
|
|
|
|
|
class RPCClient:
|
|
config: Config
|
|
loop: asyncio.AbstractEventLoop
|
|
log: logging.Logger = logging.getLogger("mau.rpc")
|
|
|
|
_reader: asyncio.StreamReader | None
|
|
_writer: asyncio.StreamWriter | None
|
|
_req_id: int
|
|
_min_broadcast_id: int
|
|
_response_waiters: dict[int, asyncio.Future[JSON]]
|
|
_event_handlers: dict[str, list[EventHandler]]
|
|
_command_queue: asyncio.Queue
|
|
_read_task: asyncio.Task | None
|
|
_connection_task: asyncio.Task | None
|
|
_is_connected: CancelableEvent
|
|
_is_disconnected: CancelableEvent
|
|
_connection_lock: asyncio.Lock
|
|
|
|
def __init__(self, config: Config) -> None:
|
|
self.config = config
|
|
self.loop = asyncio.get_running_loop()
|
|
self._req_id = 0
|
|
self._min_broadcast_id = 0
|
|
self._event_handlers = {}
|
|
self._response_waiters = {}
|
|
self._writer = None
|
|
self._reader = None
|
|
self._command_queue = asyncio.Queue()
|
|
self.loop.create_task(self._command_loop())
|
|
self._read_task = None
|
|
self._connection_task = None
|
|
self._is_connected = CancelableEvent(self.loop)
|
|
self._is_disconnected = CancelableEvent(self.loop)
|
|
self._is_disconnected.set()
|
|
self._connection_lock = asyncio.Lock()
|
|
|
|
async def connect(self) -> None:
|
|
async with self._connection_lock:
|
|
if self._is_connected.cancelled():
|
|
raise asyncio.CancelledError()
|
|
if self._is_connected.is_set():
|
|
return
|
|
self._connection_task = self.loop.create_task(self._connect())
|
|
try:
|
|
await self._connection_task
|
|
finally:
|
|
self._connection_task = None
|
|
|
|
async def _connect(self) -> None:
|
|
if self.config["rpc.connection.type"] == "unix":
|
|
while True:
|
|
try:
|
|
r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
|
|
break
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except:
|
|
self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
|
|
await asyncio.sleep(10)
|
|
elif self.config["rpc.connection.type"] == "tcp":
|
|
while True:
|
|
try:
|
|
r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
|
|
self.config["rpc.connection.port"])
|
|
break
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except:
|
|
self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
|
|
await asyncio.sleep(10)
|
|
else:
|
|
raise RuntimeError("invalid rpc connection type")
|
|
self._reader = r
|
|
self._writer = w
|
|
self._read_task = self.loop.create_task(self._try_read_loop())
|
|
await self._raw_request("register", peer_id=self.config["appservice.address"])
|
|
self._is_connected.set()
|
|
self._is_disconnected.clear()
|
|
|
|
async def disconnect(self) -> None:
|
|
async with self._connection_lock:
|
|
if self._is_disconnected.cancelled():
|
|
raise asyncio.CancelledError()
|
|
if self._is_disconnected.is_set():
|
|
return
|
|
await self._disconnect()
|
|
|
|
async def _disconnect(self) -> None:
|
|
if self._writer is not None:
|
|
self._writer.write_eof()
|
|
await self._writer.drain()
|
|
if self._read_task is not None:
|
|
self._read_task.cancel()
|
|
self._read_task = None
|
|
self._on_disconnect()
|
|
|
|
def _on_disconnect(self) -> None:
|
|
self._reader = None
|
|
self._writer = None
|
|
self._is_connected.clear()
|
|
self._is_disconnected.set()
|
|
|
|
def wait_for_connection(self) -> Awaitable[None]:
|
|
return self._is_connected.wait()
|
|
|
|
def wait_for_disconnection(self) -> Awaitable[None]:
|
|
return self._is_disconnected.wait()
|
|
|
|
def cancel(self) -> None:
|
|
self._is_connected.cancel()
|
|
self._is_disconnected.cancel()
|
|
if self._connection_task is not None:
|
|
self._connection_task.cancel()
|
|
asyncio.run(self._disconnect())
|
|
|
|
@property
|
|
def _next_req_id(self) -> int:
|
|
self._req_id += 1
|
|
return self._req_id
|
|
|
|
def add_event_handler(self, method: str, handler: EventHandler) -> None:
|
|
self._event_handlers.setdefault(method, []).append(handler)
|
|
|
|
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
|
|
try:
|
|
self._event_handlers.setdefault(method, []).remove(handler)
|
|
except ValueError:
|
|
pass
|
|
|
|
def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
|
|
self._event_handlers[method] = handlers
|
|
|
|
async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
|
|
if req_id > self._min_broadcast_id:
|
|
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
|
|
return
|
|
self._min_broadcast_id = req_id
|
|
try:
|
|
handlers = self._event_handlers[command]
|
|
except KeyError:
|
|
self.log.warning("No handlers for %s", command)
|
|
else:
|
|
for handler in handlers:
|
|
try:
|
|
await handler(req)
|
|
except:
|
|
self.log.exception("Exception in event handler")
|
|
|
|
async def _handle_incoming_line(self, line: str) -> None:
|
|
try:
|
|
req = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
self.log.debug(f"Got non-JSON data from server: {line}")
|
|
return
|
|
try:
|
|
req_id = req.pop("id")
|
|
command = req.pop("command")
|
|
is_sequential = req.pop("is_sequential", False)
|
|
except KeyError:
|
|
self.log.debug(f"Got invalid request from server: {line}")
|
|
return
|
|
if req_id < 0:
|
|
if not is_sequential:
|
|
self.loop.create_task(self._run_event_handler(req_id, command, req))
|
|
else:
|
|
self._command_queue.put_nowait((req_id, command, req))
|
|
return
|
|
try:
|
|
waiter = self._response_waiters[req_id]
|
|
except KeyError:
|
|
self.log.debug(f"Nobody waiting for response to {req_id}")
|
|
return
|
|
if command == "response":
|
|
waiter.set_result(req.get("response"))
|
|
elif command == "error":
|
|
waiter.set_exception(RPCError(req.get("error", line)))
|
|
else:
|
|
self.log.warning(f"Unexpected response command to {req_id}: {command} {req}")
|
|
|
|
async def _command_loop(self) -> None:
|
|
while True:
|
|
req_id, command, req = await self._command_queue.get()
|
|
await self._run_event_handler(req_id, command, req)
|
|
self._command_queue.task_done()
|
|
|
|
async def _try_read_loop(self) -> None:
|
|
try:
|
|
await self._read_loop()
|
|
except asyncio.CancelledError:
|
|
return
|
|
except:
|
|
self.log.exception("Fatal error in read loop")
|
|
self.log.debug("Reader disconnected")
|
|
self._on_disconnect()
|
|
|
|
async def _read_loop(self) -> None:
|
|
while self._reader is not None and not self._reader.at_eof():
|
|
line = b''
|
|
while True:
|
|
try:
|
|
line += await self._reader.readuntil()
|
|
break
|
|
except asyncio.IncompleteReadError as e:
|
|
line += e.partial
|
|
break
|
|
except asyncio.LimitOverrunError as e:
|
|
self.log.warning(f"Buffer overrun: {e}")
|
|
line += await self._reader.read(self._reader._limit)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
if not line:
|
|
continue
|
|
try:
|
|
line_str = line.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
self.log.exception("Got non-unicode request from server: %s", line)
|
|
continue
|
|
try:
|
|
await self._handle_incoming_line(line_str)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except:
|
|
self.log.exception("Failed to handle incoming request %s", line_str)
|
|
|
|
async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]:
|
|
req_id = self._next_req_id
|
|
future = self._response_waiters[req_id] = self.loop.create_future()
|
|
req = {"id": req_id, "command": command, **data}
|
|
self.log.debug("Request %d: %s %s", req_id, command, data if not is_secret else "<REDACTED>")
|
|
assert self._writer is not None
|
|
self._writer.write(json.dumps(req).encode("utf-8"))
|
|
self._writer.write(b"\n")
|
|
await self._writer.drain()
|
|
return future
|
|
|
|
async def request(self, command: str, **data: JSON) -> JSON:
|
|
await self.wait_for_connection()
|
|
future = await self._raw_request(command, **data)
|
|
return await future
|