Recreate ServiceApiClient on token refresh

This commit is contained in:
Andrew Ferrazzutti 2022-04-23 13:51:39 -04:00
parent eebcef6b08
commit 51d02d3c34
2 changed files with 73 additions and 32 deletions

View File

@ -172,13 +172,6 @@ class Client:
def _oauth_credential(self) -> JSON: def _oauth_credential(self) -> JSON:
return self.user.oauth_credential.serialize() return self.user.oauth_credential.serialize()
@property
def _user_data(self) -> JSON:
return {
"mxid": self.user.mxid,
"oauth_credential": self._oauth_credential,
}
# region HTTP # region HTTP
def get( def get(
@ -210,7 +203,7 @@ class Client:
Receive the user's profile info in response. Receive the user's profile info in response.
""" """
try: try:
settings_struct = await self._api_user_request_result(SettingsStruct, "start") settings_struct = await self._api_user_cred_request_result(SettingsStruct, "start")
except SerializerError: except SerializerError:
self.log.exception("Unable to deserialize settings struct, but starting client anyways") self.log.exception("Unable to deserialize settings struct, but starting client anyways")
settings_struct = None settings_struct = None
@ -237,7 +230,7 @@ class Client:
async def renew_and_save(self) -> None: async def renew_and_save(self) -> None:
"""Renew and save the user's session tokens.""" """Renew and save the user's session tokens."""
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential) oauth_info = await self._api_user_cred_request_result(OAuthInfo, "renew", renew=False)
self.user.oauth_credential = oauth_info.credential self.user.oauth_credential = oauth_info.credential
await self.user.save() await self.user.save()
@ -247,7 +240,7 @@ class Client:
Receive a snapshot of account state in response. Receive a snapshot of account state in response.
""" """
try: try:
login_result = await self._api_user_request_result(LoginResult, "connect") login_result = await self._api_user_cred_request_result(LoginResult, "connect")
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}" assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
except SerializerError: except SerializerError:
self.log.exception("Unable to deserialize login result, but connecting anyways") self.log.exception("Unable to deserialize login result, but connecting anyways")
@ -429,31 +422,56 @@ class Client:
) )
# TODO Combine these into one # TODO Combine each of these pairs into one
async def _api_user_request_result( async def _api_user_request_result(
self, result_type: Type[ResultType], command: str, **data: JSON self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
) -> ResultType: ) -> ResultType:
renewed = False
while True: while True:
try: try:
return await self._api_request_result(result_type, command, **self._user_data, **data) return await self._api_request_result(result_type, command, mxid=self.user.mxid, **data)
except InvalidAccessToken: except InvalidAccessToken:
if renewed: if not renew:
raise raise
await self.renew_and_save() await self.renew_and_save()
renewed = True renew = False
async def _api_user_request_void(self, command: str, **data: JSON) -> None: async def _api_user_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
renewed = False
while True: while True:
try: try:
return await self._api_request_void(command, **self._user_data, **data) return await self._api_request_void(command, mxid=self.user.mxid, **data)
except InvalidAccessToken: except InvalidAccessToken:
if renewed: if not renew:
raise raise
await self.renew_and_save() await self.renew_and_save()
renewed = True renew = False
async def _api_user_cred_request_result(
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
) -> ResultType:
while True:
try:
return await self._api_user_request_result(
result_type, command, oauth_credential=self._oauth_credential, renew=False, **data
)
except InvalidAccessToken:
if not renew:
raise
await self.renew_and_save()
renew = False
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
while True:
try:
return await self._api_user_request_result(
command, oauth_credential=self._oauth_credential, renew=False, **data
)
except InvalidAccessToken:
if not renew:
raise
await self.renew_and_save()
renew = False
# endregion # endregion

View File

@ -99,19 +99,21 @@ class UserClient {
#serviceClient #serviceClient
get serviceClient() { return this.#serviceClient } get serviceClient() { return this.#serviceClient }
/** @type {OAuthCredential} */
#credential
get userId() { return this.#credential.userId }
/** /**
* DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead. * DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead.
* @param {Long} userId
* @param {string} mxid * @param {string} mxid
* @param {PeerClient} peerClient TODO Make RPC user-specific instead of needing this * @param {PeerClient} peerClient TODO Make RPC user-specific instead of needing this
*/ */
constructor(userId, mxid, peerClient) { constructor(mxid, peerClient) {
if (!UserClient.#initializing) { if (!UserClient.#initializing) {
throw new Error("Private constructor") throw new Error("Private constructor")
} }
UserClient.#initializing = false UserClient.#initializing = false
this.userId = userId
this.mxid = mxid this.mxid = mxid
this.peerClient = peerClient this.peerClient = peerClient
@ -280,9 +282,9 @@ class UserClient {
*/ */
static async create(mxid, credential, peerClient) { static async create(mxid, credential, peerClient) {
this.#initializing = true this.#initializing = true
const userClient = new UserClient(credential.userId, mxid, peerClient) const userClient = new UserClient(mxid, peerClient)
userClient.#serviceClient = await ServiceApiClient.create(credential) await userClient.setCredential(credential)
return userClient return userClient
} }
@ -295,6 +297,14 @@ class UserClient {
console.error(`[API/${this.mxid}]`, ...text) console.error(`[API/${this.mxid}]`, ...text)
} }
/**
* @param {OAuthCredential} credential
*/
async setCredential(credential) {
this.#serviceClient = await ServiceApiClient.create(credential)
this.#credential = credential
}
/** /**
* @param {ChannelProps} channelProps * @param {ChannelProps} channelProps
*/ */
@ -318,12 +328,15 @@ class UserClient {
} }
/** /**
* @param {OAuthCredential} credential The token to log in with, obtained from prior login * @param {?OAuthCredential} credential The token to log in with, obtained from prior login
*/ */
async connect(credential) { async connect(credential) {
// TODO Don't re-login if possible. But must still return a LoginResult! // TODO Don't re-login if possible. But must still return a LoginResult!
this.disconnect() this.disconnect()
return await this.#talkClient.login(credential) if (credential && this.#credential != credential) {
await this.setCredential(credential)
}
return await this.#talkClient.login(this.#credential)
} }
disconnect() { disconnect() {
@ -529,12 +542,17 @@ export default class PeerClient {
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential * @param {OAuthCredential} req.oauth_credential
*/ */
handleRenew = async (req) => { handleRenew = async (req) => {
// TODO Cache per user? Reset API client objects? const userClient = this.#tryGetUser(req.mxid)
const oAuthClient = await OAuthApiClient.create() const oAuthClient = await OAuthApiClient.create()
return await oAuthClient.renew(req.oauth_credential) const res = await oAuthClient.renew(req.oauth_credential)
if (res.success && userClient) {
await userClient.setCredential(res.result)
}
return res
} }
/** /**
@ -543,7 +561,12 @@ export default class PeerClient {
* @param {OAuthCredential} req.oauth_credential * @param {OAuthCredential} req.oauth_credential
*/ */
userStart = async (req) => { userStart = async (req) => {
const userClient = this.#tryGetUser(req.mxid) || await UserClient.create(req.mxid, req.oauth_credential, this) let userClient = this.#tryGetUser(req.mxid)
if (!userClient) {
userClient = await UserClient.create(req.mxid, req.oauth_credential, this)
} else {
await userClient.setCredential(req.oauth_credential)
}
const res = await this.#getSettings(userClient.serviceClient) const res = await this.#getSettings(userClient.serviceClient)
if (res.success) { if (res.success) {
this.userClients.set(req.mxid, userClient) this.userClients.set(req.mxid, userClient)
@ -563,7 +586,7 @@ export default class PeerClient {
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
* @param {OAuthCredential} req.oauth_credential * @param {?OAuthCredential} req.oauth_credential
*/ */
handleConnect = async (req) => { handleConnect = async (req) => {
return await this.#getUser(req.mxid).connect(req.oauth_credential) return await this.#getUser(req.mxid).connect(req.oauth_credential)