2022-02-25 02:22:50 -05:00
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
2022-03-18 03:52:55 -04:00
from typing import Any , Callable , Awaitable
2022-02-25 02:22:50 -05:00
import asyncio
import json
import logging
from mautrix . types . primitive import JSON
from . . config import Config
from . types import RPCError
EventHandler = Callable [ [ dict [ str , Any ] ] , Awaitable [ None ] ]
2022-04-08 05:04:46 -04:00
class CancelableEvent :
_event : asyncio . Event
_task : asyncio . Task | None
_cancelled : bool
_loop : asyncio . AbstractEventLoop
def __init__ ( self , loop : asyncio . AbstractEventLoop | None ) :
self . _event = asyncio . Event ( )
self . _task = None
self . _cancelled = False
self . _loop = loop or asyncio . get_running_loop ( )
def is_set ( self ) - > bool :
return self . _event . is_set ( )
def set ( self ) - > None :
self . _event . set ( )
self . _task = None
def clear ( self ) - > None :
self . _event . clear ( )
async def wait ( self ) - > None :
if self . _cancelled :
raise asyncio . CancelledError ( )
if self . _event . is_set ( ) :
return
if not self . _task :
self . _task = asyncio . create_task ( self . _event . wait ( ) )
await self . _task
def cancel ( self ) - > None :
self . _cancelled = True
if self . _task is not None :
self . _task . cancel ( )
def cancelled ( self ) - > bool :
return self . _cancelled
2022-02-25 02:22:50 -05:00
class RPCClient :
config : Config
loop : asyncio . AbstractEventLoop
log : logging . Logger = logging . getLogger ( " mau.rpc " )
_reader : asyncio . StreamReader | None
_writer : asyncio . StreamWriter | None
_req_id : int
_min_broadcast_id : int
_response_waiters : dict [ int , asyncio . Future [ JSON ] ]
2022-03-18 03:52:55 -04:00
_event_handlers : dict [ str , list [ EventHandler ] ]
2022-02-25 02:22:50 -05:00
_command_queue : asyncio . Queue
2022-04-08 05:04:46 -04:00
_read_task : asyncio . Task | None
_connection_task : asyncio . Task | None
_is_connected : CancelableEvent
_is_disconnected : CancelableEvent
_connection_lock : asyncio . Lock
2022-02-25 02:22:50 -05:00
def __init__ ( self , config : Config ) - > None :
self . config = config
self . loop = asyncio . get_running_loop ( )
self . _req_id = 0
self . _min_broadcast_id = 0
self . _event_handlers = { }
self . _response_waiters = { }
self . _writer = None
self . _reader = None
self . _command_queue = asyncio . Queue ( )
2022-04-08 05:04:46 -04:00
self . loop . create_task ( self . _command_loop ( ) )
self . _read_task = None
self . _connection_task = None
self . _is_connected = CancelableEvent ( self . loop )
self . _is_disconnected = CancelableEvent ( self . loop )
self . _is_disconnected . set ( )
self . _connection_lock = asyncio . Lock ( )
2022-02-25 02:22:50 -05:00
async def connect ( self ) - > None :
2022-04-08 05:04:46 -04:00
async with self . _connection_lock :
if self . _is_connected . cancelled ( ) :
raise asyncio . CancelledError ( )
if self . _is_connected . is_set ( ) :
return
self . _connection_task = self . loop . create_task ( self . _connect ( ) )
try :
await self . _connection_task
finally :
self . _connection_task = None
2022-02-25 02:22:50 -05:00
2022-04-08 05:04:46 -04:00
async def _connect ( self ) - > None :
2022-02-25 02:22:50 -05:00
if self . config [ " rpc.connection.type " ] == " unix " :
while True :
try :
r , w = await asyncio . open_unix_connection ( self . config [ " rpc.connection.path " ] )
break
2022-04-08 05:04:46 -04:00
except asyncio . CancelledError :
raise
2022-02-25 02:22:50 -05:00
except :
2022-04-01 05:09:31 -04:00
self . log . warning ( f ' No unix socket available at { self . config [ " rpc.connection.path " ] } , wait for it to exist... ' )
2022-02-25 02:22:50 -05:00
await asyncio . sleep ( 10 )
elif self . config [ " rpc.connection.type " ] == " tcp " :
while True :
try :
r , w = await asyncio . open_connection ( self . config [ " rpc.connection.host " ] ,
self . config [ " rpc.connection.port " ] )
break
2022-04-08 05:04:46 -04:00
except asyncio . CancelledError :
raise
2022-02-25 02:22:50 -05:00
except :
2022-04-01 05:09:31 -04:00
self . log . warning ( f ' No TCP connection open at { self . config [ " rpc.connection.host " ] } : { self . config [ " rpc.connection.path " ] } , wait for it to become available... ' )
2022-02-25 02:22:50 -05:00
await asyncio . sleep ( 10 )
else :
raise RuntimeError ( " invalid rpc connection type " )
self . _reader = r
self . _writer = w
2022-04-08 05:04:46 -04:00
self . _read_task = self . loop . create_task ( self . _try_read_loop ( ) )
2022-04-11 04:48:48 -04:00
await self . _raw_request ( " register " , peer_id = self . config [ " appservice.address " ] )
2022-04-08 05:04:46 -04:00
self . _is_connected . set ( )
self . _is_disconnected . clear ( )
2022-02-25 02:22:50 -05:00
async def disconnect ( self ) - > None :
2022-04-08 05:04:46 -04:00
async with self . _connection_lock :
if self . _is_disconnected . cancelled ( ) :
raise asyncio . CancelledError ( )
if self . _is_disconnected . is_set ( ) :
return
await self . _disconnect ( )
async def _disconnect ( self ) - > None :
2022-03-21 01:14:52 -04:00
if self . _writer is not None :
self . _writer . write_eof ( )
await self . _writer . drain ( )
2022-04-08 05:04:46 -04:00
if self . _read_task is not None :
self . _read_task . cancel ( )
self . _read_task = None
self . _on_disconnect ( )
def _on_disconnect ( self ) - > None :
2022-04-13 03:47:27 -04:00
self . _req_id = 0
self . _min_broadcast_id = 0
2022-04-08 05:04:46 -04:00
self . _reader = None
self . _writer = None
self . _is_connected . clear ( )
self . _is_disconnected . set ( )
def wait_for_connection ( self ) - > Awaitable [ None ] :
return self . _is_connected . wait ( )
def wait_for_disconnection ( self ) - > Awaitable [ None ] :
return self . _is_disconnected . wait ( )
def cancel ( self ) - > None :
self . _is_connected . cancel ( )
self . _is_disconnected . cancel ( )
if self . _connection_task is not None :
self . _connection_task . cancel ( )
asyncio . run ( self . _disconnect ( ) )
2022-02-25 02:22:50 -05:00
@property
def _next_req_id ( self ) - > int :
self . _req_id + = 1
return self . _req_id
def add_event_handler ( self , method : str , handler : EventHandler ) - > None :
self . _event_handlers . setdefault ( method , [ ] ) . append ( handler )
def remove_event_handler ( self , method : str , handler : EventHandler ) - > None :
2022-03-18 03:52:55 -04:00
try :
self . _event_handlers . setdefault ( method , [ ] ) . remove ( handler )
except ValueError :
pass
def set_event_handlers ( self , method : str , handlers : list [ EventHandler ] ) - > None :
self . _event_handlers [ method ] = handlers
2022-02-25 02:22:50 -05:00
async def _run_event_handler ( self , req_id : int , command : str , req : dict [ str , Any ] ) - > None :
if req_id > self . _min_broadcast_id :
self . log . debug ( f " Ignoring duplicate broadcast { req_id } " )
return
self . _min_broadcast_id = req_id
try :
handlers = self . _event_handlers [ command ]
except KeyError :
self . log . warning ( " No handlers for %s " , command )
else :
for handler in handlers :
try :
await handler ( req )
2022-04-08 05:04:46 -04:00
except :
2022-02-25 02:22:50 -05:00
self . log . exception ( " Exception in event handler " )
async def _handle_incoming_line ( self , line : str ) - > None :
try :
req = json . loads ( line )
except json . JSONDecodeError :
self . log . debug ( f " Got non-JSON data from server: { line } " )
return
try :
req_id = req . pop ( " id " )
command = req . pop ( " command " )
is_sequential = req . pop ( " is_sequential " , False )
except KeyError :
self . log . debug ( f " Got invalid request from server: { line } " )
return
if req_id < 0 :
if not is_sequential :
self . loop . create_task ( self . _run_event_handler ( req_id , command , req ) )
else :
self . _command_queue . put_nowait ( ( req_id , command , req ) )
return
try :
waiter = self . _response_waiters [ req_id ]
except KeyError :
self . log . debug ( f " Nobody waiting for response to { req_id } " )
return
if command == " response " :
waiter . set_result ( req . get ( " response " ) )
elif command == " error " :
waiter . set_exception ( RPCError ( req . get ( " error " , line ) ) )
else :
self . log . warning ( f " Unexpected response command to { req_id } : { command } { req } " )
async def _command_loop ( self ) - > None :
while True :
req_id , command , req = await self . _command_queue . get ( )
await self . _run_event_handler ( req_id , command , req )
self . _command_queue . task_done ( )
async def _try_read_loop ( self ) - > None :
try :
await self . _read_loop ( )
2022-04-08 05:04:46 -04:00
except asyncio . CancelledError :
2022-04-11 04:48:48 -04:00
return
2022-04-08 05:04:46 -04:00
except :
2022-02-25 02:22:50 -05:00
self . log . exception ( " Fatal error in read loop " )
2022-04-11 04:48:48 -04:00
self . log . debug ( " Reader disconnected " )
self . _on_disconnect ( )
2022-02-25 02:22:50 -05:00
async def _read_loop ( self ) - > None :
while self . _reader is not None and not self . _reader . at_eof ( ) :
line = b ' '
while True :
try :
line + = await self . _reader . readuntil ( )
break
except asyncio . IncompleteReadError as e :
line + = e . partial
break
except asyncio . LimitOverrunError as e :
self . log . warning ( f " Buffer overrun: { e } " )
line + = await self . _reader . read ( self . _reader . _limit )
2022-04-08 05:04:46 -04:00
except asyncio . CancelledError :
raise
2022-02-25 02:22:50 -05:00
if not line :
continue
try :
line_str = line . decode ( " utf-8 " )
except UnicodeDecodeError :
self . log . exception ( " Got non-unicode request from server: %s " , line )
continue
try :
await self . _handle_incoming_line ( line_str )
2022-04-08 05:04:46 -04:00
except asyncio . CancelledError :
raise
except :
2022-02-25 02:22:50 -05:00
self . log . exception ( " Failed to handle incoming request %s " , line_str )
2022-04-14 04:26:35 -04:00
async def _raw_request ( self , command : str , * * data : JSON ) - > asyncio . Future [ JSON ] :
2022-02-25 02:22:50 -05:00
req_id = self . _next_req_id
future = self . _response_waiters [ req_id ] = self . loop . create_future ( )
req = { " id " : req_id , " command " : command , * * data }
2022-04-14 04:26:35 -04:00
self . log . debug ( " Request %d : %s " , req_id , command )
2022-02-25 02:22:50 -05:00
assert self . _writer is not None
self . _writer . write ( json . dumps ( req ) . encode ( " utf-8 " ) )
self . _writer . write ( b " \n " )
await self . _writer . drain ( )
return future
async def request ( self , command : str , * * data : JSON ) - > JSON :
2022-04-08 05:04:46 -04:00
await self . wait_for_connection ( )
2022-02-25 02:22:50 -05:00
future = await self . _raw_request ( command , * * data )
return await future