# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

from asyncpg import Record
from attr import dataclass, field

from mautrix.types import EventID, RoomID
from mautrix.util.async_db import Database, Scheme

from ..kt.types.bson import Long

fake_db = Database.create("") if TYPE_CHECKING else None


@dataclass
class Message:
    db: ClassVar[Database] = fake_db

    mxid: EventID
    mx_room: RoomID
    ktid: Long | None = field(converter=lambda ktid: Long(ktid) if ktid is not None else None)
    index: int
    kt_chat: Long = field(converter=Long)
    kt_receiver: Long = field(converter=Long)
    timestamp: int

    @classmethod
    def _from_row(cls, row: Record) -> Message:
        return cls(**row)

    @classmethod
    def _from_optional_row(cls, row: Record | None) -> Message | None:
        return cls._from_row(row) if row is not None else None

    columns = 'mxid, mx_room, ktid, "index", kt_chat, kt_receiver, timestamp'

    @classmethod
    async def get_all_by_ktid(cls, ktid: int, kt_receiver: int) -> list[Message]:
        q = f"SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2"
        rows = await cls.db.fetch(q, ktid, kt_receiver)
        return [cls._from_row(row) for row in rows if row]

    @classmethod
    async def get_by_ktid(cls, ktid: int, kt_receiver: int, index: int = 0) -> Message | None:
        q = f'SELECT {cls.columns} FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
        row = await cls.db.fetchrow(q, ktid, kt_receiver, index)
        return cls._from_optional_row(row)

    @classmethod
    async def delete_all_by_room(cls, room_id: RoomID) -> None:
        await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)

    @classmethod
    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
        q = f"SELECT {cls.columns} FROM message WHERE mxid=$1 AND mx_room=$2"
        row = await cls.db.fetchrow(q, mxid, mx_room)
        return cls._from_optional_row(row)

    @classmethod
    async def get_most_recent(cls, kt_chat: int, kt_receiver: int) -> Message | None:
        q = (
            f"SELECT {cls.columns} "
            "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND ktid IS NOT NULL "
            "ORDER BY timestamp DESC LIMIT 1"
        )
        row = await cls.db.fetchrow(q, kt_chat, kt_receiver)
        return cls._from_optional_row(row)

    @classmethod
    async def get_closest_before(
        cls, kt_chat: int, kt_receiver: int, timestamp: int
    ) -> Message | None:
        q = (
            f"SELECT {cls.columns} "
            "FROM message WHERE kt_chat=$1 AND kt_receiver=$2 AND timestamp<=$3 AND "
            "                   ktid IS NOT NULL "
            "ORDER BY timestamp DESC LIMIT 1"
        )
        row = await cls.db.fetchrow(q, kt_chat, kt_receiver, timestamp)
        return cls._from_optional_row(row)

    _insert_query = (
        'INSERT INTO message (mxid, mx_room, ktid, "index", kt_chat, kt_receiver, '
        "                     timestamp) "
        "VALUES ($1, $2, $3, $4, $5, $6, $7)"
    )

    @classmethod
    async def bulk_create(
        cls,
        ktid: Long,
        kt_chat: Long,
        kt_receiver: Long,
        event_ids: list[EventID],
        timestamp: int,
        mx_room: RoomID,
    ) -> None:
        if not event_ids:
            return
        columns = [col.strip('"') for col in cls.columns.split(", ")]
        records = [
            (mxid, mx_room, ktid, index, kt_chat, kt_receiver, timestamp)
            for index, mxid in enumerate(event_ids)
        ]
        async with cls.db.acquire() as conn, conn.transaction():
            if cls.db.scheme == Scheme.POSTGRES:
                await conn.copy_records_to_table("message", records=records, columns=columns)
            else:
                await conn.executemany(cls._insert_query, records)


    async def insert(self) -> None:
        q = self._insert_query
        await self.db.execute(
            q,
            self.mxid,
            self.mx_room,
            self.ktid,
            self.index,
            self.kt_chat,
            self.kt_receiver,
            self.timestamp,
        )

    async def delete(self) -> None:
        q = 'DELETE FROM message WHERE ktid=$1 AND kt_receiver=$2 AND "index"=$3'
        await self.db.execute(q, self.ktid, self.kt_receiver, self.index)

    async def update(self) -> None:
        q = "UPDATE message SET ktid=$1, timestamp=$2 WHERE mxid=$3 AND mx_room=$4"
        await self.db.execute(q, self.ktid, self.timestamp, self.mxid, self.mx_room)