matrix-appservice-kakaotalk/matrix_appservice_kakaotalk/user.py

640 lines
25 KiB
Python

# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
import asyncio
import time
from mautrix.bridge import BaseUser, async_getter_lock
from mautrix.types import (
EventID,
MessageType,
RoomID,
TextMessageEventContent,
UserID,
)
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
from mautrix.util.opt_prometheus import Gauge, Summary, async_time
from mautrix.util.simple_lock import SimpleLock
from . import portal as po, puppet as pu
from .config import Config
from .db import User as DBUser
from .kt.client import Client
from .kt.client.errors import AuthenticationRequired, ResponseError
from .kt.types.api.struct.profile import ProfileStruct
from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
from .kt.types.channel.channel_type import ChannelType
from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import LoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
from .kt.types.packet.chat.kickout import KnownKickoutType, KickoutRes
METRIC_CONNECT_AND_SYNC = Summary("bridge_connect_and_sync", "calls to connect_and_sync")
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
if TYPE_CHECKING:
from .__main__ import KakaoTalkBridge
BridgeState.human_readable_errors.update(
{
"kt-reconnection-error": "Failed to reconnect to KakaoTalk",
# TODO Use for errors in Node backend that cause session to be lost
#"kt-connection-error": "KakaoTalk disconnected unexpectedly",
"kt-auth-error": "Authentication error from KakaoTalk: {message}",
"logged-out": "You're not logged into KakaoTalk",
}
)
class User(DBUser, BaseUser):
temp_disconnect_notices: bool = True
shutdown: bool = False
config: Config
by_mxid: dict[UserID, User] = {}
by_ktid: dict[int, User] = {}
client: Client | None
_notice_room_lock: asyncio.Lock
_notice_send_lock: asyncio.Lock
is_admin: bool
permission_level: str
_is_logged_in: bool | None
_is_connected: bool | None
_connection_time: float
_db_instance: DBUser | None
_sync_lock: SimpleLock
_logged_in_info: ProfileStruct | None
_logged_in_info_time: float
def __init__(
self,
mxid: UserID,
ktid: Long | None = None,
uuid: str | None = None,
access_token: str | None = None,
refresh_token: str | None = None,
notice_room: RoomID | None = None,
) -> None:
super().__init__(
mxid=mxid,
ktid=ktid,
uuid=uuid,
access_token=access_token,
refresh_token=refresh_token,
notice_room=notice_room
)
BaseUser.__init__(self)
self.notice_room = notice_room
self._notice_room_lock = asyncio.Lock()
self._notice_send_lock = asyncio.Lock()
(
self.relay_whitelisted,
self.is_whitelisted,
self.is_admin,
self.permission_level,
) = self.config.get_permissions(mxid)
self._is_logged_in = None
self._is_connected = None
self._connection_time = time.monotonic()
self._sync_lock = SimpleLock(
"Waiting for thread sync to finish before handling %s", log=self.log
)
self._logged_in_info = None
self._logged_in_info_time = 0
self.client = None
@classmethod
def init_cls(cls, bridge: KakaoTalkBridge) -> AsyncIterable[Awaitable[bool]]:
cls.bridge = bridge
cls.config = bridge.config
cls.az = bridge.az
cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
@property
def is_connected(self) -> bool | None:
return self._is_connected
@is_connected.setter
def is_connected(self, val: bool | None) -> None:
if self._is_connected != val:
self._is_connected = val
self._connection_time = time.monotonic()
@property
def connection_time(self) -> float:
return self._connection_time
@property
def has_state(self) -> bool:
# TODO If more state is needed, consider returning a saved LoginResult
return bool(self.access_token and self.refresh_token)
# region Database getters
def _add_to_cache(self) -> None:
self.by_mxid[self.mxid] = self
if self.ktid:
self.by_ktid[self.ktid] = self
@classmethod
async def all_logged_in(cls) -> AsyncGenerator["User", None]:
users = await super().all_logged_in()
user: cls
for user in users:
try:
yield cls.by_mxid[user.mxid]
except KeyError:
user._add_to_cache()
yield user
@classmethod
@async_getter_lock
async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> User | None:
if pu.Puppet.get_id_from_mxid(mxid) or mxid == cls.az.bot_mxid:
return None
try:
return cls.by_mxid[mxid]
except KeyError:
pass
user = cast(cls, await super().get_by_mxid(mxid))
if user is not None:
user._add_to_cache()
return user
if create:
cls.log.debug(f"Creating user instance for {mxid}")
user = cls(mxid)
await user.insert()
user._add_to_cache()
return user
return None
@classmethod
@async_getter_lock
async def get_by_ktid(cls, ktid: int) -> User | None:
try:
return cls.by_ktid[ktid]
except KeyError:
pass
user = cast(cls, await super().get_by_ktid(ktid))
if user is not None:
user._add_to_cache()
return user
return None
async def get_uuid(self, force: bool = False) -> str:
if self.uuid is None or force:
self.uuid = await self._generate_uuid()
# TODO Maybe don't save yet
await self.save()
return self.uuid
@classmethod
async def _generate_uuid(cls) -> str:
return await Client.generate_uuid(await super().get_all_uuids())
# endregion
@property
def oauth_credential(self) -> OAuthCredential:
return OAuthCredential(
self.ktid,
self.uuid,
self.access_token,
self.refresh_token,
)
@oauth_credential.setter
def oauth_credential(self, oauth_credential: OAuthCredential) -> None:
self.ktid = oauth_credential.userId
self.access_token = oauth_credential.accessToken
self.refresh_token = oauth_credential.refreshToken
if self.uuid != oauth_credential.deviceUUID:
self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
self.uuid = oauth_credential.deviceUUID
async def get_own_info(self) -> ProfileStruct:
if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic():
self._logged_in_info = await self.client.get_own_profile()
self._logged_in_info_time = time.monotonic()
return self._logged_in_info
async def _load_session(self, is_startup: bool) -> bool:
if self._is_logged_in and is_startup:
return True
elif not self.has_state:
# If we have a user in the DB with no state, we can assume
# KT logged us out and the bridge has restarted
await self.push_bridge_state(
BridgeStateEvent.BAD_CREDENTIALS,
error="logged-out",
)
return False
client = Client(self, log=self.log.getChild("ktclient"))
user_info = await client.start()
# NOTE On failure, client.start throws instead of returning False
self.log.info("Loaded session successfully")
self.client = client
self._logged_in_info = user_info
self._logged_in_info_time = time.monotonic()
self._track_metric(METRIC_LOGGED_IN, True)
self._is_logged_in = True
self.is_connected = None
asyncio.create_task(self.post_login(is_startup=is_startup))
return True
async def _send_reset_notice(self, e: AuthenticationRequired, edit: EventID | None = None) -> None:
await self.send_bridge_notice(
"Got authentication error from KakaoTalk:\n\n"
f"> {e.message}\n\n"
"If you changed your KakaoTalk password, this "
"is normal and you just need to log in again.",
edit=edit,
important=True,
state_event=BridgeStateEvent.BAD_CREDENTIALS,
error_code="kt-auth-error",
error_message=str(e),
)
await self.logout(remove_ktid=False)
async def is_logged_in(self, _override: bool = False) -> bool:
if not self.has_state or not self.client:
return False
if self._is_logged_in is None or _override:
try:
self._is_logged_in = bool(await self.get_own_info())
except Exception:
self.log.exception("Exception checking login status")
self._is_logged_in = False
return self._is_logged_in
async def reload_session(
self, event_id: EventID | None = None, retries: int = 3, is_startup: bool = False
) -> None:
try:
await self._load_session(is_startup=is_startup)
except AuthenticationRequired as e:
await self._send_reset_notice(e, edit=event_id)
# TODO Throw a ResponseError on network failures
except ResponseError as e:
will_retry = retries > 0
retry = "Retrying in 1 minute" if will_retry else "Not retrying"
notice = f"Failed to connect to KakaoTalk: unknown response error {e}. {retry}"
if will_retry:
await self.send_bridge_notice(
notice,
edit=event_id,
state_event=BridgeStateEvent.TRANSIENT_DISCONNECT,
)
await asyncio.sleep(60)
await self.reload_session(event_id, retries - 1, is_startup)
else:
await self.send_bridge_notice(
notice,
edit=event_id,
important=True,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
except Exception:
self.log.exception("Error connecting to KakaoTalk")
await self.send_bridge_notice(
"Failed to connect to KakaoTalk: unknown error (see logs for more details)",
edit=event_id,
state_event=BridgeStateEvent.UNKNOWN_ERROR,
error_code="kt-reconnection-error",
)
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self.client:
# TODO Look for a logout API call
await self.client.stop()
if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
self._track_metric(METRIC_LOGGED_IN, False)
if reset_device:
self.uuid = await self._generate_uuid()
self.access_token = None
self.refresh_token = None
self._is_logged_in = False
self.is_connected = None
self.client = None
if self.ktid and remove_ktid:
#await UserPortal.delete_all(self.ktid)
del self.by_ktid[self.ktid]
self.ktid = None
await self.save()
async def post_login(self, is_startup: bool) -> None:
self.log.info("Running post-login actions")
self._add_to_cache()
try:
puppet = await pu.Puppet.get_by_ktid(self.ktid)
if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
self.log.info(f"Automatically enabling custom puppet")
await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
except Exception:
self.log.exception("Failed to automatically enable custom puppet")
# TODO Check if things break when a live message comes in during syncing
if self.config["bridge.sync_on_startup"] or not is_startup:
sync_count = self.config["bridge.initial_chat_sync"]
else:
sync_count = None
# TODO Don't auto-connect on startup if user's last state was disconnected
await self.connect_and_sync(sync_count)
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
return {
pu.Puppet.get_mxid_from_id(portal.ktid): [portal.mxid]
async for portal in po.Portal.get_all_by_receiver(self.ktid)
if portal.mxid
}
@async_time(METRIC_CONNECT_AND_SYNC)
async def connect_and_sync(self, sync_count: int | None) -> bool:
# TODO Look for a way to sync all channels without (re-)logging in
try:
login_result = await self.client.connect()
await self.on_connect()
await self._sync_channels(login_result, sync_count)
return True
except AuthenticationRequired as e:
await self.send_bridge_notice(
f"Got authentication error from KakaoTalk:\n\n> {e.message}\n\n",
important=True,
state_event=BridgeStateEvent.BAD_CREDENTIALS,
error_code="kt-auth-error",
error_message=str(e),
)
await self.logout(remove_ktid=False)
except Exception as e:
self.log.exception("Failed to connect and sync")
await self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR, message=str(e))
return False
async def _sync_channels(self, login_result: LoginResult, sync_count: int | None) -> None:
if sync_count is None:
sync_count = self.config["bridge.initial_chat_sync"]
if not sync_count:
self.log.debug("Skipping channel syncing")
return
if not login_result.channelList:
self.log.debug("No channels to sync")
return
# TODO What about removed channels? Don't early-return then
num_channels = len(login_result.channelList)
sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
for login_data in login_result.channelList[:sync_count]:
try:
await self._sync_channel(login_data)
except AuthenticationRequired:
raise
except Exception:
self.log.exception(f"Failed to sync channel {login_data.channel.channelId}")
await self.update_direct_chats()
async def _sync_channel(self, login_data: LoginDataItem) -> None:
channel_data = login_data.channel
self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {login_data.lastUpdate})")
channel_info = channel_data.info
if isinstance(channel_data, NormalChannelData):
channel_data: NormalChannelData
channel_info: NormalChannelInfo
self.log.debug(f"Join time: {channel_info.joinTime}")
elif isinstance(channel_data, OpenChannelData):
channel_data: OpenChannelData
self.log.debug(f"channel_data link ID: {channel_data.linkId}")
channel_info: OpenChannelInfo
self.log.debug(f"channel_info link ID: {channel_info.linkId}")
self.log.debug(f"openToken: {channel_info.openToken}")
self.log.debug(f"Is direct channel: {channel_info.directChannel}")
self.log.debug(f"Has OpenLink: {channel_info.openLink is not None}")
else:
self.log.error(f"Unexpected channel type: {type(channel_data)}")
channel_info: ChannelInfo
self.log.debug(f"channel_info channel ID: {channel_info.channelId}")
self.log.debug(f"Channel data/info IDs match: {channel_data.channelId == channel_info.channelId}")
self.log.debug(f"Channel type: {channel_info.type}")
self.log.debug(f"Active user count: {channel_info.activeUserCount}")
self.log.debug(f"New chat count: {channel_info.newChatCount}")
self.log.debug(f"New chat count invalid: {channel_info.newChatCountInvalid}")
self.log.debug(f"Last chat log ID: {channel_info.lastChatLogId}")
self.log.debug(f"Last seen log ID: {channel_info.lastSeenLogId}")
self.log.debug(f"Has last chat log: {channel_info.lastChatLog is not None}")
self.log.debug(f"metaMap: {channel_info.metaMap}")
self.log.debug(f"User count: {len(channel_info.displayUserList)}")
self.log.debug(f"Has push alert: {channel_info.pushAlert}")
for display_user_info in channel_info.displayUserList:
self.log.debug(f"Member: {display_user_info.nickname} - {display_user_info.profileURL} - {display_user_info.userId}")
portal = await po.Portal.get_by_ktid(
channel_info.channelId,
kt_receiver=self.ktid,
kt_type=channel_info.type
)
portal_info = await self.client.get_portal_channel_info(portal.channel_props)
portal_info.channel_info = channel_info
if not portal.mxid:
await portal.create_matrix_room(self, portal_info)
else:
await portal.update_matrix_room(self, portal_info)
await portal.backfill(self, is_initial=False, channel_info=channel_info)
async def get_notice_room(self) -> RoomID:
if not self.notice_room:
async with self._notice_room_lock:
# If someone already created the room while this call was waiting,
# don't make a new room
if self.notice_room:
return self.notice_room
creation_content = {}
if not self.config["bridge.federate_rooms"]:
creation_content["m.federate"] = False
self.notice_room = await self.az.intent.create_room(
is_direct=True,
invitees=[self.mxid],
topic="KakaoTalk bridge notices",
creation_content=creation_content,
)
await self.save()
return self.notice_room
async def send_bridge_notice(
self,
text: str,
edit: EventID | None = None,
state_event: BridgeStateEvent | None = None,
important: bool = False,
error_code: str | None = None,
error_message: str | None = None,
) -> EventID | None:
if state_event:
await self.push_bridge_state(
state_event,
error=error_code,
message=error_message if error_code else text,
)
if self.config["bridge.disable_bridge_notices"]:
return None
event_id = None
try:
self.log.debug("Sending bridge notice: %s", text)
content = TextMessageEventContent(
body=text,
msgtype=(MessageType.TEXT if important else MessageType.NOTICE),
)
if edit:
content.set_edit(edit)
# This is locked to prevent notices going out in the wrong order
async with self._notice_send_lock:
event_id = await self.az.intent.send_message(await self.get_notice_room(), content)
except Exception:
self.log.warning("Failed to send bridge notice", exc_info=True)
return edit or event_id
async def fill_bridge_state(self, state: BridgeState) -> None:
await super().fill_bridge_state(state)
if self.ktid:
state.remote_id = str(self.ktid)
puppet = await pu.Puppet.get_by_ktid(self.ktid)
state.remote_name = puppet.name
async def get_bridge_states(self) -> list[BridgeState]:
if not self.state:
return []
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED
# TODO
#elif self._is_logged_in and self._is_reconnecting:
# state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state]
async def get_puppet(self) -> pu.Puppet | None:
if not self.ktid:
return None
return await pu.Puppet.get_by_ktid(self.ktid)
async def get_portal_with(self, puppet: pu.Puppet, create: bool = True) -> po.Portal | None:
# TODO
return None
"""
if not self.ktid or not self.client:
return None
return await po.Portal.get_by_ktid(
await self.client.get_dm_channel_id_for(puppet.ktid),
kt_receiver=self.ktid,
create=create,
kt_type=KnownChannelType.DirectChat if puppet.ktid != self.ktid else KnownChannelType.MemoChat
)
"""
# region KakaoTalk event handling
async def on_connect(self) -> None:
now = time.monotonic()
disconnected_at = self._connection_time
max_delay = self.config["bridge.resync_max_disconnected_time"]
first_connect = self.is_connected is None
self.is_connected = True
self._track_metric(METRIC_CONNECTED, True)
if not first_connect and disconnected_at + max_delay < now:
duration = int(now - disconnected_at)
self.log.debug(f"Disconnection lasted {duration} seconds")
elif self.temp_disconnect_notices:
await self.send_bridge_notice("Connected to KakaoTalk chats")
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
async def on_disconnect(self, res: KickoutRes | None) -> None:
self.is_connected = False
self._track_metric(METRIC_CONNECTED, False)
if res:
logout = False
if res.reason == KnownKickoutType.LOGIN_ANOTHER:
reason_str = "Logged in from another desktop client."
elif res.reason == KnownKickoutType.CHANGE_SERVER:
# TODO Reconnect automatically instead
reason_str = "KakaoTalk backend server changed."
elif res.reason == KnownKickoutType.ACCOUNT_DELETED:
reason_str = "Your KakaoTalk account was deleted!"
logout = True
else:
reason_str = f"Unknown reason ({res.reason})."
if not logout:
reason_suffix = " To reconnect, use the `sync` command."
# TODO What bridge state to push?
else:
reason_suffix = " You are now logged out."
await self.logout()
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential
await self.push_bridge_state(BridgeStateEvent.CONNECTING)
self.client = Client(self, log=self.log.getChild("ktclient"))
await self.save()
self._is_logged_in = True
# TODO Retry network connection failures here, or in the client (like token refreshes are)?
# Should also catch unlikely authentication errors
self._logged_in_info = await self.client.start()
self._logged_in_info_time = time.monotonic()
asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE)
async def on_message(self, evt: Chatlog, channel_id: Long, channel_type: ChannelType) -> None:
portal = await po.Portal.get_by_ktid(
channel_id,
kt_receiver=self.ktid,
kt_type=channel_type
)
puppet = await pu.Puppet.get_by_ktid(evt.sender.userId)
await portal.backfill_lock.wait(evt.logId)
if not puppet.name:
portal.schedule_resync(self, puppet)
await portal.handle_remote_message(self, puppet, evt)
# TODO Many more handlers
# endregion