# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
Client functionality for the KakaoTalk API.
Currently a wrapper around a Node backend, but
the abstraction used here should be compatible
with any other potential backend.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, cast, Awaitable, Type, Optional, Union
import asyncio
from contextlib import asynccontextmanager
import logging
import urllib.request

from aiohttp import ClientSession
from aiohttp.client import _RequestContextManager
from yarl import URL

from mautrix.types import SerializerError
from mautrix.util.logging import TraceLogger

from ...config import Config
from ...rpc import EventHandler, RPCClient

from ..types.api.struct import (
    FriendListStruct,
    FriendReqStruct,
    FriendStruct,
    ProfileReqStruct,
    ProfileStruct,
)
from ..types.bson import Long
from ..types.channel.channel_info import ChannelInfo
from ..types.chat import Chatlog, KnownChatType
from ..types.chat.attachment import MentionStruct, ReplyAttachment
from ..types.client.client_session import LoginResult
from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.openlink.open_link_type import OpenChannelUserPerm
from ..types.openlink.open_link_user_info import OpenLinkChannelUserInfo
from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes
from ..types.request import (
    deserialize_result,
    ResultType,
    ResultListType,
    ResultRawType,
    RootCommandResult,
    CommandResultDoneValue
)

from .types import (
    ChannelProps,
    PortalChannelInfo,
    PortalChannelParticipantInfo,
    Receipt,
    SettingsStruct,
    UserInfoUnion,
)

from .errors import InvalidAccessToken, CommandException
from .error_helper import raise_unsuccessful_response

try:
    from aiohttp_socks import ProxyConnector
except ImportError:
    ProxyConnector = None

if TYPE_CHECKING:
    from mautrix.types import JSON
    from ... import user as u


@asynccontextmanager
async def sandboxed_get(url: URL) -> _RequestContextManager:
    async with ClientSession() as sess, sess.get(url) as resp:
        yield resp


# TODO Consider defining an interface for this, with node/native backend as swappable implementations
# TODO If no state is stored, consider using free functions instead of classmethods
class Client:
    _rpc_client: RPCClient

    @classmethod
    def init_cls(cls, config: Config) -> None:
        """Initialize RPC to the Node backend."""
        cls._rpc_client = RPCClient(config, "kakaotalk")
        # NOTE No need to store this, as cancelling the RPCClient will cancel this too
        asyncio.create_task(cls._keep_connected())

    @classmethod
    async def _keep_connected(cls) -> None:
        while True:
            await cls._rpc_client.connect()
            await cls._rpc_client.wait_for_disconnection()

    @classmethod
    def stop_cls(cls) -> None:
        """Stop and disconnect from the Node backend."""
        cls._rpc_client.cancel()


    # region tokenless commands

    @classmethod
    async def generate_uuid(cls, used_uuids: set[str]) -> str:
        """Randomly generate a UUID for a (fake) device."""
        tries_remaining = 10
        while True:
            uuid = await cls._rpc_client.request("generate_uuid")
            if uuid not in used_uuids:
                return uuid
            tries_remaining -= 1
            if tries_remaining == 0:
                raise Exception(
                    "Unable to generate a UUID that hasn't been used before. "
                    "Either use a different RNG, or buy a lottery ticket"
                )

    @classmethod
    def register_device(cls, passcode: str, **req: JSON) -> Awaitable[None]:
        """Register a (fake) device that will be associated with the provided login credentials."""
        return cls._api_request_void("register_device", passcode=passcode, **req)

    @classmethod
    def login(cls, uuid: str, form: JSON, forced: bool) -> Awaitable[OAuthCredential]:
        """
        Obtain a session token by logging in with user-provided credentials.
        Must have first called register_device with these credentials.
        """
        # NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential
        return cls._api_request_result(
            OAuthCredential,
            "login",
            uuid=uuid,
            form=form,
            forced=forced,
        )

    # endregion


    user: u.User
    _rpc_disconnection_task: asyncio.Task | None
    http: ClientSession
    log: TraceLogger
    _handler_methods: list[str]

    def __init__(self, user: u.User, log: Optional[TraceLogger] = None):
        """Create a per-user client object for user-specific client functionality."""
        self.user = user
        self._rpc_disconnection_task = None
        self._handler_methods = []

        # TODO Let the Node backend use a proxy too!
        connector = None
        try:
            http_proxy = urllib.request.getproxies()["http"]
        except KeyError:
            pass
        else:
            if ProxyConnector:
                connector = ProxyConnector.from_url(http_proxy)
            else:
                self.log.warning("http_proxy is set, but aiohttp-socks is not installed")
        self.http = ClientSession(connector=connector)

        self.log = log or logging.getLogger("mw.ktclient")

    @property
    def _oauth_credential(self) -> JSON:
        return self.user.oauth_credential.serialize()

    # region HTTP

    def get(
        self,
        url: Union[str, URL],
        headers: Optional[dict[str, str]] = None,
        sandbox: bool = False,
        **kwargs,
    ) -> _RequestContextManager:
        # TODO Is auth ever needed?
        headers = {
            # TODO Are any default headers needed?
            #**self._headers,
            **(headers or {}),
        }
        url = URL(url)
        if sandbox:
            return sandboxed_get(url)
        return self.http.get(url, headers=headers, **kwargs)

    # endregion


    # region post-token commands

    async def start(self) -> SettingsStruct | None:
        """
        Initialize user-specific bridging & state by providing a token obtained from a prior login.
        Receive the user's profile info in response.
        """
        try:
            settings_struct = await self._api_user_cred_request_result(SettingsStruct, "start")
        except SerializerError:
            self.log.exception("Unable to deserialize settings struct, but starting client anyways")
            settings_struct = None
        if not self._rpc_disconnection_task:
            self._rpc_disconnection_task = asyncio.create_task(self._rpc_disconnection_handler())
        else:
            self.log.warning("Called \"start\" on an already-started client")
        return settings_struct

    async def stop(self) -> None:
        """Immediately stop bridging this user."""
        self._stop_listen()
        if self._rpc_disconnection_task:
            self._rpc_disconnection_task.cancel()
        else:
            self.log.warning("Called \"stop\" on an already-stopped client")
        await self._rpc_client.request("stop", mxid=self.user.mxid)

    async def _rpc_disconnection_handler(self) -> None:
        await self._rpc_client.wait_for_disconnection()
        self._rpc_disconnection_task = None
        self._stop_listen()
        asyncio.create_task(self.user.on_client_disconnect())

    async def renew_and_save(self) -> None:
        """Renew and save the user's session tokens."""
        oauth_info = await self._api_user_cred_request_result(OAuthInfo, "renew", renew=False)
        self.user.oauth_credential = oauth_info.credential
        await self.user.save()

    async def connect(self) -> LoginResult | None:
        """
        Start a new talk session by providing a token obtained from a prior login.
        Receive a snapshot of account state in response.
        """
        try:
            login_result = await self._api_user_cred_request_result(LoginResult, "connect")
            assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
        except SerializerError:
            self.log.exception("Unable to deserialize login result, but connecting anyways")
            login_result = None
        # TODO Skip if handlers are already listening. But this is idempotent and thus probably safe
        self._start_listen()
        return login_result

    async def disconnect(self) -> None:
        """Disconnect from the talk session, but remain logged in."""
        await self._rpc_client.request("disconnect", mxid=self.user.mxid)
        await self._on_disconnect(None)

    def is_connected(self) -> Awaitable[bool]:
        return self._rpc_client.request("is_connected", mxid=self.user.mxid)

    def get_settings(self) -> Awaitable[SettingsStruct]:
        return self._api_user_request_result(SettingsStruct, "get_settings")

    async def get_own_profile(self) -> ProfileStruct:
        profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
        return profile_req_struct.profile

    async def get_profile(self, user_id: Long) -> ProfileStruct:
        profile_req_struct = await self._api_user_request_result(
            ProfileReqStruct,
            "get_profile",
            user_id=user_id.serialize(),
        )
        return profile_req_struct.profile

    def get_portal_channel_info(self, channel_props: ChannelProps) -> Awaitable[PortalChannelInfo]:
        return self._api_user_request_result(
            PortalChannelInfo,
            "get_portal_channel_info",
            channel_props=channel_props.serialize(),
        )

    def get_portal_channel_participant_info(self, channel_props: ChannelProps) -> Awaitable[PortalChannelParticipantInfo]:
        return self._api_user_request_result(
            PortalChannelParticipantInfo,
            "get_portal_channel_participant_info",
            channel_props=channel_props.serialize(),
        )

    def get_participants(self, channel_props: ChannelProps) -> Awaitable[list[UserInfoUnion]]:
        return self._api_user_request_result(
            ResultListType(UserInfoUnion),
            "get_participants",
            channel_props=channel_props.serialize(),
        )

    def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> Awaitable[list[Chatlog]]:
        return self._api_user_request_result(
            ResultListType(Chatlog),
            "get_chats",
            channel_props=channel_props.serialize(),
            sync_from=sync_from.serialize() if sync_from else None,
            limit=limit,
        )

    def get_read_receipts(self, channel_props: ChannelProps, unread_chat_ids: list[Long]) -> Awaitable[list[Receipt]]:
        return self._api_user_request_result(
            ResultListType(Receipt),
            "get_read_receipts",
            channel_props=channel_props.serialize(),
            unread_chat_ids=[c.serialize() for c in unread_chat_ids],
        )

    async def can_change_uuid(self, uuid: str) -> bool:
        try:
            await self._api_user_request_void("can_change_uuid", uuid=uuid)
        except CommandException:
            return False
        else:
            return True

    def change_uuid(self, uuid: str) -> Awaitable[None]:
        return self._api_user_request_void("change_uuid", uuid=uuid)

    def set_uuid_searchable(self, searchable: bool) -> Awaitable[None]:
        return self._api_user_request_void("set_uuid_searchable", searchable=searchable)

    def list_friends(self) -> Awaitable[FriendListStruct]:
        return self._api_user_request_result(
            FriendListStruct,
            "list_friends",
        )

    async def edit_friend(self, ktid: Long, add: bool) -> FriendStruct | None:
        try:
            friend_req_struct = await self._api_user_request_result(
                FriendReqStruct,
                "edit_friend",
                user_id=ktid.serialize(),
                add=add,
            )
            return friend_req_struct.friend
        except SerializerError:
            self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
            return None

    async def edit_friend_by_uuid(self, uuid: str, add: bool) -> FriendStruct | None:
        try:
            friend_req_struct = await self._api_user_request_result(
                FriendReqStruct,
                "edit_friend_by_uuid",
                uuid=uuid,
                add=add,
            )
            return friend_req_struct.friend
        except SerializerError:
            self.log.exception("Unable to deserialize friend struct, but friend should have been edited nonetheless")
            return None

    async def get_friend_dm_id(self, friend_id: Long) -> Long | None:
        try:
            return await self._api_user_request_result(
                Long,
                "get_friend_dm_id",
                friend_id=friend_id.serialize(),
            )
        except CommandException:
            self.log.exception(f"Could not find friend with ID {friend_id}")
            return None

    async def get_memo_ids(self) -> list[Long]:
        return ResultListType(Long).deserialize(
            await self._rpc_client.request("get_memo_ids", mxid=self.user.mxid)
        )

    def download_file(
        self,
        channel_props: ChannelProps,
        key: str,
    ) -> Awaitable[bytes]:
        return self._api_user_request_result(
            ResultRawType(bytes),
            "download_file",
            channel_props=channel_props.serialize(),
            key=key,
        )

    def send_chat(
        self,
        channel_props: ChannelProps,
        text: str,
        reply_to: ReplyAttachment | None,
        mentions: list[MentionStruct] | None,
    ) -> Awaitable[Long]:
        return self._api_user_request_result(
            Long,
            "send_chat",
            channel_props=channel_props.serialize(),
            text=text,
            reply_to=reply_to.serialize() if reply_to is not None else None,
            mentions=[m.serialize() for m in mentions] if mentions is not None else None,
        )

    def send_media(
        self,
        channel_props: ChannelProps,
        media_type: KnownChatType,
        data: bytes,
        filename: str,
        *,
        width: int | None = None,
        height: int | None = None,
        ext: str | None = None,
    ) -> Awaitable[Long]:
        return self._api_user_request_result(
            Long,
            "send_media",
            channel_props=channel_props.serialize(),
            type=media_type,
            data=list(data),
            name=filename,
            width=width,
            height=height,
            ext=ext,
        )

    def delete_chat(
        self,
        channel_props: ChannelProps,
        chat_id: Long,
    ) -> Awaitable[None]:
        return self._api_user_request_void(
            "delete_chat",
            channel_props=channel_props.serialize(),
            chat_id=chat_id.serialize(),
        )

    def mark_read(
        self,
        channel_props: ChannelProps,
        read_until_chat_id: Long,
    ) -> Awaitable[None]:
        return self._api_user_request_void(
            "mark_read",
            channel_props=channel_props.serialize(),
            read_until_chat_id=read_until_chat_id.serialize(),
        )

    def send_perm(
        self,
        channel_props: ChannelProps,
        user_id: Long,
        perm: OpenChannelUserPerm,
    ) -> Awaitable[None]:
        return self._api_user_request_void(
            "send_perm",
            channel_props=channel_props.serialize(),
            user_id=user_id.serialize(),
            perm=perm,
        )

    def set_channel_name(
        self,
        channel_props: ChannelProps,
        name: str,
    ) -> Awaitable[None]:
        return self._api_user_request_void(
            "set_channel_name",
            channel_props=channel_props.serialize(),
            name=name,
        )

    def set_channel_description(
        self,
        channel_props: ChannelProps,
        description: str,
    ) -> Awaitable[None]:
        return self._api_user_request_void(
            "set_channel_description",
            channel_props=channel_props.serialize(),
            description=description,
        )

    def set_channel_photo(
        self,
        channel_props: ChannelProps,
        photo_url: str,
    ) -> Awaitable[None]:
        return self._api_user_request_void(
            "set_channel_photo",
            channel_props=channel_props.serialize(),
            photo_url=photo_url,
        )

    def create_direct_chat(self, ktid: Long) -> Awaitable[Long]:
        return self._api_user_request_result(
            Long,
            "create_direct_chat",
            user_id=ktid.serialize(),
        )

    def leave_channel(
        self,
        channel_props: ChannelProps,
    ) -> Awaitable[None]:
        return self._api_user_request_void(
            "leave_channel",
            channel_props=channel_props.serialize(),
        )


    # TODO Combine each of these pairs into one

    async def _api_user_request_result(
        self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
    ) -> ResultType:
        while True:
            try:
                return await self._api_request_result(result_type, command, mxid=self.user.mxid, **data)
            except InvalidAccessToken:
                if not renew:
                    raise
                await self.renew_and_save()
                renew = False

    async def _api_user_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
        while True:
            try:
                return await self._api_request_void(command, mxid=self.user.mxid, **data)
            except InvalidAccessToken:
                if not renew:
                    raise
                await self.renew_and_save()
                renew = False


    async def _api_user_cred_request_result(
        self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
    ) -> ResultType:
        while True:
            try:
                return await self._api_user_request_result(
                    result_type, command, oauth_credential=self._oauth_credential, renew=False, **data
                )
            except InvalidAccessToken:
                if not renew:
                    raise
                await self.renew_and_save()
                renew = False

    async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
        while True:
            try:
                await self._api_user_request_void(
                    command, oauth_credential=self._oauth_credential, renew=False, **data
                )
            except InvalidAccessToken:
                if not renew:
                    raise
                await self.renew_and_save()
                renew = False

    # endregion


    # region listeners

    def _on_chat(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_chat(
            Chatlog.deserialize(data["chatlog"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_chat_deleted(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_chat_deleted(
            Long.deserialize(data["chatId"]),
            Long.deserialize(data["senderId"]),
            int(data["timestamp"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_chat_read(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_chat_read(
            Long.deserialize(data["chatId"]),
            Long.deserialize(data["senderId"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_profile_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_profile_changed(
            OpenLinkChannelUserInfo.deserialize(data["info"]),
        )

    def _on_perm_changed(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_perm_changed(
            Long.deserialize(data["userId"]),
            OpenChannelUserPerm(data["perm"]),
            Long.deserialize(data["senderId"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_channel_added(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_channel_added(
            ChannelInfo.deserialize(data["channelInfo"]),
        )

    def _on_channel_join(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_channel_join(
            ChannelInfo.deserialize(data["channelInfo"]),
        )

    def _on_channel_left(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_channel_left(
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_channel_kicked(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_channel_kicked(
            Long.deserialize(data["userId"]),
            Long.deserialize(data["senderId"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_user_join(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_user_join(
            Long.deserialize(data["userId"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_user_left(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_user_left(
            Long.deserialize(data["userId"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )

    def _on_channel_meta_change(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_channel_meta_change(
            PortalChannelInfo.deserialize(data["info"]),
            Long.deserialize(data["channelId"]),
            str(data["channelType"]),
        )


    def _on_listen_disconnect(self, data: dict[str, JSON]) -> Awaitable[None]:
        try:
            res = KickoutRes.deserialize(data)
        except Exception:
            self.log.exception("Invalid kickout reason, defaulting to None")
            res = None
        return self._on_disconnect(res)

    def _on_switch_server(self, _: dict[str, JSON]) -> Awaitable[None]:
        # TODO Reconnect automatically instead
        return self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))

    def _on_disconnect(self, res: KickoutRes | None) -> Awaitable[None]:
        self._stop_listen()
        return self.user.on_disconnect(res)

    def _on_error(self, data: dict[str, JSON]) -> Awaitable[None]:
        return self.user.on_error(data)


    def _start_listen(self) -> None:
        self._add_event_handler("chat", self._on_chat)
        self._add_event_handler("chat_deleted", self._on_chat_deleted)
        self._add_event_handler("chat_read", self._on_chat_read)
        self._add_event_handler("profile_changed", self._on_profile_changed)
        self._add_event_handler("perm_changed", self._on_perm_changed)
        self._add_event_handler("channel_added", self._on_channel_added)
        self._add_event_handler("channel_join", self._on_channel_join)
        self._add_event_handler("channel_left", self._on_channel_left)
        self._add_event_handler("channel_kicked", self._on_channel_kicked)
        self._add_event_handler("user_join", self._on_user_join)
        self._add_event_handler("user_left", self._on_user_left)
        self._add_event_handler("channel_meta_change", self._on_channel_meta_change)
        self._add_event_handler("disconnected", self._on_listen_disconnect)
        self._add_event_handler("switch_server", self._on_switch_server)
        self._add_event_handler("error", self._on_error)

    def _stop_listen(self) -> None:
        for method in self._handler_methods:
            self._rpc_client.set_event_handlers(self._get_user_cmd(method), [])


    def _add_event_handler(self, method: str, handler: EventHandler):
        self._rpc_client.set_event_handlers(self._get_user_cmd(method), [handler])
        self._handler_methods.append(method)

    def _get_user_cmd(self, command) -> str:
        return f"{command}:{self.user.mxid}"

    # endregion


    @classmethod
    async def _api_request_result(
        cls, result_type: Type[ResultType], command: str, **data: JSON
    ) -> ResultType:
        """
        Call a command via RPC, and return its result object.
        On failure, raise an appropriate exception.
        """
        resp = deserialize_result(
            result_type,
            await cls._rpc_client.request(command, **data)
        )
        if not resp.success:
            raise_unsuccessful_response(resp)
        # NOTE Not asserting against CommandResultDoneValue because it's generic!
        # TODO Check if there really is no way to do it.
        assert type(resp) is not RootCommandResult, "Result object missing from successful response"
        return cast(CommandResultDoneValue[ResultType], resp).result

    @classmethod
    async def _api_request_void(cls, command: str, **data: JSON) -> None:
        resp = RootCommandResult.deserialize(
            await cls._rpc_client.request(command, **data)
        )
        if not resp.success:
            raise_unsuccessful_response(resp)