Recreate ServiceApiClient on token refresh
This commit is contained in:
parent
eebcef6b08
commit
51d02d3c34
|
@ -172,13 +172,6 @@ class Client:
|
|||
def _oauth_credential(self) -> JSON:
|
||||
return self.user.oauth_credential.serialize()
|
||||
|
||||
@property
|
||||
def _user_data(self) -> JSON:
|
||||
return {
|
||||
"mxid": self.user.mxid,
|
||||
"oauth_credential": self._oauth_credential,
|
||||
}
|
||||
|
||||
# region HTTP
|
||||
|
||||
def get(
|
||||
|
@ -210,7 +203,7 @@ class Client:
|
|||
Receive the user's profile info in response.
|
||||
"""
|
||||
try:
|
||||
settings_struct = await self._api_user_request_result(SettingsStruct, "start")
|
||||
settings_struct = await self._api_user_cred_request_result(SettingsStruct, "start")
|
||||
except SerializerError:
|
||||
self.log.exception("Unable to deserialize settings struct, but starting client anyways")
|
||||
settings_struct = None
|
||||
|
@ -237,7 +230,7 @@ class Client:
|
|||
|
||||
async def renew_and_save(self) -> None:
|
||||
"""Renew and save the user's session tokens."""
|
||||
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
|
||||
oauth_info = await self._api_user_cred_request_result(OAuthInfo, "renew", renew=False)
|
||||
self.user.oauth_credential = oauth_info.credential
|
||||
await self.user.save()
|
||||
|
||||
|
@ -247,7 +240,7 @@ class Client:
|
|||
Receive a snapshot of account state in response.
|
||||
"""
|
||||
try:
|
||||
login_result = await self._api_user_request_result(LoginResult, "connect")
|
||||
login_result = await self._api_user_cred_request_result(LoginResult, "connect")
|
||||
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
||||
except SerializerError:
|
||||
self.log.exception("Unable to deserialize login result, but connecting anyways")
|
||||
|
@ -429,31 +422,56 @@ class Client:
|
|||
)
|
||||
|
||||
|
||||
# TODO Combine these into one
|
||||
# TODO Combine each of these pairs into one
|
||||
|
||||
async def _api_user_request_result(
|
||||
self, result_type: Type[ResultType], command: str, **data: JSON
|
||||
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
|
||||
) -> ResultType:
|
||||
renewed = False
|
||||
while True:
|
||||
try:
|
||||
return await self._api_request_result(result_type, command, **self._user_data, **data)
|
||||
return await self._api_request_result(result_type, command, mxid=self.user.mxid, **data)
|
||||
except InvalidAccessToken:
|
||||
if renewed:
|
||||
if not renew:
|
||||
raise
|
||||
await self.renew_and_save()
|
||||
renewed = True
|
||||
renew = False
|
||||
|
||||
async def _api_user_request_void(self, command: str, **data: JSON) -> None:
|
||||
renewed = False
|
||||
async def _api_user_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
||||
while True:
|
||||
try:
|
||||
return await self._api_request_void(command, **self._user_data, **data)
|
||||
return await self._api_request_void(command, mxid=self.user.mxid, **data)
|
||||
except InvalidAccessToken:
|
||||
if renewed:
|
||||
if not renew:
|
||||
raise
|
||||
await self.renew_and_save()
|
||||
renewed = True
|
||||
renew = False
|
||||
|
||||
|
||||
async def _api_user_cred_request_result(
|
||||
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
|
||||
) -> ResultType:
|
||||
while True:
|
||||
try:
|
||||
return await self._api_user_request_result(
|
||||
result_type, command, oauth_credential=self._oauth_credential, renew=False, **data
|
||||
)
|
||||
except InvalidAccessToken:
|
||||
if not renew:
|
||||
raise
|
||||
await self.renew_and_save()
|
||||
renew = False
|
||||
|
||||
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
||||
while True:
|
||||
try:
|
||||
return await self._api_user_request_result(
|
||||
command, oauth_credential=self._oauth_credential, renew=False, **data
|
||||
)
|
||||
except InvalidAccessToken:
|
||||
if not renew:
|
||||
raise
|
||||
await self.renew_and_save()
|
||||
renew = False
|
||||
|
||||
# endregion
|
||||
|
||||
|
|
|
@ -99,19 +99,21 @@ class UserClient {
|
|||
#serviceClient
|
||||
get serviceClient() { return this.#serviceClient }
|
||||
|
||||
/** @type {OAuthCredential} */
|
||||
#credential
|
||||
get userId() { return this.#credential.userId }
|
||||
|
||||
/**
|
||||
* DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead.
|
||||
* @param {Long} userId
|
||||
* @param {string} mxid
|
||||
* @param {PeerClient} peerClient TODO Make RPC user-specific instead of needing this
|
||||
*/
|
||||
constructor(userId, mxid, peerClient) {
|
||||
constructor(mxid, peerClient) {
|
||||
if (!UserClient.#initializing) {
|
||||
throw new Error("Private constructor")
|
||||
}
|
||||
UserClient.#initializing = false
|
||||
|
||||
this.userId = userId
|
||||
this.mxid = mxid
|
||||
this.peerClient = peerClient
|
||||
|
||||
|
@ -280,9 +282,9 @@ class UserClient {
|
|||
*/
|
||||
static async create(mxid, credential, peerClient) {
|
||||
this.#initializing = true
|
||||
const userClient = new UserClient(credential.userId, mxid, peerClient)
|
||||
const userClient = new UserClient(mxid, peerClient)
|
||||
|
||||
userClient.#serviceClient = await ServiceApiClient.create(credential)
|
||||
await userClient.setCredential(credential)
|
||||
return userClient
|
||||
}
|
||||
|
||||
|
@ -295,6 +297,14 @@ class UserClient {
|
|||
console.error(`[API/${this.mxid}]`, ...text)
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {OAuthCredential} credential
|
||||
*/
|
||||
async setCredential(credential) {
|
||||
this.#serviceClient = await ServiceApiClient.create(credential)
|
||||
this.#credential = credential
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {ChannelProps} channelProps
|
||||
*/
|
||||
|
@ -318,12 +328,15 @@ class UserClient {
|
|||
}
|
||||
|
||||
/**
|
||||
* @param {OAuthCredential} credential The token to log in with, obtained from prior login
|
||||
* @param {?OAuthCredential} credential The token to log in with, obtained from prior login
|
||||
*/
|
||||
async connect(credential) {
|
||||
// TODO Don't re-login if possible. But must still return a LoginResult!
|
||||
this.disconnect()
|
||||
return await this.#talkClient.login(credential)
|
||||
if (credential && this.#credential != credential) {
|
||||
await this.setCredential(credential)
|
||||
}
|
||||
return await this.#talkClient.login(this.#credential)
|
||||
}
|
||||
|
||||
disconnect() {
|
||||
|
@ -529,12 +542,17 @@ export default class PeerClient {
|
|||
|
||||
/**
|
||||
* @param {Object} req
|
||||
* @param {string} req.mxid
|
||||
* @param {OAuthCredential} req.oauth_credential
|
||||
*/
|
||||
handleRenew = async (req) => {
|
||||
// TODO Cache per user? Reset API client objects?
|
||||
const userClient = this.#tryGetUser(req.mxid)
|
||||
const oAuthClient = await OAuthApiClient.create()
|
||||
return await oAuthClient.renew(req.oauth_credential)
|
||||
const res = await oAuthClient.renew(req.oauth_credential)
|
||||
if (res.success && userClient) {
|
||||
await userClient.setCredential(res.result)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -543,7 +561,12 @@ export default class PeerClient {
|
|||
* @param {OAuthCredential} req.oauth_credential
|
||||
*/
|
||||
userStart = async (req) => {
|
||||
const userClient = this.#tryGetUser(req.mxid) || await UserClient.create(req.mxid, req.oauth_credential, this)
|
||||
let userClient = this.#tryGetUser(req.mxid)
|
||||
if (!userClient) {
|
||||
userClient = await UserClient.create(req.mxid, req.oauth_credential, this)
|
||||
} else {
|
||||
await userClient.setCredential(req.oauth_credential)
|
||||
}
|
||||
const res = await this.#getSettings(userClient.serviceClient)
|
||||
if (res.success) {
|
||||
this.userClients.set(req.mxid, userClient)
|
||||
|
@ -563,7 +586,7 @@ export default class PeerClient {
|
|||
/**
|
||||
* @param {Object} req
|
||||
* @param {string} req.mxid
|
||||
* @param {OAuthCredential} req.oauth_credential
|
||||
* @param {?OAuthCredential} req.oauth_credential
|
||||
*/
|
||||
handleConnect = async (req) => {
|
||||
return await this.#getUser(req.mxid).connect(req.oauth_credential)
|
||||
|
|
Loading…
Reference in New Issue