diff --git a/mautrix_amp/user.py b/mautrix_amp/user.py index b84e7fc..2584fcc 100644 --- a/mautrix_amp/user.py +++ b/mautrix_amp/user.py @@ -54,7 +54,6 @@ class User(DBUser, BaseUser): self._metric_value = defaultdict(lambda: False) self._connection_check_task = None self.client = None - self.username = None self.intent = None @classmethod diff --git a/mautrix_amp/web/provisioning_api.py b/mautrix_amp/web/provisioning_api.py index ab8370b..e3da47e 100644 --- a/mautrix_amp/web/provisioning_api.py +++ b/mautrix_amp/web/provisioning_api.py @@ -13,9 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Dict +from typing import Awaitable, Dict, Optional import logging -import json +import asyncio from aiohttp import web @@ -33,16 +33,14 @@ class ProvisioningAPI: self.app = web.Application() self.shared_secret = shared_secret self.app.router.add_get("/api/whoami", self.status) - self.app.router.add_options("/api/login", self.login_options) - self.app.router.add_post("/api/login", self.login) - self.app.router.add_post("/api/logout", self.logout) + self.app.router.add_get("/api/login", self.login) @property def _acao_headers(self) -> Dict[str, str]: return { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Authorization, Content-Type", - "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Methods": "GET", } @property @@ -55,22 +53,39 @@ class ProvisioningAPI: async def login_options(self, _: web.Request) -> web.Response: return web.Response(status=200, headers=self._headers) + @staticmethod + def _get_ws_token(request: web.Request) -> Optional[str]: + if not request.path.endswith("/login"): + return None + + try: + auth_parts = request.headers["Sec-WebSocket-Protocol"].split(",") + except KeyError: + return None + for part in auth_parts: + part = part.strip() + if part.startswith("net.maunium.amp.auth-"): + return part[len("net.maunium.amp.auth-"):] + return None + def check_token(self, request: web.Request) -> Awaitable['u.User']: try: token = request.headers["Authorization"] token = token[len("Bearer "):] except KeyError: - raise web.HTTPBadRequest(body='{"error": "Missing Authorization header"}', - headers=self._headers) + token = self._get_ws_token(request) + if not token: + raise web.HTTPBadRequest(text='{"error": "Missing Authorization header"}', + headers=self._headers) except IndexError: - raise web.HTTPBadRequest(body='{"error": "Malformed Authorization header"}', + raise web.HTTPBadRequest(text='{"error": "Malformed Authorization header"}', headers=self._headers) if token != self.shared_secret: - raise web.HTTPForbidden(body='{"error": "Invalid token"}', headers=self._headers) + raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers) try: user_id = request.query["user_id"] except KeyError: - raise web.HTTPBadRequest(body='{"error": "Missing user_id query param"}', + raise web.HTTPBadRequest(text='{"error": "Missing user_id query param"}', headers=self._headers) return u.User.get_by_mxid(UserID(user_id)) @@ -78,37 +93,31 @@ class ProvisioningAPI: async def status(self, request: web.Request) -> web.Response: user = await self.check_token(request) data = { - "permissions": user.permission_level, "mxid": user.mxid, - "twitter": None, + "amp": { + "connected": True, + } if await user.is_logged_in() else None, } - if await user.is_logged_in(): - data["twitter"] = (await user.get_info()).serialize() return web.json_response(data, headers=self._acao_headers) - async def login(self, request: web.Request) -> web.Response: + async def login(self, request: web.Request) -> web.WebSocketResponse: user = await self.check_token(request) - try: - data = await request.json() - except json.JSONDecodeError: - raise web.HTTPBadRequest(body='{"error": "Malformed JSON"}', headers=self._headers) + status = await user.client.start() + if status.is_logged_in: + raise web.HTTPConflict(text='{"error": "Already logged in"}', headers=self._headers) + ws = web.WebSocketResponse(protocols=["net.maunium.amp.login"]) + await ws.prepare(request) try: - auth_token = data["auth_token"] - csrf_token = data["csrf_token"] - except KeyError: - raise web.HTTPBadRequest(body='{"error": "Missing keys"}', headers=self._headers) - - try: - await user.connect(auth_token=auth_token, csrf_token=csrf_token) + async for url in user.client.login(): + self.log.debug("Sending QR URL %s to websocket", url) + await ws.send_json({"url": url}) except Exception: - self.log.debug("Failed to log in", exc_info=True) - raise web.HTTPUnauthorized(body='{"error": "Twitter authorization failed"}', - headers=self._headers) - return web.Response(body='{}', status=200, headers=self._headers) - - async def logout(self, request: web.Request) -> web.Response: - user = await self.check_token(request) - await user.logout() - return web.json_response({}, headers=self._acao_headers) + await ws.send_json({"success": False}) + self.log.exception("Error logging in") + else: + await ws.send_json({"success": True}) + asyncio.create_task(user.sync()) + await ws.close() + return ws diff --git a/puppet/src/puppet.js b/puppet/src/puppet.js index 91867e3..038b7ee 100644 --- a/puppet/src/puppet.js +++ b/puppet/src/puppet.js @@ -115,8 +115,12 @@ export default class MessagesPuppeteer { return } const qrSelector = "mw-authentication-container mw-qr-code" - this.log("Clicking Remember Me button") - await this.page.click("mat-slide-toggle:not(.mat-checked) > label") + if (!await this.page.$("mat-slide-toggle.mat-checked")) { + this.log("Clicking Remember Me button") + await this.page.click("mat-slide-toggle:not(.mat-checked) > label") + } else { + this.log("Remember Me button already clicked") + } this.log("Fetching current QR code") const currentQR = await this.page.$eval(qrSelector, element => element.getAttribute("data-qr-code"))