matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/user.py

615 lines
24 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
import asyncio
import time
from mautrix.bridge import BaseUser, async_getter_lock
from mautrix.types import (
EventID,
MessageType,
RoomID,
TextMessageEventContent,
UserID,
)
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
from mautrix.util.opt_prometheus import Gauge, Summary, async_time
from mautrix.util.simple_lock import SimpleLock
from . import portal as po, puppet as pu
from .config import Config
from .db import User as DBUser
from .kt.client import Client
from .kt.client.errors import AuthenticationRequired, ResponseError
from .kt.types.api.struct.profile import ProfileStruct
from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
from .kt.types.channel.channel_type import ChannelType
from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import LoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
METRIC_CONNECT_AND_SYNC = Summary("bridge_sync_channels", "calls to connect_and_sync")
METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added")
METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed")
METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing")
METRIC_PRESENCE = Summary("bridge_on_presence", "calls to on_presence")
METRIC_REACTION = Summary("bridge_on_reaction", "calls to on_reaction")
METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent")
METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen")
METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change")
METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change")
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
if TYPE_CHECKING:
from .__main__ import KakaoTalkBridge
BridgeState.human_readable_errors.update(
{
"kt-reconnection-error": "Failed to reconnect to KakaoTalk",
"kt-connection-error": "KakaoTalk disconnected unexpectedly",
"kt-auth-error": "Authentication error from KakaoTalk: {message}",
"kt-disconnected": None,
"logged-out": "You're not logged into KakaoTalk",
}
)
class User(DBUser, BaseUser):
shutdown: bool = False
config: Config
by_mxid: dict[UserID, User] = {}
by_ktid: dict[int, User] = {}
client: Client | None
_notice_room_lock: asyncio.Lock
_notice_send_lock: asyncio.Lock
is_admin: bool
permission_level: str
_is_logged_in: bool | None
_is_connected: bool | None
_connection_time: float
_db_instance: DBUser | None
_sync_lock: SimpleLock
_logged_in_info: ProfileStruct | None
_logged_in_info_time: float
def __init__(
self,
mxid: UserID,
ktid: Long | None = None,
uuid: str | None = None,
access_token: str | None = None,
refresh_token: str | None = None,
notice_room: RoomID | None = None,
) -> None:
super().__init__(
mxid=mxid,
ktid=ktid,
uuid=uuid,
access_token=access_token,
refresh_token=refresh_token,
notice_room=notice_room
)
BaseUser.__init__(self)
self.notice_room = notice_room
self._notice_room_lock = asyncio.Lock()
self._notice_send_lock = asyncio.Lock()
(
self.relay_whitelisted,
self.is_whitelisted,
self.is_admin,
self.permission_level,
) = self.config.get_permissions(mxid)
self._is_logged_in = None
self._is_connected = None
self._connection_time = time.monotonic()
self._sync_lock = SimpleLock(
"Waiting for thread sync to finish before handling %s", log=self.log
)
self._logged_in_info = None
self._logged_in_info_time = 0
self.client = None
@classmethod
def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]:
cls.bridge = bridge
cls.config = bridge.config
cls.az = bridge.az
cls.loop = bridge.loop
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
@property
def is_connected(self) -> bool | None:
return self._is_connected
@is_connected.setter
def is_connected(self, val: bool | None) -> None:
if self._is_connected != val:
self._is_connected = val
self._connection_time = time.monotonic()
@property
def connection_time(self) -> float:
return self._connection_time
@property
def has_state(self) -> bool:
# TODO If more state is needed, consider returning a saved LoginResult
return bool(self.access_token and self.refresh_token)
# region Database getters
def _add_to_cache(self) -> None:
self.by_mxid[self.mxid] = self
if self.ktid:
self.by_ktid[self.ktid] = self
@classmethod
async def all_logged_in(cls) -> AsyncGenerator["User", None]:
users = await super().all_logged_in()
user: cls
for user in users:
try:
yield cls.by_mxid[user.mxid]
except KeyError:
user._add_to_cache()
yield user
@classmethod
@async_getter_lock
async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> User | None:
if pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid:
return None
try:
return cls.by_mxid[mxid]
except KeyError:
pass
user = cast(cls, await super().get_by_mxid(mxid))
if user is not None:
user._add_to_cache()
return user
if create:
cls.log.debug(f"Creating user instance for {mxid}")
user = cls(mxid)
await user.insert()
user._add_to_cache()
return user
return None
@classmethod
@async_getter_lock
async def get_by_ktid(cls, ktid: int) -> User | None:
try:
return cls.by_ktid[ktid]
except KeyError:
pass
user = cast(cls, await super().get_by_ktid(ktid))
if user is not None:
user._add_to_cache()
return user
return None
async def get_uuid(self, force: bool = False) -> str:
if self.uuid is None or force:
self.uuid = await self._generate_uuid()
# TODO Maybe don't save yet
await self.save()
return self.uuid
@classmethod
async def _generate_uuid(cls) -> str:
return await Client.generate_uuid(await super().get_all_uuids())
# endregion
@property
def oauth_credential(self) -> OAuthCredential:
return OAuthCredential(
self.ktid,
self.uuid,
self.access_token,
self.refresh_token,
)
@oauth_credential.setter
def oauth_credential(self, oauth_credential: OAuthCredential) -> None:
self.ktid = oauth_credential.userId
self.access_token = oauth_credential.accessToken
self.refresh_token = oauth_credential.refreshToken
if self.uuid != oauth_credential.deviceUUID:
self.log.warn(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
self.uuid = oauth_credential.deviceUUID
async def get_own_info(self) -> ProfileStruct:
if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic():
self._logged_in_info = await self.client.fetch_logged_in_user()
self._logged_in_info_time = time.monotonic()
return self._logged_in_info
async def _load_session(self, is_startup: bool) -> bool:
if self._is_logged_in and not is_startup:
return True
elif not self.has_state:
# If we have a user in the DB with no state, we can assume
# KT logged us out and the bridge has restarted
await self.push_bridge_state(
BridgeStateEvent.BAD_CREDENTIALS,
error="logged-out",
)
return False
client = Client(self, log=self.log.getChild("ktclient"))
user_info = await self.fetch_logged_in_user(client)
if user_info:
self.log.info("Loaded session successfully")
self.client = client
self._logged_in_info = user_info
self._logged_in_info_time = time.monotonic()
self._track_metric(METRIC_LOGGED_IN, True)
self._is_logged_in = True
self.is_connected = None
asyncio.create_task(self.post_login(is_startup=is_startup))
return True
return False
async def _send_reset_notice(self, e: AuthenticationRequired, edit: EventID | None = None) -> None:
await self.send_bridge_notice(
"Got authentication error from KakaoTalk:\n\n"
f"> {e.message}\n\n"
"If you changed your KakaoTalk password, this "
"is normal and you just need to log in again.",
edit=edit,
important=True,
state_event=BridgeStateEvent.BAD_CREDENTIALS,
error_code="kt-auth-error",
error_message=str(e),
)
await self.logout(remove_ktid=False)
async def fetch_logged_in_user(
self, client: Client | None = None, action: str = "restore session"
) -> ProfileStruct:
if not client:
client = self.client
# TODO Retry network connection failures here, or in the client (like token refreshes are)?
try:
return await client.fetch_logged_in_user()
except AuthenticationRequired as e:
if action != "restore session":
await self._send_reset_notice(e)
raise
except Exception:
self.log.exception(f"Failed to {action}")
raise
async def is_logged_in(self, _override: bool = False) -> bool:
if not self.has_state or not self.client:
return False
if self._is_logged_in is None or _override:
try:
self._is_logged_in = bool(await self.get_own_info())
except Exception:
self.log.exception("Exception checking login status")
self._is_logged_in = False
return self._is_logged_in
async def reload_session(
self, event_id: EventID | None = None, retries: int = 3, is_startup: bool = False
) -> None:
try:
await self._load_session(is_startup=is_startup)
except AuthenticationRequired as e:
await self._send_reset_notice(e, edit=event_id)
# TODO Throw a ResponseError on network failures
except ResponseError as e:
will_retry = retries > 0
retry = "Retrying in 1 minute" if will_retry else "Not retrying"
notice = f"Failed to connect to KakaoTalk: unknown response error {e}. {retry}"
if will_retry:
await self.send_bridge_notice(
notice,
edit=event_id,
state_event=BridgeStateEvent.TRANSIENT_DISCONNECT,
)
await asyncio.sleep(60)
await self.reload_session(event_id, retries - 1, is_startup)
else:
await self.send_bridge_notice(
notice,
edit=event_id,
important=True,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
except Exception:
self.log.exception("Error connecting to KakaoTalk")
await self.send_bridge_notice(
"Failed to connect to KakaoTalk: unknown error (see logs for more details)",
edit=event_id,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self.client:
# TODO Look for a logout API call
was_connected = await self.client.disconnect()
if was_connected != self._is_connected:
self.log.warn(
f"Node backend was{' not' if not was_connected else ''} connected, "
f"but we thought it was{' not' if not self._is_connected else ''}")
if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
self._track_metric(METRIC_LOGGED_IN, False)
if reset_device:
self.uuid = await self._generate_uuid()
self.access_token = None
self.refresh_token = None
self._is_logged_in = False
self.is_connected = None
self.client = None
if self.ktid and remove_ktid:
#await UserPortal.delete_all(self.ktid)
del self.by_ktid[self.ktid]
self.ktid = None
await self.save()
async def post_login(self, is_startup: bool) -> None:
self.log.info("Running post-login actions")
self._add_to_cache()
try:
puppet = await pu.Puppet.get_by_ktid(self.ktid)
if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
self.log.info(f"Automatically enabling custom puppet")
await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
except Exception:
self.log.exception("Failed to automatically enable custom puppet")
# TODO Check if things break when a live message comes in during syncing
if self.config["bridge.sync_on_startup"] or not is_startup:
sync_count = self.config["bridge.initial_chat_sync"]
else:
sync_count = None
await self.connect_and_sync(sync_count)
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
return {
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
async for portal in po.Portal.get_all_by_receiver(self.ktid)
if portal.mxid
}
@async_time(METRIC_CONNECT_AND_SYNC)
async def connect_and_sync(self, sync_count: int | None) -> bool:
# TODO Look for a way to sync all channels without (re-)logging in
try:
login_result = await self.client.connect()
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
await self._sync_channels(login_result, sync_count)
return True
except AuthenticationRequired as e:
await self.send_bridge_notice(
f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n",
important=True,
state_event=BridgeStateEvent.BAD_CREDENTIALS,
error_code="kt-auth-error",
error_message=str(e),
)
await self.logout(remove_ktid=False)
except Exception as e:
self.log.exception("Failed to connect and sync")
await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e))
return False
async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None:
if sync_count is None:
sync_count = self.config["bridge.initial_chat_sync"]
if not sync_count:
self.log.debug("Skipping channel syncing")
return
if not login_result.channelList:
self.log.debug("No channels to sync")
return
# TODO What about removed channels? Don't early-return then
num_channels = len(login_result.channelList)
sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
for login_data in login_result.channelList[:sync_count]:
try:
await self._sync_channel(login_data)
except AuthenticationRequired:
raise
except Exception:
self.log.exception(f"Failed to sync channel {login_data.channel.channelId}")
await self.update_direct_chats()
async def _sync_channel(self, login_data: LoginDataItem) -> None:
channel_data = login_data.channel
self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {login_data.lastUpdate})")
channel_info = channel_data.info
if isinstance(channel_data, NormalChannelData):
channel_data: NormalChannelData
channel_info: NormalChannelInfo
self.log.debug(f"Join time: {channel_info.joinTime}")
elif isinstance(channel_data, OpenChannelData):
channel_data: OpenChannelData
self.log.debug(f"channel_data link ID: {channel_data.linkId}")
channel_info: OpenChannelInfo
self.log.debug(f"channel_info link ID: {channel_info.linkId}")
self.log.debug(f"openToken: {channel_info.openToken}")
self.log.debug(f"Is direct channel: {channel_info.directChannel}")
self.log.debug(f"Has OpenLink: {channel_info.openLink is not None}")
else:
self.log.error(f"Unexpected channel type: {type(channel_data)}")
channel_info: ChannelInfo
self.log.debug(f"channel_info channel ID: {channel_info.channelId}")
self.log.debug(f"Channel data/info IDs match: {channel_data.channelId == channel_info.channelId}")
self.log.debug(f"Channel type: {channel_info.type}")
self.log.debug(f"Active user count: {channel_info.activeUserCount}")
self.log.debug(f"New chat count: {channel_info.newChatCount}")
self.log.debug(f"New chat count invalid: {channel_info.newChatCountInvalid}")
self.log.debug(f"Last chat log ID: {channel_info.lastChatLogId}")
self.log.debug(f"Last seen log ID: {channel_info.lastSeenLogId}")
self.log.debug(f"Has last chat log: {channel_info.lastChatLog is not None}")
self.log.debug(f"metaMap: {channel_info.metaMap}")
self.log.debug(f"User count: {len(channel_info.displayUserList)}")
self.log.debug(f"Has push alert: {channel_info.pushAlert}")
for display_user_info in channel_info.displayUserList:
self.log.debug(f"Member: {display_user_info.nickname} - {display_user_info.profileURL} - {display_user_info.userId}")
portal = await po.Portal.get_by_ktid(
channel_info.channelId,
kt_receiver=self.ktid,
kt_type=channel_info.type
)
portal_info = await self.client.get_portal_channel_info(portal.channel_props)
portal_info.channel_info = channel_info
if not portal.mxid:
await portal.create_matrix_room(self, portal_info)
else:
await portal.update_matrix_room(self, portal_info)
await portal.backfill(self, is_initial=False, channel_info=channel_info)
async def get_notice_room(self) -> RoomID:
if not self.notice_room:
async with self._notice_room_lock:
# If someone already created the room while this call was waiting,
# don't make a new room
if self.notice_room:
return self.notice_room
creation_content = {}
if not self.config["bridge.federate_rooms"]:
creation_content["m.federate"] = False
self.notice_room = await self.az.intent.create_room(
is_direct=True,
invitees=[self.mxid],
topic="KakaoTalk bridge notices",
creation_content=creation_content,
)
await self.save()
return self.notice_room
async def send_bridge_notice(
self,
text: str,
edit: EventID | None = None,
state_event: BridgeStateEvent | None = None,
important: bool = False,
error_code: str | None = None,
error_message: str | None = None,
) -> EventID | None:
if state_event:
await self.push_bridge_state(
state_event,
error=error_code,
message=error_message if error_code else text,
)
if self.config["bridge.disable_bridge_notices"]:
return None
event_id = None
try:
self.log.debug("Sending bridge notice: %s", text)
content = TextMessageEventContent(
body=text,
msgtype=(MessageType.TEXT if important else MessageType.NOTICE),
)
if edit:
content.set_edit(edit)
# This is locked to prevent notices going out in the wrong order
async with self._notice_send_lock:
event_id = await self.az.intent.send_message(await self.get_notice_room(), content)
except Exception:
self.log.warning("Failed to send bridge notice", exc_info=True)
return edit or event_id
async def fill_bridge_state(self, state: BridgeState) -> None:
await super().fill_bridge_state(state)
if self.ktid:
state.remote_id = str(self.ktid)
puppet = await pu.Puppet.get_by_ktid(self.ktid)
state.remote_name = puppet.name
async def get_bridge_states(self) -> list[BridgeState]:
if not self.state:
return []
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED
return [state]
async def get_puppet(self) -> pu.Puppet | None:
if not self.ktid:
return None
return await pu.Puppet.get_by_ktid(self.ktid)
# region KakaoTalk event handling
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential
await self.push_bridge_state(BridgeStateEvent.CONNECTING)
self.client = Client(self, log=self.log.getChild("ktclient"))
await self.save()
self._is_logged_in = True
try:
self._logged_in_info = await self.client.fetch_logged_in_user(post_login=True)
self._logged_in_info_time = time.monotonic()
except Exception:
self.log.exception("Failed to fetch post-login info")
asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE)
async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None:
portal = await po.Portal.get_by_ktid(
channel_id,
kt_receiver=self.ktid,
kt_type=channel_type
)
puppet = await pu.Puppet.get_by_ktid(evt.sender.userId)
await portal.backfill_lock.wait(evt.logId)
if not puppet.name:
portal.schedule_resync(self, puppet)
# TODO reply_to
await portal.handle_remote_message(self, puppet, evt)
# TODO Many more handlers
# endregion