Recreate ServiceApiClient on token refresh

This commit is contained in:
Andrew Ferrazzutti 2022-04-23 13:51:39 -04:00
parent eebcef6b08
commit 51d02d3c34
2 changed files with 73 additions and 32 deletions

View File

@ -172,13 +172,6 @@ class Client:
def _oauth_credential(self) -> JSON:
return self.user.oauth_credential.serialize()
@property
def _user_data(self) -> JSON:
return {
"mxid": self.user.mxid,
"oauth_credential": self._oauth_credential,
}
# region HTTP
def get(
@ -210,7 +203,7 @@ class Client:
Receive the user's profile info in response.
"""
try:
settings_struct = await self._api_user_request_result(SettingsStruct, "start")
settings_struct = await self._api_user_cred_request_result(SettingsStruct, "start")
except SerializerError:
self.log.exception("Unable to deserialize settings struct, but starting client anyways")
settings_struct = None
@ -237,7 +230,7 @@ class Client:
async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens."""
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
oauth_info = await self._api_user_cred_request_result(OAuthInfo, "renew", renew=False)
self.user.oauth_credential = oauth_info.credential
await self.user.save()
@ -247,7 +240,7 @@ class Client:
Receive a snapshot of account state in response.
"""
try:
login_result = await self._api_user_request_result(LoginResult, "connect")
login_result = await self._api_user_cred_request_result(LoginResult, "connect")
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
except SerializerError:
self.log.exception("Unable to deserialize login result, but connecting anyways")
@ -429,31 +422,56 @@ class Client:
)
# TODO Combine these into one
# TODO Combine each of these pairs into one
async def _api_user_request_result(
self, result_type: Type[ResultType], command: str, **data: JSON
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
) -> ResultType:
renewed = False
while True:
try:
return await self._api_request_result(result_type, command, **self._user_data, **data)
return await self._api_request_result(result_type, command, mxid=self.user.mxid, **data)
except InvalidAccessToken:
if renewed:
if not renew:
raise
await self.renew_and_save()
renewed = True
renew = False
async def _api_user_request_void(self, command: str, **data: JSON) -> None:
renewed = False
async def _api_user_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
while True:
try:
return await self._api_request_void(command, **self._user_data, **data)
return await self._api_request_void(command, mxid=self.user.mxid, **data)
except InvalidAccessToken:
if renewed:
if not renew:
raise
await self.renew_and_save()
renewed = True
renew = False
async def _api_user_cred_request_result(
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
) -> ResultType:
while True:
try:
return await self._api_user_request_result(
result_type, command, oauth_credential=self._oauth_credential, renew=False, **data
)
except InvalidAccessToken:
if not renew:
raise
await self.renew_and_save()
renew = False
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
while True:
try:
return await self._api_user_request_result(
command, oauth_credential=self._oauth_credential, renew=False, **data
)
except InvalidAccessToken:
if not renew:
raise
await self.renew_and_save()
renew = False
# endregion

View File

@ -99,19 +99,21 @@ class UserClient {
#serviceClient
get serviceClient() { return this.#serviceClient }
/** @type {OAuthCredential} */
#credential
get userId() { return this.#credential.userId }
/**
* DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead.
* @param {Long} userId
* @param {string} mxid
* @param {PeerClient} peerClient TODO Make RPC user-specific instead of needing this
*/
constructor(userId, mxid, peerClient) {
constructor(mxid, peerClient) {
if (!UserClient.#initializing) {
throw new Error("Private constructor")
}
UserClient.#initializing = false
this.userId = userId
this.mxid = mxid
this.peerClient = peerClient
@ -280,9 +282,9 @@ class UserClient {
*/
static async create(mxid, credential, peerClient) {
this.#initializing = true
const userClient = new UserClient(credential.userId, mxid, peerClient)
const userClient = new UserClient(mxid, peerClient)
userClient.#serviceClient = await ServiceApiClient.create(credential)
await userClient.setCredential(credential)
return userClient
}
@ -295,6 +297,14 @@ class UserClient {
console.error(`[API/${this.mxid}]`, ...text)
}
/**
* @param {OAuthCredential} credential
*/
async setCredential(credential) {
this.#serviceClient = await ServiceApiClient.create(credential)
this.#credential = credential
}
/**
* @param {ChannelProps} channelProps
*/
@ -318,12 +328,15 @@ class UserClient {
}
/**
* @param {OAuthCredential} credential The token to log in with, obtained from prior login
* @param {?OAuthCredential} credential The token to log in with, obtained from prior login
*/
async connect(credential) {
// TODO Don't re-login if possible. But must still return a LoginResult!
this.disconnect()
return await this.#talkClient.login(credential)
if (credential && this.#credential != credential) {
await this.setCredential(credential)
}
return await this.#talkClient.login(this.#credential)
}
disconnect() {
@ -529,12 +542,17 @@ export default class PeerClient {
/**
* @param {Object} req
* @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential
*/
handleRenew = async (req) => {
// TODO Cache per user? Reset API client objects?
const userClient = this.#tryGetUser(req.mxid)
const oAuthClient = await OAuthApiClient.create()
return await oAuthClient.renew(req.oauth_credential)
const res = await oAuthClient.renew(req.oauth_credential)
if (res.success && userClient) {
await userClient.setCredential(res.result)
}
return res
}
/**
@ -543,7 +561,12 @@ export default class PeerClient {
* @param {OAuthCredential} req.oauth_credential
*/
userStart = async (req) => {
const userClient = this.#tryGetUser(req.mxid) || await UserClient.create(req.mxid, req.oauth_credential, this)
let userClient = this.#tryGetUser(req.mxid)
if (!userClient) {
userClient = await UserClient.create(req.mxid, req.oauth_credential, this)
} else {
await userClient.setCredential(req.oauth_credential)
}
const res = await this.#getSettings(userClient.serviceClient)
if (res.success) {
this.userClients.set(req.mxid, userClient)
@ -563,7 +586,7 @@ export default class PeerClient {
/**
* @param {Object} req
* @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential
* @param {?OAuthCredential} req.oauth_credential
*/
handleConnect = async (req) => {
return await this.#getUser(req.mxid).connect(req.oauth_credential)