# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import Any, Callable, Awaitable

import asyncio
import json
import logging

from mautrix.types.primitive import JSON

from ..config import Config
from .types import RPCError

EventHandler = Callable[[dict[str, Any]], Awaitable[None]]


class RPCClient:
    config: Config
    loop: asyncio.AbstractEventLoop
    log: logging.Logger = logging.getLogger("mau.rpc")

    _reader: asyncio.StreamReader | None
    _writer: asyncio.StreamWriter | None
    _req_id: int
    _min_broadcast_id: int
    _response_waiters: dict[int, asyncio.Future[JSON]]
    _event_handlers: dict[str, list[EventHandler]]
    _command_queue: asyncio.Queue

    def __init__(self, config: Config) -> None:
        self.config = config
        self.loop = asyncio.get_running_loop()
        self._req_id = 0
        self._min_broadcast_id = 0
        self._event_handlers = {}
        self._response_waiters = {}
        self._writer = None
        self._reader = None
        self._command_queue = asyncio.Queue()

    async def connect(self) -> None:
        if self._writer is not None:
            return

        if self.config["rpc.connection.type"] == "unix":
            while True:
                try:
                    r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
                    break
                except:
                    self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
                    await asyncio.sleep(10)
        elif self.config["rpc.connection.type"] == "tcp":
            while True:
                try:
                    r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
                                                         self.config["rpc.connection.port"])
                    break
                except:
                    self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
                    await asyncio.sleep(10)
        else:
            raise RuntimeError("invalid rpc connection type")
        self._reader = r
        self._writer = w
        self.loop.create_task(self._try_read_loop())
        self.loop.create_task(self._command_loop())
        await self.request("register", peer_id=self.config["appservice.address"])

    async def disconnect(self) -> None:
        if self._writer is not None:
            self._writer.write_eof()
            await self._writer.drain()
            self._writer = None
            self._reader = None

    @property
    def _next_req_id(self) -> int:
        self._req_id += 1
        return self._req_id

    def add_event_handler(self, method: str, handler: EventHandler) -> None:
        self._event_handlers.setdefault(method, []).append(handler)

    def remove_event_handler(self, method: str, handler: EventHandler) -> None:
        try:
            self._event_handlers.setdefault(method, []).remove(handler)
        except ValueError:
            pass

    def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
        self._event_handlers[method] = handlers

    async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
        if req_id > self._min_broadcast_id:
            self.log.debug(f"Ignoring duplicate broadcast {req_id}")
            return
        self._min_broadcast_id = req_id
        try:
            handlers = self._event_handlers[command]
        except KeyError:
            self.log.warning("No handlers for %s", command)
        else:
            for handler in handlers:
                try:
                    await handler(req)
                except Exception:
                    self.log.exception("Exception in event handler")

    async def _handle_incoming_line(self, line: str) -> None:
        try:
            req = json.loads(line)
        except json.JSONDecodeError:
            self.log.debug(f"Got non-JSON data from server: {line}")
            return
        try:
            req_id = req.pop("id")
            command = req.pop("command")
            is_sequential = req.pop("is_sequential", False)
        except KeyError:
            self.log.debug(f"Got invalid request from server: {line}")
            return
        if req_id < 0:
            if not is_sequential:
                self.loop.create_task(self._run_event_handler(req_id, command, req))
            else:
                self._command_queue.put_nowait((req_id, command, req))
            return
        try:
            waiter = self._response_waiters[req_id]
        except KeyError:
            self.log.debug(f"Nobody waiting for response to {req_id}")
            return
        if command == "response":
            waiter.set_result(req.get("response"))
        elif command == "error":
            waiter.set_exception(RPCError(req.get("error", line)))
        else:
            self.log.warning(f"Unexpected response command to {req_id}: {command} {req}")

    async def _command_loop(self) -> None:
        while True:
            req_id, command, req = await self._command_queue.get()
            await self._run_event_handler(req_id, command, req)
            self._command_queue.task_done()

    async def _try_read_loop(self) -> None:
        try:
            await self._read_loop()
        except Exception:
            self.log.exception("Fatal error in read loop")

    async def _read_loop(self) -> None:
        while self._reader is not None and not self._reader.at_eof():
            line = b''
            while True:
                try:
                    line += await self._reader.readuntil()
                    break
                except asyncio.IncompleteReadError as e:
                    line += e.partial
                    break
                except asyncio.LimitOverrunError as e:
                    self.log.warning(f"Buffer overrun: {e}")
                    line += await self._reader.read(self._reader._limit)
            if not line:
                continue
            try:
                line_str = line.decode("utf-8")
            except UnicodeDecodeError:
                self.log.exception("Got non-unicode request from server: %s", line)
                continue
            try:
                await self._handle_incoming_line(line_str)
            except Exception:
                self.log.exception("Failed to handle incoming request %s", line_str)
        self.log.debug("Reader disconnected")
        self._reader = None
        self._writer = None

    async def _raw_request(self, command: str, is_secret: bool = False, **data: JSON) -> asyncio.Future[JSON]:
        req_id = self._next_req_id
        future = self._response_waiters[req_id] = self.loop.create_future()
        req = {"id": req_id, "command": command, **data}
        self.log.debug("Request %d: %s %s", req_id, command, data if not is_secret else "<REDACTED>")
        assert self._writer is not None
        self._writer.write(json.dumps(req).encode("utf-8"))
        self._writer.write(b"\n")
        await self._writer.drain()
        return future

    async def request(self, command: str, **data: JSON) -> JSON:
        future = await self._raw_request(command, **data)
        return await future