Compare commits

...

9 Commits

12 changed files with 74 additions and 33 deletions

View File

@ -69,10 +69,30 @@ async def login(evt: CommandEvent) -> None:
await evt.reply("You're already logged in") await evt.reply("You're already logged in")
return return
save = len(evt.args) > 0 and evt.args[0] == "--save" num_args = len(evt.args)
email = evt.args[0 if not save else 1] if len(evt.args) > 0 else None save = num_args > 0 and evt.args[0] == "--save"
# TODO Once web login is implemented, don't make <email> a mandatory argument
if not save and num_args != 1:
await evt.reply("**Usage:** `$cmdprefix+sp login [--save] <email>`")
return
email = evt.args[0 if not save else 1] if num_args > 0 else None
if email: if email:
try:
creds = await LoginCredential.get_by_mxid(evt.sender.mxid)
except:
evt.log.exception("Exception while looking for saved password")
creds = None
if creds and creds.email == email:
await evt.reply("Logging in with saved password")
evt.sender.command_status = {
"action": "Login with saved password",
"room_id": evt.room_id,
"save": True,
}
await _login_with_password(evt, email, creds.password, evt.sender.force_login)
return
evt.sender.command_status = { evt.sender.command_status = {
"action": "Login", "action": "Login",
"room_id": evt.room_id, "room_id": evt.room_id,
@ -81,20 +101,6 @@ async def login(evt: CommandEvent) -> None:
"save": save, "save": save,
"forced": evt.sender.force_login, "forced": evt.sender.force_login,
} }
try:
creds = await LoginCredential.get_by_mxid(evt.sender.mxid)
except:
evt.log.exception("Exception while looking for saved password")
creds = None
if creds and creds.email == email:
await evt.reply("Logging in with saved password")
await _login_with_password(
evt,
evt.sender.command_status.pop("email"),
creds.password,
evt.sender.command_status.pop("forced"),
)
return
""" TODO Implement web login """ TODO Implement web login
if evt.bridge.public_website: if evt.bridge.public_website:

View File

@ -117,6 +117,7 @@ class Config(BaseBridgeConfig):
else: else:
copy("rpc.connection.host") copy("rpc.connection.host")
copy("rpc.connection.port") copy("rpc.connection.port")
copy("rpc.logging_keys")
def _get_permissions(self, key: str) -> tuple[bool, bool, bool, str]: def _get_permissions(self, key: str) -> tuple[bool, bool, bool, str]:
level = self["bridge.permissions"].get(key, "") level = self["bridge.permissions"].get(key, "")

View File

@ -252,6 +252,11 @@ rpc:
# Only for type: tcp # Only for type: tcp
host: localhost host: localhost
port: 29392 port: 29392
# Command arguments to print in logs. Optional.
# TODO Support nested arguments, like channel_props.ktid
logging_keys:
- mxid
#- channel_props
# Python logging configuration. # Python logging configuration.
# #

View File

@ -109,6 +109,10 @@ class Client:
await cls._rpc_client.connect() await cls._rpc_client.connect()
await cls._rpc_client.wait_for_disconnection() await cls._rpc_client.wait_for_disconnection()
@classmethod
def wait_for_connection(cls) -> Awaitable[None]:
return cls._rpc_client.wait_for_connection()
@classmethod @classmethod
def stop_cls(cls) -> None: def stop_cls(cls) -> None:
"""Stop and disconnect from the Node backend.""" """Stop and disconnect from the Node backend."""

View File

@ -1881,7 +1881,7 @@ class Portal(DBPortal, BasePortal):
if not self.is_direct: if not self.is_direct:
self._main_intent = self.az.intent self._main_intent = self.az.intent
else: else:
# TODO Save kt_sender in DB instead? Depends on if DM channels are shared... # TODO Save kt_sender in DB instead? Only do that if keeping a unique DM portal for each receiver
user = await u.User.get_by_ktid(self.kt_receiver) user = await u.User.get_by_ktid(self.kt_receiver)
assert user, f"Found no user for this portal's receiver of {self.kt_receiver}" assert user, f"Found no user for this portal's receiver of {self.kt_receiver}"
if self.kt_type == KnownChannelType.MemoChat: if self.kt_type == KnownChannelType.MemoChat:
@ -1920,7 +1920,7 @@ class Portal(DBPortal, BasePortal):
create: bool = True, create: bool = True,
kt_type: ChannelType | None = None, kt_type: ChannelType | None = None,
) -> Portal | None: ) -> Portal | None:
# TODO Find out if direct channels are shared. If so, don't need kt_receiver! # TODO Direct chats are shared, so can remove kt_receiver if DM portals should be shared
if kt_type: if kt_type:
kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0 kt_receiver = kt_receiver if KnownChannelType.is_direct(kt_type) else 0
ktid_full = (ktid, kt_receiver) ktid_full = (ktid, kt_receiver)

View File

@ -86,6 +86,7 @@ class RPCClient:
_is_connected: CancelableEvent _is_connected: CancelableEvent
_is_disconnected: CancelableEvent _is_disconnected: CancelableEvent
_connection_lock: asyncio.Lock _connection_lock: asyncio.Lock
_logging_keys: list[str]
def __init__(self, config: Config, register_config_key: str) -> None: def __init__(self, config: Config, register_config_key: str) -> None:
self.config = config self.config = config
@ -105,6 +106,7 @@ class RPCClient:
self._is_disconnected = CancelableEvent(self.loop) self._is_disconnected = CancelableEvent(self.loop)
self._is_disconnected.set() self._is_disconnected.set()
self._connection_lock = asyncio.Lock() self._connection_lock = asyncio.Lock()
self._logging_keys = config["rpc.logging_keys"]
async def connect(self) -> None: async def connect(self) -> None:
async with self._connection_lock: async with self._connection_lock:
@ -147,7 +149,8 @@ class RPCClient:
self._read_task = self.loop.create_task(self._try_read_loop()) self._read_task = self.loop.create_task(self._try_read_loop())
await self._raw_request("register", await self._raw_request("register",
peer_id=self.config["appservice.address"], peer_id=self.config["appservice.address"],
register_config=self.config[self.register_config_key]) register_config=self.config[self.register_config_key],
logging_keys=self._logging_keys)
self._is_connected.set() self._is_connected.set()
self._is_disconnected.clear() self._is_disconnected.clear()
@ -302,7 +305,10 @@ class RPCClient:
req_id = self._next_req_id req_id = self._next_req_id
future = self._response_waiters[req_id] = self.loop.create_future() future = self._response_waiters[req_id] = self.loop.create_future()
req = {"id": req_id, "command": command, **data} req = {"id": req_id, "command": command, **data}
self.log.debug("Request %d: %s", req_id, command) self.log.debug("Request %d: %s", req_id,
', '.join(
[command] +
[f"{k}: {data[k]}" for k in self._logging_keys if k in data]))
assert self._writer is not None assert self._writer is not None
self._writer.write(json.dumps(req).encode("utf-8")) self._writer.write(json.dumps(req).encode("utf-8"))
self._writer.write(b"\n") self._writer.write(b"\n")

View File

@ -317,6 +317,8 @@ class User(DBUser, BaseUser):
oauth_credential = await Client.login(uuid=uuid, form=form, forced=True) oauth_credential = await Client.login(uuid=uuid, form=form, forced=True)
except OAuthException as e: except OAuthException as e:
latest_exc = e latest_exc = e
else:
return False
if oauth_credential: if oauth_credential:
self.oauth_credential = oauth_credential self.oauth_credential = oauth_credential
await self.save() await self.save()
@ -388,6 +390,9 @@ class User(DBUser, BaseUser):
) -> None: ) -> None:
try: try:
if not await self._load_session(is_startup=is_startup) and self.has_state: if not await self._load_session(is_startup=is_startup) and self.has_state:
self.log.debug("reload_session failure: wait for connection to Node module before prompting for manual login")
await Client.wait_for_connection()
self.log.debug("reload_session failure: now connected to Node module")
await self.send_bridge_notice( await self.send_bridge_notice(
"Logged out of KakaoTalk. Must use the `login` command to log back in.", "Logged out of KakaoTalk. Must use the `login` command to log back in.",
important=True, important=True,
@ -429,7 +434,6 @@ class User(DBUser, BaseUser):
async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None: async def logout(self, *, remove_ktid: bool = True, reset_device: bool = False) -> None:
if self._client: if self._client:
# TODO Look for a logout API call
await self._client.stop() await self._client.stop()
if remove_ktid: if remove_ktid:
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT) await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
@ -746,7 +750,10 @@ class User(DBUser, BaseUser):
reason_suffix = "To reconnect, use the `sync` command." reason_suffix = "To reconnect, use the `sync` command."
else: else:
reason_suffix = "You are now logged out. To log back in, use the `login` command." reason_suffix = "You are now logged out. To log back in, use the `login` command."
await self.send_bridge_notice(f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}") await self.send_bridge_notice(
f"Disconnected from KakaoTalk: {reason_str} {reason_suffix}",
important=True,
)
async def on_error(self, error: JSON) -> None: async def on_error(self, error: JSON) -> None:
await self.send_bridge_notice( await self.send_bridge_notice(

View File

@ -2,3 +2,6 @@
If `type` is `unix`, `path` is the path where to create the socket, and `force` is whether to overwrite the socket file if it already exists. If `type` is `unix`, `path` is the path where to create the socket, and `force` is whether to overwrite the socket file if it already exists.
If `type` is `tcp`, `port` and `host` are the host/port where to listen. If `type` is `tcp`, `port` and `host` are the host/port where to listen.
### Register timeout
`register_timeout` is the amount of time (in milliseconds) that a connecting peer must send a "register" command after initiating a connection.

View File

@ -3,5 +3,6 @@
"type": "unix", "type": "unix",
"path": "/var/run/matrix-appservice-kakaotalk/rpc.sock", "path": "/var/run/matrix-appservice-kakaotalk/rpc.sock",
"force": false "force": false
} },
"register_timeout": 3000
} }

View File

@ -422,9 +422,10 @@ export default class PeerClient {
this.stopped = false this.stopped = false
this.notificationID = 0 this.notificationID = 0
this.maxCommandID = 0 this.maxCommandID = 0
this.peerID = null this.peerID = ""
this.deviceName = "KakaoTalk Bridge" this.deviceName = "KakaoTalk Bridge"
/** @type {[string]} */
this.loggingKeys = []
/** @type {Map<string, UserClient>} */ /** @type {Map<string, UserClient>} */
this.userClients = new Map() this.userClients = new Map()
} }
@ -455,10 +456,10 @@ export default class PeerClient {
setTimeout(() => { setTimeout(() => {
if (!this.peerID && !this.stopped) { if (!this.peerID && !this.stopped) {
this.log("Didn't receive register request within 3 seconds, terminating") this.log(`Didn't receive register request within ${this.manager.registerTimeout/1000} seconds, terminating`)
this.stop("Register request timeout") this.stop("Register request timeout")
} }
}, 3000) }, this.manager.registerTimeout)
} }
async stop(error = null) { async stop(error = null) {
@ -482,11 +483,11 @@ export default class PeerClient {
if (this.peerID && this.manager.clients.get(this.peerID) === this) { if (this.peerID && this.manager.clients.get(this.peerID) === this) {
this.manager.clients.delete(this.peerID) this.manager.clients.delete(this.peerID)
} }
this.log(`Connection closed (peer: ${this.peerID})`) this.log(`Connection closed (peer: ${this.peerID || "unknown peer"})`)
} }
#closeUsers() { #closeUsers() {
this.log("Closing all API clients for", this.peerID) this.log(`Closing all API clients for ${this.peerID || "unknown peer"}`)
for (const userClient of this.userClients.values()) { for (const userClient of this.userClients.values()) {
userClient.disconnect() userClient.disconnect()
} }
@ -526,7 +527,6 @@ export default class PeerClient {
* request failed, its status is stored here. * request failed, its status is stored here.
*/ */
handleLogin = async (req) => { handleLogin = async (req) => {
// TODO Look for a logout API call
const authClient = await this.#createAuthClient(req.uuid) const authClient = await this.#createAuthClient(req.uuid)
const loginRes = await authClient.login(req.form, req.forced) const loginRes = await authClient.login(req.form, req.forced)
if (loginRes.status === KnownAuthStatusCode.DEVICE_NOT_REGISTERED) { if (loginRes.status === KnownAuthStatusCode.DEVICE_NOT_REGISTERED) {
@ -1166,10 +1166,12 @@ export default class PeerClient {
* @param {string} req.peer_id * @param {string} req.peer_id
* @param {Object} req.register_config * @param {Object} req.register_config
* @param {string} req.register_config.device_name * @param {string} req.register_config.device_name
* @param {?[string]} req.logging_keys
*/ */
handleRegister = async (req) => { handleRegister = async (req) => {
this.peerID = req.peer_id this.peerID = req.peer_id
this.deviceName = req.register_config.device_name || this.deviceName this.deviceName = req.register_config.device_name || this.deviceName
this.loggingKeys = req.logging_keys || this.loggingKeys
this.log(`Registered socket ${this.connID} -> ${this.peerID}`) this.log(`Registered socket ${this.connID} -> ${this.peerID}`)
if (this.manager.clients.has(this.peerID)) { if (this.manager.clients.has(this.peerID)) {
const oldClient = this.manager.clients.get(this.peerID) const oldClient = this.manager.clients.get(this.peerID)
@ -1200,7 +1202,12 @@ export default class PeerClient {
this.log("Ignoring old request", req.id) this.log("Ignoring old request", req.id)
return return
} }
this.log("Received request", req.id, "with command", req.command) this.log(
`Request ${req.id}:`,
[req.command].concat(
this.loggingKeys.filter(k => k in req).map(k => `${k}: ${JSON.stringify(req[k], this.#writeReplacer)}`))
.join(', ')
)
this.maxCommandID = req.id this.maxCommandID = req.id
let handler let handler
if (!this.peerID) { if (!this.peerID) {

View File

@ -22,8 +22,9 @@ import { promisify } from "./util.js"
export default class ClientManager { export default class ClientManager {
constructor(listenConfig) { constructor(listenConfig, registerTimeout) {
this.listenConfig = listenConfig this.listenConfig = listenConfig
this.registerTimeout = registerTimeout
this.server = net.createServer(this.acceptConnection) this.server = net.createServer(this.acceptConnection)
this.connections = [] this.connections = []
this.clients = new Map() this.clients = new Map()

View File

@ -32,7 +32,7 @@ const configPath = args["--config"] || "config.json"
console.log("[Main] Reading config from", configPath) console.log("[Main] Reading config from", configPath)
const config = JSON.parse(fs.readFileSync(configPath).toString()) const config = JSON.parse(fs.readFileSync(configPath).toString())
const manager = new ClientManager(config.listen) const manager = new ClientManager(config.listen, config.register_timeout)
function stop() { function stop() {
manager.stop().then(() => { manager.stop().then(() => {