matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/user.py

654 lines
26 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
import asyncio
import time
from mautrix.bridge import BaseUser, async_getter_lock
from mautrix.types import (
EventID,
MessageType,
RoomID,
TextMessageEventContent,
UserID,
)
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
from mautrix.util.opt_prometheus import Gauge, Summary, async_time
from mautrix.util.simple_lock import SimpleLock
from . import portal as po, puppet as pu
from .config import Config
from .db import User as DBUser
from .kt.client import Client
from .kt.client.errors import AuthenticationRequired, ResponseError
from .kt.types.api.struct.profile import ProfileStruct
from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
from .kt.types.channel.channel_type import ChannelType
from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
METRIC_SYNC_CHANNELS = Summary("bridge_sync_channels", "calls to _sync_channels")
METRIC_RESYNC = Summary("bridge_on_resync", "calls to on_resync")
METRIC_UNKNOWN_EVENT = Summary("bridge_on_unknown_event", "calls to on_unknown_event")
METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added")
METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed")
METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing")
METRIC_PRESENCE = Summary("bridge_on_presence", "calls to on_presence")
METRIC_REACTION = Summary("bridge_on_reaction", "calls to on_reaction")
METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent")
METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen")
METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change")
METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change")
METRIC_THREAD_CHANGE = Summary("bridge_on_thread_change", "calls to on_thread_change")
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
if TYPE_CHECKING:
from .__main__ import KakaoTalkBridge
BridgeState.human_readable_errors.update(
{
"kt-reconnection-error": "Failed to reconnect to KakaoTalk",
"kt-connection-error": "KakaoTalk disconnected unexpectedly",
"kt-auth-error": "Authentication error from KakaoTalk: {message}",
"kt-start-error": "Startup error from KakaoTalk: {message}",
"kt-disconnected": None,
"logged-out": "You're not logged into KakaoTalk",
}
)
class User(DBUser, BaseUser):
#temp_disconnect_notices: bool = True
shutdown: bool = False
config: Config
by_mxid: dict[UserID, User] = {}
by_ktid: dict[int, User] = {}
client: Client | None
_notice_room_lock: asyncio.Lock
_notice_send_lock: asyncio.Lock
command_status: dict | None
is_admin: bool
permission_level: str
_is_logged_in: bool | None
#_is_connected: bool | None
#_connection_time: float
_prev_reconnect_fail_refresh: float
_db_instance: DBUser | None
_sync_lock: SimpleLock
_is_refreshing: bool
_logged_in_info: ProfileStruct | None
_logged_in_info_time: float
def __init__(
self,
mxid: UserID,
ktid: Long | None = None,
uuid: str | None = None,
access_token: str | None = None,
refresh_token: str | None = None,
notice_room: RoomID | None = None,
) -> None:
super().__init__(
mxid=mxid,
ktid=ktid,
uuid=uuid,
access_token=access_token,
refresh_token=refresh_token,
notice_room=notice_room
)
BaseUser.__init__(self)
self.notice_room = notice_room
self._notice_room_lock = asyncio.Lock()
self._notice_send_lock = asyncio.Lock()
self.command_status = None
(
self.relay_whitelisted,
self.is_whitelisted,
self.is_admin,
self.permission_level,
) = self.config.get_permissions(mxid)
self._is_logged_in = None
#self._is_connected = None
#self._connection_time = time.monotonic()
self._prev_reconnect_fail_refresh = time.monotonic()
self._sync_lock = SimpleLock(
"Waiting for thread sync to finish before handling %s", log=self.log
)
self._is_refreshing = False
self._logged_in_info = None
self._logged_in_info_time = 0
self.client = None
@classmethod
def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]:
cls.bridge = bridge
cls.config = bridge.config
cls.az = bridge.az
cls.loop = bridge.loop
#cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
"""
@property
def is_connected(self) -> bool | None:
return self._is_connected
@is_connected.setter
def is_connected(self, val: bool | None) -> None:
if self._is_connected != val:
self._is_connected = val
self._connection_time = time.monotonic()
@property
def connection_time(self) -> float:
return self._connection_time
"""
@property
def has_state(self) -> bool:
return bool(self.uuid and self.ktid and self.access_token and self.refresh_token)
# region Database getters
def _add_to_cache(self) -> None:
self.by_mxid[self.mxid] = self
if self.ktid:
self.by_ktid[self.ktid] = self
@classmethod
async def all_logged_in(cls) -> AsyncGenerator["User", None]:
users = await super().all_logged_in()
user: cls
for user in users:
try:
yield cls.by_mxid[user.mxid]
except KeyError:
user._add_to_cache()
yield user
@classmethod
@async_getter_lock
async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> User | None:
if pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid:
return None
try:
return cls.by_mxid[mxid]
except KeyError:
pass
user = cast(cls, await super().get_by_mxid(mxid))
if user is not None:
user._add_to_cache()
return user
if create:
cls.log.debug(f"Creating user instance for {mxid}")
user = cls(mxid)
await user.insert()
user._add_to_cache()
return user
return None
@classmethod
@async_getter_lock
async def get_by_ktid(cls, ktid: int) -> User | None:
try:
return cls.by_ktid[ktid]
except KeyError:
pass
user = cast(cls, await super().get_by_ktid(ktid))
if user is not None:
user._add_to_cache()
return user
return None
async def get_uuid(self, force: bool = False) -> str:
if self.uuid is None or force:
self.uuid = await self._generate_uuid()
await self.save()
return self.uuid
async def _generate_uuid(self) -> str:
return await Client.generate_uuid(await self.get_all_uuids())
# endregion
@property
def oauth_credential(self) -> OAuthCredential:
return OAuthCredential(
self.ktid,
self.uuid,
self.access_token,
self.refresh_token,
)
@oauth_credential.setter
def oauth_credential(self, oauth_credential: OAuthCredential) -> None:
self.ktid = oauth_credential.userId
self.access_token = oauth_credential.accessToken
self.refresh_token = oauth_credential.refreshToken
if self.uuid != oauth_credential.deviceUUID:
self.log.warn(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
self.uuid = oauth_credential.deviceUUID
async def get_own_info(self) -> ProfileStruct:
if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic():
self._logged_in_info = await self.client.fetch_logged_in_user()
self._logged_in_info_time = time.monotonic()
return self._logged_in_info
async def _load_session(self, is_startup: bool) -> bool:
if self._is_logged_in and not is_startup:
return True
elif not self.has_state:
# If we have a user in the DB with no state, we can assume
# KT logged us out and the bridge has restarted
await self.push_bridge_state(
BridgeStateEvent.BAD_CREDENTIALS,
error="logged-out",
)
return False
client = Client(self, log=self.log.getChild("ktclient"))
user_info = await self.fetch_logged_in_user(client)
if user_info:
self.log.info("Loaded session successfully")
self.client = client
self._logged_in_info = user_info
self._logged_in_info_time = time.monotonic()
self._track_metric(METRIC_LOGGED_IN, True)
self._is_logged_in = True
#self.is_connected = None
self.stop_listen()
asyncio.create_task(self.post_login(is_startup=is_startup))
return True
return False
async def _send_reset_notice(self, e: AuthenticationRequired, edit: EventID | None = None) -> None:
await self.send_bridge_notice(
"Got authentication error from KakaoTalk:\n\n"
f"> {e.message}\n\n"
"If you changed your KakaoTalk password, this "
"is normal and you just need to log in again.",
edit=edit,
important=True,
state_event=BridgeStateEvent.BAD_CREDENTIALS,
error_code="kt-auth-error",
error_message=str(e),
)
await self.logout(remove_ktid=False)
async def fetch_logged_in_user(
self, client: Client | None = None, action: str = "restore session"
) -> ProfileStruct:
if not client:
client = self.client
# TODO Retry network connection failures here, or in the client?
try:
return await client.fetch_logged_in_user()
# NOTE Not catching InvalidAccessToken here, as client handles it & tries to refresh the token
except AuthenticationRequired as e:
if action != "restore session":
await self._send_reset_notice(e)
raise
except Exception:
self.log.exception(f"Failed to {action}")
raise
async def is_logged_in(self, _override: bool = False) -> bool:
if not self.has_state or not self.client:
return False
if self._is_logged_in is None or _override:
try:
self._is_logged_in = bool(await self.get_own_info())
except Exception:
self.log.exception("Exception checking login status")
self._is_logged_in = False
return self._is_logged_in
async def reload_session(
self, event_id: EventID | None = None, retries: int = 3, is_startup: bool = False
) -> None:
try:
await self._load_session(is_startup=is_startup)
except AuthenticationRequired as e:
await self._send_reset_notice(e, edit=event_id)
# TODO Throw a ResponseError on network failures
except ResponseError as e:
will_retry = retries > 0
retry = "Retrying in 1 minute" if will_retry else "Not retrying"
notice = f"Failed to connect to KakaoTalk: unknown response error {e}. {retry}"
if will_retry:
await self.send_bridge_notice(
notice,
edit=event_id,
state_event=BridgeStateEvent.TRANSIENT_DISCONNECT,
)
await asyncio.sleep(60)
await self.reload_session(event_id, retries - 1)
else:
await self.send_bridge_notice(
notice,
edit=event_id,
important=True,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
except Exception:
await self.send_bridge_notice(
"Failed to connect to KakaoTalk: unknown error (see logs for more details)",
edit=event_id,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
finally:
self._is_refreshing = False
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> bool:
ok = True
self.stop_listen()
if self.has_state:
# TODO Log out of KakaoTalk if an API exists for it
pass
if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
self._track_metric(METRIC_LOGGED_IN, False)
self._is_logged_in = False
#self.is_connected = None
if self.client:
await self.client.stop()
self.client = None
if self.ktid and remove_ktid:
#await UserPortal.delete_all(self.ktid)
del self.by_ktid[self.ktid]
self.ktid = None
if reset_device:
self.uuid = await self._generate_uuid()
self.access_token = None
self.refresh_token = None
await self.save()
return ok
async def post_login(self, is_startup: bool) -> bool:
self.log.info("Running post-login actions")
self._add_to_cache()
try:
puppet = await pu.Puppet.get_by_ktid(self.ktid)
if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
self.log.info(f"Automatically enabling custom puppet")
await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
except Exception:
self.log.exception("Failed to automatically enable custom puppet")
assert self.client
try:
# TODO if not is_startup, close existing listeners
login_result = await self.client.start()
await self._sync_channels(login_result, is_startup)
self.start_listen()
return True
except AuthenticationRequired as e:
await self.send_bridge_notice(
f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n",
important=True,
state_event=BridgeStateEvent.BAD_CREDENTIALS,
error_code="kt-auth-error",
error_message=str(e),
)
await self.logout(remove_ktid=False)
except Exception as e:
self.log.exception("Failed to start client")
await self.send_bridge_notice(
f"Got error from KakaoTalk:\n\n> {e!s}\n\n",
important=True,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-start-error",
error_message=str(e),
)
return False
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
return {
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
async for portal in po.Portal.get_all_by_receiver(self.ktid)
if portal.mxid
}
@async_time(METRIC_SYNC_CHANNELS)
async def _sync_channels(self, login_result: LoginResult, is_startup: bool) -> None:
# TODO Look for a way to sync all channels without (re-)logging in
sync_count = self.config["bridge.initial_chat_sync"]
if sync_count <= 0 or not self.config["bridge.sync_on_startup"] and is_startup:
self.log.debug(f"Skipping channel syncing{' on startup' if sync_count > 0 else ''}")
return
if not login_result.channelList:
self.log.debug("No channels to sync")
return
# TODO What about removed channels? Don't early-return then
sync_count = min(sync_count, len(login_result.channelList))
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
self.log.debug(f"Syncing {sync_count} of {login_result.channelList} channels...")
for channel_item in login_result.channelList[:sync_count]:
# TODO try-except here, above, below?
await self._sync_channel(channel_item)
async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None:
channel_data = channel_item.channel
self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {channel_item.lastUpdate})")
channel_info = channel_data.info
if isinstance(channel_data, NormalChannelData):
channel_data: NormalChannelData
channel_info: NormalChannelInfo
self.log.debug(f"Join time: {channel_info.joinTime}")
elif isinstance(channel_data, OpenChannelData):
channel_data: OpenChannelData
self.log.debug(f"channel_data link ID: {channel_data.linkId}")
channel_info: OpenChannelInfo
self.log.debug(f"channel_info link ID: {channel_info.linkId}")
self.log.debug(f"openToken: {channel_info.openToken}")
self.log.debug(f"Is direct channel: {channel_info.directChannel}")
self.log.debug(f"Has OpenLink: {channel_info.openLink is not None}")
else:
self.log.error(f"Unexpected channel type: {type(channel_data)}")
channel_info: ChannelInfo
self.log.debug(f"channel_info channel ID: {channel_info.channelId}")
self.log.debug(f"Channel data/info IDs match: {channel_data.channelId == channel_info.channelId}")
self.log.debug(f"Channel type: {channel_info.type}")
self.log.debug(f"Active user count: {channel_info.activeUserCount}")
self.log.debug(f"New chat count: {channel_info.newChatCount}")
self.log.debug(f"New chat count invalid: {channel_info.newChatCountInvalid}")
self.log.debug(f"Last chat log ID: {channel_info.lastChatLogId}")
self.log.debug(f"Last seen log ID: {channel_info.lastSeenLogId}")
self.log.debug(f"Has last chat log: {channel_info.lastChatLog is not None}")
self.log.debug(f"metaMap: {channel_info.metaMap}")
self.log.debug(f"User count: {len(channel_info.displayUserList)}")
self.log.debug(f"Has push alert: {channel_info.pushAlert}")
for display_user_info in channel_info.displayUserList:
self.log.debug(f"Member: {display_user_info.nickname} - {display_user_info.profileURL} - {display_user_info.userId}")
portal = await po.Portal.get_by_ktid(
channel_info.channelId,
kt_receiver=self.ktid,
kt_type=channel_info.type
)
portal_info = await self.client.get_portal_channel_info(channel_info.channelId)
portal_info.channel_info = channel_info
if not portal.mxid:
await portal.create_matrix_room(self, portal_info)
else:
await portal.update_matrix_room(self, portal_info)
await portal.backfill(self, is_initial=False, channel_info=channel_info)
async def get_notice_room(self) -> RoomID:
if not self.notice_room:
async with self._notice_room_lock:
# If someone already created the room while this call was waiting,
# don't make a new room
if self.notice_room:
return self.notice_room
creation_content = {}
if not self.config["bridge.federate_rooms"]:
creation_content["m.federate"] = False
self.notice_room = await self.az.intent.create_room(
is_direct=True,
invitees=[self.mxid],
topic="KakaoTalk bridge notices",
creation_content=creation_content,
)
await self.save()
return self.notice_room
async def send_bridge_notice(
self,
text: str,
edit: EventID | None = None,
state_event: BridgeStateEvent | None = None,
important: bool = False,
error_code: str | None = None,
error_message: str | None = None,
) -> EventID | None:
if state_event:
await self.push_bridge_state(
state_event,
error=error_code,
message=error_message if error_code else text,
)
if self.config["bridge.disable_bridge_notices"]:
return None
event_id = None
try:
self.log.debug("Sending bridge notice: %s", text)
content = TextMessageEventContent(
body=text,
msgtype=(MessageType.TEXT if important else MessageType.NOTICE),
)
if edit:
content.set_edit(edit)
# This is locked to prevent notices going out in the wrong order
async with self._notice_send_lock:
event_id = await self.az.intent.send_message(await self.get_notice_room(), content)
except Exception:
self.log.warning("Failed to send bridge notice", exc_info=True)
return edit or event_id
async def fill_bridge_state(self, state: BridgeState) -> None:
await super().fill_bridge_state(state)
if self.ktid:
state.remote_id = str(self.ktid)
puppet = await pu.Puppet.get_by_ktid(self.ktid)
state.remote_name = puppet.name
async def get_bridge_states(self) -> list[BridgeState]:
self.log.info("TODO: get_bridge_states")
return []
"""
if not self.state:
return []
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED
elif self._is_refreshing or self.mqtt:
state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state]
"""
async def get_puppet(self) -> pu.Puppet | None:
if not self.ktid:
return None
return await pu.Puppet.get_by_ktid(self.ktid)
# region KakaoTalk event handling
def start_listen(self) -> None:
self.listen_task = asyncio.create_task(self._try_listen())
def _disconnect_listener_after_error(self) -> None:
self.log.info("TODO: _disconnect_listener_after_error")
async def _try_listen(self) -> None:
try:
# TODO Pass all listeners to start_listen instead of registering them one-by-one?
await self.client.start_listen()
await self.client.on_message(self.on_message)
# TODO Handle auth errors specially?
#except AuthenticationRequired as e:
except Exception:
#self.is_connected = False
self.log.exception("Fatal error in listener")
await self.send_bridge_notice(
"Fatal error in listener (see logs for more info)",
state_event=BridgeStateEvent.UNKNOWN_ERROR,
important=True,
error_code="kt-connection-error",
)
self._disconnect_listener_after_error()
def stop_listen(self) -> None:
self.log.info("TODO: stop_listen")
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential
self.client = Client(self, log=self.log.getChild("ktclient"))
await self.save()
self._is_logged_in = True
try:
self._logged_in_info = await self.client.fetch_logged_in_user(post_login=True)
self._logged_in_info_time = time.monotonic()
except Exception:
self.log.exception("Failed to fetch post-login info")
self.stop_listen()
asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE)
async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None:
portal = await po.Portal.get_by_ktid(
channel_id,
kt_receiver=self.ktid,
kt_type=channel_type
)
puppet = await pu.Puppet.get_by_ktid(evt.sender.userId)
await portal.backfill_lock.wait(evt.logId)
if not puppet.name:
portal.schedule_resync(self, puppet)
# TODO reply_to
await portal.handle_remote_message(self, puppet, evt)
# TODO Many more handlers
# endregion