Recreate ServiceApiClient on token refresh
This commit is contained in:
parent
eebcef6b08
commit
51d02d3c34
|
@ -172,13 +172,6 @@ class Client:
|
||||||
def _oauth_credential(self) -> JSON:
|
def _oauth_credential(self) -> JSON:
|
||||||
return self.user.oauth_credential.serialize()
|
return self.user.oauth_credential.serialize()
|
||||||
|
|
||||||
@property
|
|
||||||
def _user_data(self) -> JSON:
|
|
||||||
return {
|
|
||||||
"mxid": self.user.mxid,
|
|
||||||
"oauth_credential": self._oauth_credential,
|
|
||||||
}
|
|
||||||
|
|
||||||
# region HTTP
|
# region HTTP
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
|
@ -210,7 +203,7 @@ class Client:
|
||||||
Receive the user's profile info in response.
|
Receive the user's profile info in response.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
settings_struct = await self._api_user_request_result(SettingsStruct, "start")
|
settings_struct = await self._api_user_cred_request_result(SettingsStruct, "start")
|
||||||
except SerializerError:
|
except SerializerError:
|
||||||
self.log.exception("Unable to deserialize settings struct, but starting client anyways")
|
self.log.exception("Unable to deserialize settings struct, but starting client anyways")
|
||||||
settings_struct = None
|
settings_struct = None
|
||||||
|
@ -237,7 +230,7 @@ class Client:
|
||||||
|
|
||||||
async def renew_and_save(self) -> None:
|
async def renew_and_save(self) -> None:
|
||||||
"""Renew and save the user's session tokens."""
|
"""Renew and save the user's session tokens."""
|
||||||
oauth_info = await self._api_request_result(OAuthInfo, "renew", oauth_credential=self._oauth_credential)
|
oauth_info = await self._api_user_cred_request_result(OAuthInfo, "renew", renew=False)
|
||||||
self.user.oauth_credential = oauth_info.credential
|
self.user.oauth_credential = oauth_info.credential
|
||||||
await self.user.save()
|
await self.user.save()
|
||||||
|
|
||||||
|
@ -247,7 +240,7 @@ class Client:
|
||||||
Receive a snapshot of account state in response.
|
Receive a snapshot of account state in response.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
login_result = await self._api_user_request_result(LoginResult, "connect")
|
login_result = await self._api_user_cred_request_result(LoginResult, "connect")
|
||||||
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
assert self.user.ktid == login_result.userId, f"User ID mismatch: expected {self.user.ktid}, got {login_result.userId}"
|
||||||
except SerializerError:
|
except SerializerError:
|
||||||
self.log.exception("Unable to deserialize login result, but connecting anyways")
|
self.log.exception("Unable to deserialize login result, but connecting anyways")
|
||||||
|
@ -429,31 +422,56 @@ class Client:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO Combine these into one
|
# TODO Combine each of these pairs into one
|
||||||
|
|
||||||
async def _api_user_request_result(
|
async def _api_user_request_result(
|
||||||
self, result_type: Type[ResultType], command: str, **data: JSON
|
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
|
||||||
) -> ResultType:
|
) -> ResultType:
|
||||||
renewed = False
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
return await self._api_request_result(result_type, command, **self._user_data, **data)
|
return await self._api_request_result(result_type, command, mxid=self.user.mxid, **data)
|
||||||
except InvalidAccessToken:
|
except InvalidAccessToken:
|
||||||
if renewed:
|
if not renew:
|
||||||
raise
|
raise
|
||||||
await self.renew_and_save()
|
await self.renew_and_save()
|
||||||
renewed = True
|
renew = False
|
||||||
|
|
||||||
async def _api_user_request_void(self, command: str, **data: JSON) -> None:
|
async def _api_user_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
||||||
renewed = False
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
return await self._api_request_void(command, **self._user_data, **data)
|
return await self._api_request_void(command, mxid=self.user.mxid, **data)
|
||||||
except InvalidAccessToken:
|
except InvalidAccessToken:
|
||||||
if renewed:
|
if not renew:
|
||||||
raise
|
raise
|
||||||
await self.renew_and_save()
|
await self.renew_and_save()
|
||||||
renewed = True
|
renew = False
|
||||||
|
|
||||||
|
|
||||||
|
async def _api_user_cred_request_result(
|
||||||
|
self, result_type: Type[ResultType], command: str, *, renew: bool = True, **data: JSON
|
||||||
|
) -> ResultType:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await self._api_user_request_result(
|
||||||
|
result_type, command, oauth_credential=self._oauth_credential, renew=False, **data
|
||||||
|
)
|
||||||
|
except InvalidAccessToken:
|
||||||
|
if not renew:
|
||||||
|
raise
|
||||||
|
await self.renew_and_save()
|
||||||
|
renew = False
|
||||||
|
|
||||||
|
async def _api_user_cred_request_void(self, command: str, *, renew: bool = True, **data: JSON) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await self._api_user_request_result(
|
||||||
|
command, oauth_credential=self._oauth_credential, renew=False, **data
|
||||||
|
)
|
||||||
|
except InvalidAccessToken:
|
||||||
|
if not renew:
|
||||||
|
raise
|
||||||
|
await self.renew_and_save()
|
||||||
|
renew = False
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
|
@ -99,19 +99,21 @@ class UserClient {
|
||||||
#serviceClient
|
#serviceClient
|
||||||
get serviceClient() { return this.#serviceClient }
|
get serviceClient() { return this.#serviceClient }
|
||||||
|
|
||||||
|
/** @type {OAuthCredential} */
|
||||||
|
#credential
|
||||||
|
get userId() { return this.#credential.userId }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead.
|
* DO NOT CONSTRUCT DIRECTLY. Callers should use {@link UserClient#create} instead.
|
||||||
* @param {Long} userId
|
|
||||||
* @param {string} mxid
|
* @param {string} mxid
|
||||||
* @param {PeerClient} peerClient TODO Make RPC user-specific instead of needing this
|
* @param {PeerClient} peerClient TODO Make RPC user-specific instead of needing this
|
||||||
*/
|
*/
|
||||||
constructor(userId, mxid, peerClient) {
|
constructor(mxid, peerClient) {
|
||||||
if (!UserClient.#initializing) {
|
if (!UserClient.#initializing) {
|
||||||
throw new Error("Private constructor")
|
throw new Error("Private constructor")
|
||||||
}
|
}
|
||||||
UserClient.#initializing = false
|
UserClient.#initializing = false
|
||||||
|
|
||||||
this.userId = userId
|
|
||||||
this.mxid = mxid
|
this.mxid = mxid
|
||||||
this.peerClient = peerClient
|
this.peerClient = peerClient
|
||||||
|
|
||||||
|
@ -280,9 +282,9 @@ class UserClient {
|
||||||
*/
|
*/
|
||||||
static async create(mxid, credential, peerClient) {
|
static async create(mxid, credential, peerClient) {
|
||||||
this.#initializing = true
|
this.#initializing = true
|
||||||
const userClient = new UserClient(credential.userId, mxid, peerClient)
|
const userClient = new UserClient(mxid, peerClient)
|
||||||
|
|
||||||
userClient.#serviceClient = await ServiceApiClient.create(credential)
|
await userClient.setCredential(credential)
|
||||||
return userClient
|
return userClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,6 +297,14 @@ class UserClient {
|
||||||
console.error(`[API/${this.mxid}]`, ...text)
|
console.error(`[API/${this.mxid}]`, ...text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {OAuthCredential} credential
|
||||||
|
*/
|
||||||
|
async setCredential(credential) {
|
||||||
|
this.#serviceClient = await ServiceApiClient.create(credential)
|
||||||
|
this.#credential = credential
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {ChannelProps} channelProps
|
* @param {ChannelProps} channelProps
|
||||||
*/
|
*/
|
||||||
|
@ -318,12 +328,15 @@ class UserClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {OAuthCredential} credential The token to log in with, obtained from prior login
|
* @param {?OAuthCredential} credential The token to log in with, obtained from prior login
|
||||||
*/
|
*/
|
||||||
async connect(credential) {
|
async connect(credential) {
|
||||||
// TODO Don't re-login if possible. But must still return a LoginResult!
|
// TODO Don't re-login if possible. But must still return a LoginResult!
|
||||||
this.disconnect()
|
this.disconnect()
|
||||||
return await this.#talkClient.login(credential)
|
if (credential && this.#credential != credential) {
|
||||||
|
await this.setCredential(credential)
|
||||||
|
}
|
||||||
|
return await this.#talkClient.login(this.#credential)
|
||||||
}
|
}
|
||||||
|
|
||||||
disconnect() {
|
disconnect() {
|
||||||
|
@ -529,12 +542,17 @@ export default class PeerClient {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {Object} req
|
* @param {Object} req
|
||||||
|
* @param {string} req.mxid
|
||||||
* @param {OAuthCredential} req.oauth_credential
|
* @param {OAuthCredential} req.oauth_credential
|
||||||
*/
|
*/
|
||||||
handleRenew = async (req) => {
|
handleRenew = async (req) => {
|
||||||
// TODO Cache per user? Reset API client objects?
|
const userClient = this.#tryGetUser(req.mxid)
|
||||||
const oAuthClient = await OAuthApiClient.create()
|
const oAuthClient = await OAuthApiClient.create()
|
||||||
return await oAuthClient.renew(req.oauth_credential)
|
const res = await oAuthClient.renew(req.oauth_credential)
|
||||||
|
if (res.success && userClient) {
|
||||||
|
await userClient.setCredential(res.result)
|
||||||
|
}
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -543,7 +561,12 @@ export default class PeerClient {
|
||||||
* @param {OAuthCredential} req.oauth_credential
|
* @param {OAuthCredential} req.oauth_credential
|
||||||
*/
|
*/
|
||||||
userStart = async (req) => {
|
userStart = async (req) => {
|
||||||
const userClient = this.#tryGetUser(req.mxid) || await UserClient.create(req.mxid, req.oauth_credential, this)
|
let userClient = this.#tryGetUser(req.mxid)
|
||||||
|
if (!userClient) {
|
||||||
|
userClient = await UserClient.create(req.mxid, req.oauth_credential, this)
|
||||||
|
} else {
|
||||||
|
await userClient.setCredential(req.oauth_credential)
|
||||||
|
}
|
||||||
const res = await this.#getSettings(userClient.serviceClient)
|
const res = await this.#getSettings(userClient.serviceClient)
|
||||||
if (res.success) {
|
if (res.success) {
|
||||||
this.userClients.set(req.mxid, userClient)
|
this.userClients.set(req.mxid, userClient)
|
||||||
|
@ -563,7 +586,7 @@ export default class PeerClient {
|
||||||
/**
|
/**
|
||||||
* @param {Object} req
|
* @param {Object} req
|
||||||
* @param {string} req.mxid
|
* @param {string} req.mxid
|
||||||
* @param {OAuthCredential} req.oauth_credential
|
* @param {?OAuthCredential} req.oauth_credential
|
||||||
*/
|
*/
|
||||||
handleConnect = async (req) => {
|
handleConnect = async (req) => {
|
||||||
return await this.#getUser(req.mxid).connect(req.oauth_credential)
|
return await this.#getUser(req.mxid).connect(req.oauth_credential)
|
||||||
|
|
Loading…
Reference in New Issue