Split state into "logged in" and "connected"

Logged in = have (supposedly valid) session tokens
Connected = have an active TalkClient session
This commit is contained in:
Andrew Ferrazzutti 2022-04-01 05:07:05 -04:00
parent 15415a5eec
commit 59ea91519a
9 changed files with 303 additions and 163 deletions

View File

@ -38,7 +38,22 @@ async def set_notice_room(evt: CommandEvent) -> None:
needs_auth=True, needs_auth=True,
management_only=True, management_only=True,
help_section=SECTION_CONNECTION, help_section=SECTION_CONNECTION,
help_text="Check if you're logged into KakaoTalk", help_text="Disconnect from KakaoTalk chats, but remain logged into profile-management commands",
)
async def disconnect(evt: CommandEvent) -> None:
if not evt.sender.is_connected:
await evt.reply("You are already disconnected from KakaoTalk chats")
return
await evt.mark_read()
await evt.sender.client.disconnect()
await evt.reply("Successfully disconnected from KakaoTalk chats. To reconnect, use the `sync` command.")
@command_handler(
needs_auth=True,
management_only=True,
help_section=SECTION_CONNECTION,
help_text="Check if you're logged into KakaoTalk & connected to chats",
) )
async def ping(evt: CommandEvent) -> None: async def ping(evt: CommandEvent) -> None:
if not await evt.sender.is_logged_in(): if not await evt.sender.is_logged_in():
@ -47,7 +62,11 @@ async def ping(evt: CommandEvent) -> None:
await evt.mark_read() await evt.mark_read()
try: try:
own_info = await evt.sender.get_own_info() own_info = await evt.sender.get_own_info()
await evt.reply(f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})") await evt.reply(
f"You're logged in as {own_info.nickname} (user ID {evt.sender.ktid})."
"\n\n"
f"You are {'connected to' if evt.sender.is_connected else '**disconnected** from'} KakaoTalk chats.\n\n"
)
except CommandException as e: except CommandException as e:
await evt.reply(f"Error from KakaoTalk: {e}") await evt.reply(f"Error from KakaoTalk: {e}")
@ -56,8 +75,8 @@ async def ping(evt: CommandEvent) -> None:
needs_auth=True, needs_auth=True,
management_only=True, management_only=True,
help_section=SECTION_CONNECTION, help_section=SECTION_CONNECTION,
help_text="Resync chats", help_text="(Re)connect to KakaoTalk chats & sync any missed chat updates",
help_args="[count]", help_args="[number_of_channels_to_sync]",
) )
async def sync(evt: CommandEvent) -> None: async def sync(evt: CommandEvent) -> None:
try: try:
@ -65,7 +84,7 @@ async def sync(evt: CommandEvent) -> None:
except IndexError: except IndexError:
sync_count = None sync_count = None
except ValueError: except ValueError:
await evt.reply("**Usage:** `$cmdprefix+sp logout [--reset-device]`") await evt.reply("**Usage:** `sync [number_of_channels_to_sync]`")
return return
await evt.mark_read() await evt.mark_read()

View File

@ -213,23 +213,15 @@ bridge:
# The number of seconds that a disconnection can last without triggering an automatic re-sync # The number of seconds that a disconnection can last without triggering an automatic re-sync
# and missed message backfilling when reconnecting. # and missed message backfilling when reconnecting.
# Set to 0 to always re-sync, or -1 to never re-sync automatically. # Set to 0 to always re-sync, or -1 to never re-sync automatically.
#resync_max_disconnected_time: 5 resync_max_disconnected_time: 5
# Should the bridge do a resync on startup? # Should the bridge do a resync on startup?
sync_on_startup: true sync_on_startup: true
# Whether or not temporary disconnections should send notices to the notice room. # Whether or not temporary disconnections should send notices to the notice room.
# If this is false, disconnections will never send messages and connections will only send # If this is false, disconnections will never send messages and connections will only send
# messages if it was disconnected for more than resync_max_disconnected_time seconds. # messages if it was disconnected for more than resync_max_disconnected_time seconds.
# TODO Probably don't need this
temporary_disconnect_notices: true temporary_disconnect_notices: true
# Disable bridge notices entirely # Disable bridge notices entirely
disable_bridge_notices: false disable_bridge_notices: false
on_reconnection_fail:
# Whether or not the bridge should try to "refresh" the connection if a normal reconnection
# attempt fails.
refresh: false
# Seconds to wait before attempting to refresh the connection, set a list of two items to
# to randomize the interval (min, max).
wait_for: 0
# Set this to true to tell the bridge to re-send m.bridge events to all rooms on the next run. # Set this to true to tell the bridge to re-send m.bridge events to all rooms on the next run.
# This field will automatically be changed back to false after it, # This field will automatically be changed back to false after it,
# except if the config file is not writable. # except if the config file is not writable.

View File

@ -42,6 +42,7 @@ from ..types.bson import Long
from ..types.client.client_session import LoginResult from ..types.client.client_session import LoginResult
from ..types.chat import Chatlog, KnownChatType from ..types.chat import Chatlog, KnownChatType
from ..types.oauth import OAuthCredential, OAuthInfo from ..types.oauth import OAuthCredential, OAuthInfo
from ..types.packet.chat.kickout import KnownKickoutType, KickoutRes
from ..types.request import ( from ..types.request import (
deserialize_result, deserialize_result,
ResultType, ResultType,
@ -148,11 +149,12 @@ class Client:
def _oauth_credential(self) -> JSON: def _oauth_credential(self) -> JSON:
return self.user.oauth_credential.serialize() return self.user.oauth_credential.serialize()
def _get_user_data(self) -> JSON: @property
return dict( def _user_data(self) -> JSON:
mxid=self.user.mxid, return {
oauth_credential=self._oauth_credential "mxid": self.user.mxid,
) "oauth_credential": self._oauth_credential,
}
# region HTTP # region HTTP
@ -179,19 +181,28 @@ class Client:
# region post-token commands # region post-token commands
async def renew(self) -> OAuthInfo: async def start(self) -> ProfileStruct:
"""Get a new set of tokens from a refresh token.""" """
return await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential) Initialize user-specific bridging & state by providing a token obtained from a prior login.
Receive the user's profile info in response.
"""
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "start")
return profile_req_struct.profile
async def stop(self) -> None:
"""Immediately stop bridging this user."""
self._stop_listen()
await self._rpc_client.request("stop", mxid=self.user.mxid)
async def renew_and_save(self) -> None: async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens.""" """Renew and save the user's session tokens."""
oauth_info = await self.renew() oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
self.user.oauth_credential = oauth_info.credential self.user.oauth_credential = oauth_info.credential
await self.user.save() await self.user.save()
async def connect(self) -> LoginResult: async def connect(self) -> LoginResult:
""" """
Start a new session by providing a token obtained from a prior login. Start a new talk session by providing a token obtained from a prior login.
Receive a snapshot of account state in response. Receive a snapshot of account state in response.
""" """
login_result = await self._api_user_request_result(LoginResult, "connect") login_result = await self._api_user_request_result(LoginResult, "connect")
@ -200,12 +211,12 @@ class Client:
self._start_listen() self._start_listen()
return login_result return login_result
async def disconnect(self) -> bool: async def disconnect(self) -> None:
connection_existed = await self._rpc_client.request("disconnect", mxid=self.user.mxid) """Disconnect from the talk session, but remain logged in."""
self._stop_listen() await self._rpc_client.request("disconnect", mxid=self.user.mxid)
return connection_existed await self._on_disconnect(None)
async def fetch_logged_in_user(self, post_login: bool = False) -> ProfileStruct: async def get_own_profile(self) -> ProfileStruct:
profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile") profile_req_struct = await self._api_user_request_result(ProfileReqStruct, "get_own_profile")
return profile_req_struct.profile return profile_req_struct.profile
@ -289,7 +300,7 @@ class Client:
renewed = False renewed = False
while True: while True:
try: try:
return await self._api_request_result(result_type, command, **self._get_user_data(), **data) return await self._api_request_result(result_type, command, **self._user_data, **data)
except InvalidAccessToken: except InvalidAccessToken:
if renewed: if renewed:
raise raise
@ -300,7 +311,7 @@ class Client:
renewed = False renewed = False
while True: while True:
try: try:
return await self._api_request_void(command, **self._get_user_data(), **data) return await self._api_request_void(command, **self._user_data, **data)
except InvalidAccessToken: except InvalidAccessToken:
if renewed: if renewed:
raise raise
@ -316,7 +327,7 @@ class Client:
await self.user.on_message( await self.user.on_message(
Chatlog.deserialize(data["chatlog"]), Chatlog.deserialize(data["chatlog"]),
Long.deserialize(data["channelId"]), Long.deserialize(data["channelId"]),
data["channelType"] data["channelType"],
) )
""" TODO """ TODO
@ -324,16 +335,36 @@ class Client:
await self.user.on_receipt(Receipt.deserialize(data["receipt"])) await self.user.on_receipt(Receipt.deserialize(data["receipt"]))
""" """
async def _on_listen_disconnect(self, data: dict[str, JSON]) -> None:
try:
res = KickoutRes.deserialize(data)
except Exception:
self.log.exception("Invalid kickout reason, defaulting to None")
res = None
await self._on_disconnect(res)
async def _on_switch_server(self) -> None:
# TODO Reconnect automatically instead
await self._on_disconnect(KickoutRes(KnownKickoutType.CHANGE_SERVER))
async def _on_disconnect(self, res: KickoutRes | None) -> None:
self._stop_listen()
await self.user.on_disconnect(res)
def _start_listen(self) -> None: def _start_listen(self) -> None:
# TODO Automate this somehow, like with a fancy enum # TODO Automate this somehow, like with a fancy enum
self._rpc_client.set_event_handlers(self._get_user_cmd("message"), [self._on_message]) self._rpc_client.set_event_handlers(self._get_user_cmd("chat"), [self._on_message])
# TODO many more listeners # TODO many more listeners
self._rpc_client.set_event_handlers(self._get_user_cmd("disconnected"), [self._on_listen_disconnect])
self._rpc_client.set_event_handlers(self._get_user_cmd("switch_server"), [self._on_switch_server])
def _stop_listen(self) -> None: def _stop_listen(self) -> None:
# TODO Automate this somehow, like with a fancy enum # TODO Automate this somehow, like with a fancy enum
self._rpc_client.set_event_handlers(self._get_user_cmd("message"), []) self._rpc_client.set_event_handlers(self._get_user_cmd("chat"), [])
# TODO many more listeners # TODO many more listeners
self._rpc_client.set_event_handlers(self._get_user_cmd("disconnected"), [])
self._rpc_client.set_event_handlers(self._get_user_cmd("switch_server"), [])
def _get_user_cmd(self, command) -> str: def _get_user_cmd(self, command) -> str:
return f"{command}:{self.user.mxid}" return f"{command}:{self.user.mxid}"

View File

@ -0,0 +1,41 @@
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union
from enum import IntEnum
from attr import dataclass
from mautrix.types import SerializableAttrs
@dataclass
class KnownKickoutType(IntEnum):
CHANGE_SERVER = -2
LOGIN_ANOTHER = 0
ACCOUNT_DELETED = 1
KickoutType = Union[KnownKickoutType, int]
@dataclass
class KickoutRes(SerializableAttrs):
reason: KickoutType
__all__ = [
"KnownKickoutType",
"KickoutType",
"KickoutRes",
]

View File

@ -53,6 +53,10 @@ class MatrixHandler(BaseMatrixHandler):
self.user_id_suffix = f"{suffix}:{homeserver}" self.user_id_suffix = f"{suffix}:{homeserver}"
super().__init__(bridge=bridge) super().__init__(bridge=bridge)
@staticmethod
async def allow_bridging_message(user: u.User, portal: po.Portal) -> bool:
return user.is_connected or (user.relay_whitelisted and portal.has_relay)
async def send_welcome_message(self, room_id: RoomID, inviter: u.User) -> None: async def send_welcome_message(self, room_id: RoomID, inviter: u.User) -> None:
await super().send_welcome_message(room_id, inviter) await super().send_welcome_message(room_id, inviter)
if not inviter.notice_room: if not inviter.notice_room:

View File

@ -1260,7 +1260,10 @@ class Portal(DBPortal, BasePortal):
# TODO Save kt_sender in DB instead? Depends on if DM channels are shared... # TODO Save kt_sender in DB instead? Depends on if DM channels are shared...
user = await u.User.get_by_ktid(self.kt_receiver) user = await u.User.get_by_ktid(self.kt_receiver)
assert user, f"Found no user for this portal's receiver of {self.kt_receiver}" assert user, f"Found no user for this portal's receiver of {self.kt_receiver}"
await self._update_participants(user) if user.is_connected:
await self._update_participants(user)
else:
self.log.debug(f"Not setting _main_intent of new direct chat for disconnected user {user.ktid}")
else: else:
self.log.debug("Not setting _main_intent of new direct chat until after checking participant list") self.log.debug("Not setting _main_intent of new direct chat until after checking participant list")

View File

@ -45,17 +45,9 @@ from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import LoginDataItem, LoginResult from .kt.types.client.client_session import LoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
from .kt.types.packet.chat.kickout import KnownKickoutType, KickoutRes
METRIC_CONNECT_AND_SYNC = Summary("bridge_sync_channels", "calls to connect_and_sync") METRIC_CONNECT_AND_SYNC = Summary("bridge_connect_and_sync", "calls to connect_and_sync")
METRIC_MEMBERS_ADDED = Summary("bridge_on_members_added", "calls to on_members_added")
METRIC_MEMBER_REMOVED = Summary("bridge_on_member_removed", "calls to on_member_removed")
METRIC_TYPING = Summary("bridge_on_typing", "calls to on_typing")
METRIC_PRESENCE = Summary("bridge_on_presence", "calls to on_presence")
METRIC_REACTION = Summary("bridge_on_reaction", "calls to on_reaction")
METRIC_MESSAGE_UNSENT = Summary("bridge_on_unsent", "calls to on_unsent")
METRIC_MESSAGE_SEEN = Summary("bridge_on_message_seen", "calls to on_message_seen")
METRIC_TITLE_CHANGE = Summary("bridge_on_title_change", "calls to on_title_change")
METRIC_AVATAR_CHANGE = Summary("bridge_on_avatar_change", "calls to on_avatar_change")
METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message") METRIC_MESSAGE = Summary("bridge_on_message", "calls to on_message")
METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge") METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk") METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to KakaoTalk")
@ -66,15 +58,16 @@ if TYPE_CHECKING:
BridgeState.human_readable_errors.update( BridgeState.human_readable_errors.update(
{ {
"kt-reconnection-error": "Failed to reconnect to KakaoTalk", "kt-reconnection-error": "Failed to reconnect to KakaoTalk",
"kt-connection-error": "KakaoTalk disconnected unexpectedly", # TODO Use for errors in Node backend that cause session to be lost
#"kt-connection-error": "KakaoTalk disconnected unexpectedly",
"kt-auth-error": "Authentication error from KakaoTalk: {message}", "kt-auth-error": "Authentication error from KakaoTalk: {message}",
"kt-disconnected": None,
"logged-out": "You're not logged into KakaoTalk", "logged-out": "You're not logged into KakaoTalk",
} }
) )
class User(DBUser, BaseUser): class User(DBUser, BaseUser):
temp_disconnect_notices: bool = True
shutdown: bool = False shutdown: bool = False
config: Config config: Config
@ -138,7 +131,7 @@ class User(DBUser, BaseUser):
cls.bridge = bridge cls.bridge = bridge
cls.config = bridge.config cls.config = bridge.config
cls.az = bridge.az cls.az = bridge.az
cls.loop = bridge.loop cls.temp_disconnect_notices = bridge.config["bridge.temporary_disconnect_notices"]
return (user.reload_session(is_startup=True) async for user in cls.all_logged_in()) return (user.reload_session(is_startup=True) async for user in cls.all_logged_in())
@property @property
@ -245,17 +238,17 @@ class User(DBUser, BaseUser):
self.access_token = oauth_credential.accessToken self.access_token = oauth_credential.accessToken
self.refresh_token = oauth_credential.refreshToken self.refresh_token = oauth_credential.refreshToken
if self.uuid != oauth_credential.deviceUUID: if self.uuid != oauth_credential.deviceUUID:
self.log.warn(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}") self.log.warning(f"UUID mismatch: expected {self.uuid}, got {oauth_credential.deviceUUID}")
self.uuid = oauth_credential.deviceUUID self.uuid = oauth_credential.deviceUUID
async def get_own_info(self) -> ProfileStruct: async def get_own_info(self) -> ProfileStruct:
if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic(): if not self._logged_in_info or self._logged_in_info_time + 60 * 60 < time.monotonic():
self._logged_in_info = await self.client.fetch_logged_in_user() self._logged_in_info = await self.client.get_own_profile()
self._logged_in_info_time = time.monotonic() self._logged_in_info_time = time.monotonic()
return self._logged_in_info return self._logged_in_info
async def _load_session(self, is_startup: bool) -> bool: async def _load_session(self, is_startup: bool) -> bool:
if self._is_logged_in and not is_startup: if self._is_logged_in and is_startup:
return True return True
elif not self.has_state: elif not self.has_state:
# If we have a user in the DB with no state, we can assume # If we have a user in the DB with no state, we can assume
@ -266,18 +259,17 @@ class User(DBUser, BaseUser):
) )
return False return False
client = Client(self, log=self.log.getChild("ktclient")) client = Client(self, log=self.log.getChild("ktclient"))
user_info = await self.fetch_logged_in_user(client) user_info = await client.start()
if user_info: # NOTE On failure, client.start throws instead of returning False
self.log.info("Loaded session successfully") self.log.info("Loaded session successfully")
self.client = client self.client = client
self._logged_in_info = user_info self._logged_in_info = user_info
self._logged_in_info_time = time.monotonic() self._logged_in_info_time = time.monotonic()
self._track_metric(METRIC_LOGGED_IN, True) self._track_metric(METRIC_LOGGED_IN, True)
self._is_logged_in = True self._is_logged_in = True
self.is_connected = None self.is_connected = None
asyncio.create_task(self.post_login(is_startup=is_startup)) asyncio.create_task(self.post_login(is_startup=is_startup))
return True return True
return False
async def _send_reset_notice(self, e: AuthenticationRequired, edit: EventID | None = None) -> None: async def _send_reset_notice(self, e: AuthenticationRequired, edit: EventID | None = None) -> None:
await self.send_bridge_notice( await self.send_bridge_notice(
@ -293,22 +285,6 @@ class User(DBUser, BaseUser):
) )
await self.logout(remove_ktid=False) await self.logout(remove_ktid=False)
async def fetch_logged_in_user(
self, client: Client | None = None, action: str = "restore session"
) -> ProfileStruct:
if not client:
client = self.client
# TODO Retry network connection failures here, or in the client (like token refreshes are)?
try:
return await client.fetch_logged_in_user()
except AuthenticationRequired as e:
if action != "restore session":
await self._send_reset_notice(e)
raise
except Exception:
self.log.exception(f"Failed to {action}")
raise
async def is_logged_in(self, _override: bool = False) -> bool: async def is_logged_in(self, _override: bool = False) -> bool:
if not self.has_state or not self.client: if not self.has_state or not self.client:
return False return False
@ -360,11 +336,7 @@ class User(DBUser, BaseUser):
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None: async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self.client: if self.client:
# TODO Look for a logout API call # TODO Look for a logout API call
was_connected = await self.client.disconnect() await self.client.stop()
if was_connected != self._is_connected:
self.log.warn(
f"Node backend was{' not' if not was_connected else ''} connected, "
f"but we thought it was{' not' if not self._is_connected else ''}")
if remove_ktid: if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT) await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
self._track_metric(METRIC_LOGGED_IN, False) self._track_metric(METRIC_LOGGED_IN, False)
@ -403,6 +375,7 @@ class User(DBUser, BaseUser):
sync_count = self.config["bridge.initial_chat_sync"] sync_count = self.config["bridge.initial_chat_sync"]
else: else:
sync_count = None sync_count = None
# TODO Don't auto-connect on startup if user's last state was disconnected
await self.connect_and_sync(sync_count) await self.connect_and_sync(sync_count)
async def get_direct_chats(self) -> dict[UserID, list[RoomID]]: async def get_direct_chats(self) -> dict[UserID, list[RoomID]]:
@ -417,7 +390,7 @@ class User(DBUser, BaseUser):
# TODO Look for a way to sync all channels without (re-)logging in # TODO Look for a way to sync all channels without (re-)logging in
try: try:
login_result = await self.client.connect() login_result = await self.client.connect()
await self.push_bridge_state(BridgeStateEvent.CONNECTED) await self.on_connect()
await self._sync_channels(login_result, sync_count) await self._sync_channels(login_result, sync_count)
return True return True
except AuthenticationRequired as e: except AuthenticationRequired as e:
@ -572,6 +545,9 @@ class User(DBUser, BaseUser):
state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR) state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
if self.is_connected: if self.is_connected:
state.state_event = BridgeStateEvent.CONNECTED state.state_event = BridgeStateEvent.CONNECTED
# TODO
#elif self._is_logged_in and self._is_reconnecting:
# state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
return [state] return [state]
async def get_puppet(self) -> pu.Puppet | None: async def get_puppet(self) -> pu.Puppet | None:
@ -581,6 +557,43 @@ class User(DBUser, BaseUser):
# region KakaoTalk event handling # region KakaoTalk event handling
async def on_connect(self) -> None:
now = time.monotonic()
disconnected_at = self._connection_time
max_delay = self.config["bridge.resync_max_disconnected_time"]
first_connect = self.is_connected is None
self.is_connected = True
self._track_metric(METRIC_CONNECTED, True)
if not first_connect and disconnected_at + max_delay < now:
duration = int(now - disconnected_at)
self.log.debug(f"Disconnection lasted {duration} seconds")
elif self.temp_disconnect_notices:
await self.send_bridge_notice("Connected to KakaoTalk chats")
await self.push_bridge_state(BridgeStateEvent.CONNECTED)
async def on_disconnect(self, res: KickoutRes | None) -> None:
self.is_connected = False
self._track_metric(METRIC_CONNECTED, False)
if res:
logout = False
if res.reason == KnownKickoutType.LOGIN_ANOTHER:
reason_str = "Logged in from another desktop client."
elif res.reason == KnownKickoutType.CHANGE_SERVER:
# TODO Reconnect automatically instead
reason_str = "KakaoTalk backend server changed."
elif res.reason == KnownKickoutType.ACCOUNT_DELETED:
reason_str = "Your KakaoTalk account was deleted!"
logout = True
else:
reason_str = f"Unknown reason ({res.reason})."
if not logout:
reason_suffix = " To reconnect, use the `sync` command."
# TODO What bridge state to push?
else:
reason_suffix = " You are now logged out."
await self.logout()
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}")
async def on_logged_in(self, oauth_credential: OAuthCredential) -> None: async def on_logged_in(self, oauth_credential: OAuthCredential) -> None:
self.log.debug(f"Successfully logged in as {oauth_credential.userId}") self.log.debug(f"Successfully logged in as {oauth_credential.userId}")
self.oauth_credential = oauth_credential self.oauth_credential = oauth_credential
@ -588,11 +601,10 @@ class User(DBUser, BaseUser):
self.client = Client(self, log=self.log.getChild("ktclient")) self.client = Client(self, log=self.log.getChild("ktclient"))
await self.save() await self.save()
self._is_logged_in = True self._is_logged_in = True
try: # TODO Retry network connection failures here, or in the client (like token refreshes are)?
self._logged_in_info = await self.client.fetch_logged_in_user(post_login=True) # Should also catch unlikely authentication errors
self._logged_in_info_time = time.monotonic() self._logged_in_info = await self.client.start()
except Exception: self._logged_in_info_time = time.monotonic()
self.log.exception("Failed to fetch post-login info")
asyncio.create_task(self.post_login(is_startup=True)) asyncio.create_task(self.post_login(is_startup=True))
@async_time(METRIC_MESSAGE) @async_time(METRIC_MESSAGE)

View File

@ -63,30 +63,40 @@ class UserClient {
/** /**
* DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead. * DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead.
* @param {string} mxid * @param {string} mxid
* @param {OAuthCredential} credential * @param {PeerClient} peerClient TODO Make RPC user-specific instead of needing this
*/ */
constructor(mxid, credential) { constructor(mxid, peerClient) {
if (!UserClient.#initializing) { if (!UserClient.#initializing) {
throw new Error("Private constructor") throw new Error("Private constructor")
} }
UserClient.#initializing = false UserClient.#initializing = false
this.mxid = mxid this.mxid = mxid
this.credential = credential this.peerClient = peerClient
} }
/** /**
* @param {string} mxid The ID of the associated Matrix user * @param {string} mxid The ID of the associated Matrix user
* @param {OAuthCredential} credential The tokens that API calls may use * @param {OAuthCredential} credential The token to log in with, obtained from prior login
* @param {PeerClient} peerClient What handles RPC
*/ */
static async create(mxid, credential) { static async create(mxid, credential, peerClient) {
this.#initializing = true this.#initializing = true
const userClient = new UserClient(mxid, credential) const userClient = new UserClient(mxid, peerClient)
userClient.#serviceClient = await ServiceApiClient.create(credential) userClient.#serviceClient = await ServiceApiClient.create(credential)
return userClient return userClient
} }
log(...text) {
console.log(`[API/${this.mxid}]`, ...text)
}
error(...text) {
console.error(`[API/${this.mxid}]`, ...text)
}
/** /**
* @param {Object} channel_props * @param {Object} channel_props
* @param {Long} channel_props.id * @param {Long} channel_props.id
@ -111,16 +121,68 @@ class UserClient {
} }
} }
close() { /**
this.#talkClient.close() * @param {OAuthCredential} credential The token to log in with, obtained from prior login
*/
async connect(credential) {
// TODO Don't re-login if possible. But must still return a LoginResult!
this.disconnect()
const res = await this.#talkClient.login(credential)
if (!res.success) return res
this.#talkClient.on("chat", (data, channel) => {
this.log(`Received chat message ${data.chat.logId} in channel ${channel.channelId}`)
return this.write("chat", {
//is_sequential: true, // TODO Make sequential per user & channel (if it isn't already)
chatlog: data.chat,
channelId: channel.channelId,
channelType: channel.info.type,
})
})
/* TODO Many more listeners
this.#talkClient.on("chat_read", (chat, channel, reader) => {
this.log(`chat_read in channel ${channel.channelId}`)
//chat.logId
})
*/
this.#talkClient.on("disconnected", (reason) => {
this.log(`Disconnected (reason=${reason})`)
this.disconnect()
return this.write("disconnected", {
reason: reason,
})
})
this.#talkClient.on("switch_server", () => {
this.log(`Server switch requested`)
return this.write("switch_server", {
is_sequential: true,
})
})
return res
}
disconnect() {
if (this.#talkClient.logon) {
this.#talkClient.close()
}
} }
/** /**
* TODO Maybe use a "write" method instead * Send a user-specific command with (optional) data to the socket.
* @param {string} command *
* @param {string} command - The data to write.
* @param {?object} data - The data to write.
*/ */
getCmd(command) { write(command, data) {
return `${command}:${this.mxid}` return this.peerClient.write({
id: --this.peerClient.notificationID,
command: `${command}:${this.mxid}`,
...data
})
} }
} }
@ -183,7 +245,7 @@ export default class PeerClient {
this.stopped = true this.stopped = true
this.#closeUsers() this.#closeUsers()
try { try {
await this.#write({ id: --this.notificationID, command: "quit", error }) await this.write({ id: --this.notificationID, command: "quit", error })
await promisify(cb => this.socket.end(cb)) await promisify(cb => this.socket.end(cb))
} catch (err) { } catch (err) {
this.error("Failed to end connection:", err) this.error("Failed to end connection:", err)
@ -203,7 +265,7 @@ export default class PeerClient {
#closeUsers() { #closeUsers() {
this.log("Closing all API clients for", this.peerID) this.log("Closing all API clients for", this.peerID)
for (const userClient of this.userClients.values()) { for (const userClient of this.userClients.values()) {
userClient.close() userClient.disconnect()
} }
this.userClients.clear() this.userClients.clear()
} }
@ -214,7 +276,7 @@ export default class PeerClient {
* @param {object} data - The data to write. * @param {object} data - The data to write.
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
#write(data) { write(data) {
return promisify(cb => this.socket.write(JSON.stringify(data, this.#writeReplacer) + "\n", cb)) return promisify(cb => this.socket.write(JSON.stringify(data, this.#writeReplacer) + "\n", cb))
} }
@ -231,7 +293,7 @@ export default class PeerClient {
} }
/** /**
* Log in. If this fails due to not having a device, also request a device passcode. * Obtain login tokens. If this fails due to not having a device, also request a device passcode.
* @param {Object} req * @param {Object} req
* @param {string} req.uuid * @param {string} req.uuid
* @param {Object} req.form * @param {Object} req.form
@ -284,22 +346,12 @@ export default class PeerClient {
return this.userClients.get(mxid) return this.userClients.get(mxid)
} }
/**
* Get the service client for the specified user ID, or create
* and return a new service client if no user ID is provided.
* @param {?string} mxid
* @param {?OAuthCredential} oauth_credential
*/
async #getServiceClient(mxid, oauth_credential) {
return this.#tryGetUser(mxid)?.serviceClient ||
await ServiceApiClient.create(oauth_credential)
}
/** /**
* @param {Object} req * @param {Object} req
* @param {OAuthCredential} req.oauth_credential * @param {OAuthCredential} req.oauth_credential
*/ */
handleRenew = async (req) => { handleRenew = async (req) => {
// TODO Cache per user? Reset API client objects?
const oAuthClient = await OAuthApiClient.create() const oAuthClient = await OAuthApiClient.create()
return await oAuthClient.renew(req.oauth_credential) return await oAuthClient.renew(req.oauth_credential)
} }
@ -309,72 +361,57 @@ export default class PeerClient {
* @param {string} req.mxid * @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential * @param {OAuthCredential} req.oauth_credential
*/ */
handleConnect = async (req) => { userStart = async (req) => {
// TODO Don't re-login if possible. But must still return a LoginResult! const userClient = this.#tryGetUser(req.mxid) || await UserClient.create(req.mxid, req.oauth_credential, this)
this.handleDisconnect(req) // TODO Should call requestMore/LessSettings instead
const res = await userClient.serviceClient.requestMyProfile()
const userClient = await UserClient.create(req.mxid, req.oauth_credential) if (res.success) {
const res = await userClient.talkClient.login(req.oauth_credential) this.userClients.set(req.mxid, userClient)
if (!res.success) return res }
this.userClients.set(req.mxid, userClient)
userClient.talkClient.on("chat", (data, channel) => {
this.log(`Received message ${data.chat.logId} in channel ${channel.channelId}`)
return this.#write({
id: --this.notificationID,
command: userClient.getCmd("message"),
//is_sequential: true, // TODO Make sequential per user & channel (if it isn't already)
chatlog: data.chat,
channelId: channel.channelId,
channelType: channel.info.type,
})
})
/* TODO Many more listeners
userClient.talkClient.on("chat_read", (chat, channel, reader) => {
this.log(`chat_read in channel ${channel.channelId}`)
//chat.logId
})
*/
return res return res
} }
/**
* @param {Object} req
* @param {string} req.mxid
*/
userStop = async (req) => {
this.handleDisconnect(req)
this.userClients.delete(req.mxid)
}
/**
* @param {Object} req
* @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential
*/
handleConnect = async (req) => {
return await this.#getUser(req.mxid).connect(req.oauth_credential)
}
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
*/ */
handleDisconnect = (req) => { handleDisconnect = (req) => {
const userClient = this.#tryGetUser(req.mxid) this.#tryGetUser(req.mxid)?.disconnect()
if (!!userClient) {
userClient.close()
this.userClients.delete(req.mxid)
return true
} else {
return false
}
} }
/** /**
* @param {Object} req * @param {Object} req
* @param {?string} req.mxid * @param {?string} req.mxid
* @param {?OAuthCredential} req.oauth_credential
*/ */
getOwnProfile = async (req) => { getOwnProfile = async (req) => {
const serviceClient = await this.#getServiceClient(req.mxid, req.oauth_credential) return await this.#getUser(req.mxid).serviceClient.requestMyProfile()
return await serviceClient.requestMyProfile()
} }
/** /**
* @param {Object} req * @param {Object} req
* @param {?string} req.mxid * @param {?string} req.mxid
* @param {?OAuthCredential} req.oauth_credential
* @param {Long} req.user_id * @param {Long} req.user_id
*/ */
getProfile = async (req) => { getProfile = async (req) => {
const serviceClient = await this.#getServiceClient(req.mxid, req.oauth_credential) return await this.#getUser(req.mxid).serviceClient.requestProfile(req.user_id)
return await serviceClient.requestProfile(req.user_id)
} }
/** /**
@ -431,8 +468,7 @@ export default class PeerClient {
* @param {?OAuthCredential} req.oauth_credential * @param {?OAuthCredential} req.oauth_credential
*/ */
listFriends = async (req) => { listFriends = async (req) => {
const serviceClient = await this.#getServiceClient(req.mxid, req.oauth_credential) return await this.#getUser(req.mxid).serviceClient.requestFriendList()
return await serviceClient.requestFriendList()
} }
/** /**
@ -544,6 +580,8 @@ export default class PeerClient {
register_device: this.registerDevice, register_device: this.registerDevice,
login: this.handleLogin, login: this.handleLogin,
renew: this.handleRenew, renew: this.handleRenew,
start: this.userStart,
stop: this.userStop,
connect: this.handleConnect, connect: this.handleConnect,
disconnect: this.handleDisconnect, disconnect: this.handleDisconnect,
get_own_profile: this.getOwnProfile, get_own_profile: this.getOwnProfile,
@ -575,7 +613,7 @@ export default class PeerClient {
// TODO Check if session is broken. If it is, close the PeerClient // TODO Check if session is broken. If it is, close the PeerClient
} }
} }
await this.#write(resp) await this.write(resp)
} }
#writeReplacer = function(key, value) { #writeReplacer = function(key, value) {