# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import Any, Callable, Awaitable

import asyncio
import json
import logging

from mautrix.types.primitive import JSON

from ..config import Config
from .types import RPCError

EventHandler = Callable[[dict[str, Any]], Awaitable[None]]


class CancelableEvent:
    _event: asyncio.Event
    _task: asyncio.Task | None
    _cancelled: bool
    _loop: asyncio.AbstractEventLoop

    def __init__(self, loop: asyncio.AbstractEventLoop | None):
        self._event = asyncio.Event()
        self._task = None
        self._cancelled = False
        self._loop = loop or asyncio.get_running_loop()

    def is_set(self) -> bool:
        return self._event.is_set()

    def set(self) -> None:
        self._event.set()
        self._task = None

    def clear(self) -> None:
        self._event.clear()

    async def wait(self) -> None:
        if self._cancelled:
            raise asyncio.CancelledError()
        if self._event.is_set():
            return
        if not self._task:
            self._task = asyncio.create_task(self._event.wait())
        await self._task

    def cancel(self) -> None:
        self._cancelled = True
        if self._task is not None:
            self._task.cancel()

    def cancelled(self) -> bool:
        return self._cancelled


class RPCClient:
    config: Config
    loop: asyncio.AbstractEventLoop
    log: logging.Logger = logging.getLogger("mau.rpc")

    _reader: asyncio.StreamReader | None
    _writer: asyncio.StreamWriter | None
    _req_id: int
    _min_broadcast_id: int
    _response_waiters: dict[int, asyncio.Future[JSON]]
    _event_handlers: dict[str, list[EventHandler]]
    _command_queue: asyncio.Queue
    _read_task: asyncio.Task | None
    _connection_task: asyncio.Task | None
    _is_connected: CancelableEvent
    _is_disconnected: CancelableEvent
    _connection_lock: asyncio.Lock

    def __init__(self, config: Config, register_config_key: str) -> None:
        self.config = config
        self.register_config_key = register_config_key
        self.loop = asyncio.get_running_loop()
        self._req_id = 0
        self._min_broadcast_id = 0
        self._event_handlers = {}
        self._response_waiters = {}
        self._writer = None
        self._reader = None
        self._command_queue = asyncio.Queue()
        self.loop.create_task(self._command_loop())
        self._read_task = None
        self._connection_task = None
        self._is_connected = CancelableEvent(self.loop)
        self._is_disconnected = CancelableEvent(self.loop)
        self._is_disconnected.set()
        self._connection_lock = asyncio.Lock()

    async def connect(self) -> None:
        async with self._connection_lock:
            if self._is_connected.cancelled():
                raise asyncio.CancelledError()
            if self._is_connected.is_set():
                return
            self._connection_task = self.loop.create_task(self._connect())
            try:
                await self._connection_task
            finally:
                self._connection_task = None

    async def _connect(self) -> None:
        if self.config["rpc.connection.type"] == "unix":
            while True:
                try:
                    r, w = await asyncio.open_unix_connection(self.config["rpc.connection.path"])
                    break
                except asyncio.CancelledError:
                    raise
                except:
                    self.log.warning(f'No unix socket available at {self.config["rpc.connection.path"]}, wait for it to exist...')
                    await asyncio.sleep(10)
        elif self.config["rpc.connection.type"] == "tcp":
            while True:
                try:
                    r, w = await asyncio.open_connection(self.config["rpc.connection.host"],
                                                         self.config["rpc.connection.port"])
                    break
                except asyncio.CancelledError:
                    raise
                except:
                    self.log.warning(f'No TCP connection open at {self.config["rpc.connection.host"]}:{self.config["rpc.connection.path"]}, wait for it to become available...')
                    await asyncio.sleep(10)
        else:
            raise RuntimeError("invalid rpc connection type")
        self._reader = r
        self._writer = w
        self._read_task = self.loop.create_task(self._try_read_loop())
        await self._raw_request("register",
                                peer_id=self.config["appservice.address"],
                                register_config=self.config[self.register_config_key])
        self._is_connected.set()
        self._is_disconnected.clear()

    async def disconnect(self) -> None:
        async with self._connection_lock:
            if self._is_disconnected.cancelled():
                raise asyncio.CancelledError()
            if self._is_disconnected.is_set():
                return
            await self._disconnect()

    async def _disconnect(self) -> None:
        if self._writer is not None:
            self._writer.write_eof()
            await self._writer.drain()
        if self._read_task is not None:
            self._read_task.cancel()
            self._read_task = None
        self._on_disconnect()

    def _on_disconnect(self) -> None:
        self._req_id = 0
        self._min_broadcast_id = 0
        self._reader = None
        self._writer = None
        self._is_connected.clear()
        self._is_disconnected.set()

    def wait_for_connection(self) -> Awaitable[None]:
        return self._is_connected.wait()

    def wait_for_disconnection(self) -> Awaitable[None]:
        return self._is_disconnected.wait()

    def cancel(self) -> None:
        self._is_connected.cancel()
        self._is_disconnected.cancel()
        if self._connection_task is not None:
            self._connection_task.cancel()
        asyncio.run(self._disconnect())

    @property
    def _next_req_id(self) -> int:
        self._req_id += 1
        return self._req_id

    def add_event_handler(self, method: str, handler: EventHandler) -> None:
        self._event_handlers.setdefault(method, []).append(handler)

    def remove_event_handler(self, method: str, handler: EventHandler) -> None:
        try:
            self._event_handlers.setdefault(method, []).remove(handler)
        except ValueError:
            pass

    def set_event_handlers(self, method: str, handlers: list[EventHandler]) -> None:
        self._event_handlers[method] = handlers

    async def _run_event_handler(self, req_id: int, command: str, req: dict[str, Any]) -> None:
        if req_id > self._min_broadcast_id:
            self.log.debug(f"Ignoring duplicate broadcast {req_id}")
            return
        self._min_broadcast_id = req_id
        try:
            handlers = self._event_handlers[command]
        except KeyError:
            self.log.warning("No handlers for %s", command)
        else:
            for handler in handlers:
                try:
                    await handler(req)
                except:
                    self.log.exception("Exception in event handler")

    async def _handle_incoming_line(self, line: str) -> None:
        try:
            req = json.loads(line)
        except json.JSONDecodeError:
            self.log.debug(f"Got non-JSON data from server: {line}")
            return
        try:
            req_id = req.pop("id")
            command = req.pop("command")
            is_sequential = req.pop("is_sequential", False)
        except KeyError:
            self.log.debug(f"Got invalid request from server: {line}")
            return
        if req_id < 0:
            if not is_sequential:
                self.loop.create_task(self._run_event_handler(req_id, command, req))
            else:
                self._command_queue.put_nowait((req_id, command, req))
            return
        try:
            waiter = self._response_waiters[req_id]
        except KeyError:
            self.log.debug(f"Nobody waiting for response to {req_id}")
            return
        if command == "response":
            self.log.debug("Received response %d", req_id)
            waiter.set_result(req.get("response"))
        elif command == "error":
            waiter.set_exception(RPCError(req.get("error", line)))
        else:
            self.log.warning(f"Unexpected response command to {req_id}: {command} {req}")

    async def _command_loop(self) -> None:
        while True:
            req_id, command, req = await self._command_queue.get()
            await self._run_event_handler(req_id, command, req)
            self._command_queue.task_done()

    async def _try_read_loop(self) -> None:
        try:
            await self._read_loop()
        except asyncio.CancelledError:
            return
        except:
            self.log.exception("Fatal error in read loop")
        self.log.debug("Reader disconnected")
        self._on_disconnect()

    async def _read_loop(self) -> None:
        while self._reader is not None and not self._reader.at_eof():
            line = b''
            while True:
                try:
                    line += await self._reader.readuntil()
                    break
                except asyncio.IncompleteReadError as e:
                    line += e.partial
                    break
                except asyncio.LimitOverrunError as e:
                    self.log.warning(f"Buffer overrun: {e}")
                    line += await self._reader.read(e.consumed)
                except asyncio.CancelledError:
                    raise
            if not line:
                continue
            try:
                line_str = line.decode("utf-8")
            except UnicodeDecodeError:
                self.log.exception("Got non-unicode request from server: %s", line)
                continue
            try:
                await self._handle_incoming_line(line_str)
            except asyncio.CancelledError:
                raise
            except:
                self.log.exception("Failed to handle incoming request %s", line_str)

    async def _raw_request(self, command: str, **data: JSON) -> asyncio.Future[JSON]:
        req_id = self._next_req_id
        future = self._response_waiters[req_id] = self.loop.create_future()
        req = {"id": req_id, "command": command, **data}
        self.log.debug("Request %d: %s", req_id, command)
        assert self._writer is not None
        self._writer.write(json.dumps(req).encode("utf-8"))
        self._writer.write(b"\n")
        await self._writer.drain()
        return future

    async def request(self, command: str, **data: JSON) -> JSON:
        await self.wait_for_connection()
        future = await self._raw_request(command, **data)
        return await future