matrix-puppeteer-line/mautrix_amp/rpc/rpc.py

167 lines
6.2 KiB
Python
Raw Normal View History

2020-08-28 09:38:06 -04:00
# mautrix-amp - A very hacky Matrix-SMS bridge based on using Android Messages for Web in Puppeteer
# Copyright (C) 2020 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any, Callable, Awaitable, List, Optional
import logging
import asyncio
import json
from mautrix.types import UserID
from mautrix.util.logging import TraceLogger
from ..config import Config
from .types import RPCError
EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
class RPCClient:
config: Config
loop: asyncio.AbstractEventLoop
log: TraceLogger = logging.getLogger("mau.rpc")
user_id: UserID
_reader: Optional[asyncio.StreamReader]
_writer: Optional[asyncio.StreamWriter]
_req_id: int
_min_broadcast_id: int
_response_waiters: Dict[int, asyncio.Future]
_event_handlers: Dict[str, List[EventHandler]]
def __init__(self, user_id: UserID) -> None:
self.log = self.log.getChild(user_id)
self.loop = asyncio.get_running_loop()
self.user_id = user_id
self._req_id = 0
self._min_broadcast_id = 0
self._event_handlers = {}
self._response_waiters = {}
self._writer = None
self._reader = None
async def connect(self) -> None:
if self._writer is not None:
return
if self.config["puppeteer.connection.type"] == "unix":
r, w = await asyncio.open_unix_connection(self.config["puppeteer.connection.path"])
elif self.config["puppeteer.connection.type"] == "tcp":
r, w = await asyncio.open_connection(self.config["puppeteer.connection.host"],
self.config["puppeteer.connection.port"])
else:
raise RuntimeError("invalid puppeteer connection type")
self._reader = r
self._writer = w
self.loop.create_task(self._try_read_loop())
await self.request("register", user_id=self.user_id)
async def disconnect(self) -> None:
self._writer.write_eof()
await self._writer.drain()
self._writer = None
self._reader = None
@property
def _next_req_id(self) -> int:
self._req_id += 1
return self._req_id
def add_event_handler(self, method: str, handler: EventHandler) -> None:
self._event_handlers.setdefault(method, []).append(handler)
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
self._event_handlers.setdefault(method, []).remove(handler)
async def _run_event_handler(self, req_id: int, command: str, req: Dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
return
self._min_broadcast_id = req_id
try:
handlers = self._event_handlers[command]
except KeyError:
self.log.warning("No handlers for %s", command)
else:
for handler in handlers:
try:
await handler(req)
except Exception:
self.log.exception("Exception in event handler")
async def _handle_incoming_line(self, line: str) -> None:
try:
req = json.loads(line)
except json.JSONDecodeError:
self.log.debug(f"Got non-JSON data from server: {line}")
return
try:
req_id = req.pop("id")
command = req.pop("command")
except KeyError:
self.log.debug(f"Got invalid request from server: {line}")
return
if req_id < 0:
self.loop.create_task(self._run_event_handler(req_id, command, req))
return
try:
waiter = self._response_waiters[req_id]
except KeyError:
self.log.debug(f"Nobody waiting for response to {req_id}")
return
if command == "response":
waiter.set_result(req.get("response"))
elif command == "error":
waiter.set_exception(RPCError(req.get("error", line)))
else:
self.log.warning(f"Unexpected response command to {req_id}: {command} {req}")
async def _try_read_loop(self) -> None:
try:
await self._read_loop()
except Exception:
self.log.exception("Fatal error in read loop")
async def _read_loop(self) -> None:
while self._reader is not None and not self._reader.at_eof():
line = await self._reader.readline()
if not line:
continue
try:
line_str = line.decode("utf-8")
except UnicodeDecodeError:
self.log.exception("Got non-unicode request from server: %s", line)
continue
try:
await self._handle_incoming_line(line_str)
except Exception:
self.log.exception("Failed to handle incoming request %s", line_str)
self.log.debug("Reader disconnected")
self._reader = None
self._writer = None
async def _raw_request(self, command: str, **data: Any) -> asyncio.Future:
req_id = self._next_req_id
future = self._response_waiters[req_id] = self.loop.create_future()
req = {"id": req_id, "command": command, **data}
self.log.trace("Request %d: %s %s", req_id, command, data)
self._writer.write(json.dumps(req).encode("utf-8"))
self._writer.write(b"\n")
await self._writer.drain()
return future
async def request(self, command: str, **data: Any) -> Dict[str, Any]:
future = await self._raw_request(command, **data)
return await future