2022-02-25 02:22:50 -05:00
# matrix-appservice-kakaotalk - A Matrix-KakaoTalk puppeting bridge.
# Copyright (C) 2022 Tulir Asokan, Andrew Ferrazzutti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
2022-03-18 03:52:55 -04:00
from typing import Any , Callable , Awaitable
2022-02-25 02:22:50 -05:00
import asyncio
import json
import logging
from mautrix . types . primitive import JSON
from . . config import Config
from . types import RPCError
EventHandler = Callable [ [ dict [ str , Any ] ] , Awaitable [ None ] ]
class RPCClient :
config : Config
loop : asyncio . AbstractEventLoop
log : logging . Logger = logging . getLogger ( " mau.rpc " )
_reader : a syncio . StreamReader | None
_writer : asyncio . StreamWriter | None
_req_id : int
_min_broadcast_id : int
_response_waiters : dict [ int , asyncio . Future [ JSON ] ]
2022-03-18 03:52:55 -04:00
_event_handlers : dict [ str , list [ EventHandler ] ]
2022-02-25 02:22:50 -05:00
_command_queue : asyncio . Queue
def __init__ ( self , config : Config ) - > None :
self . config = config
self . loop = asyncio . get_running_loop ( )
self . _req_id = 0
self . _min_broadcast_id = 0
self . _event_handlers = { }
self . _response_waiters = { }
self . _writer = None
self . _reader = None
self . _command_queue = asyncio . Queue ( )
async def connect ( self ) - > None :
if self . _writer is not None :
return
if self . config [ " rpc.connection.type " ] == " unix " :
while True :
try :
r , w = await asyncio . open_unix_connection ( self . config [ " rpc.connection.path " ] )
break
except :
self . log . warn ( f ' No unix socket available at { self . config [ " rpc.connection.path " ] } , wait for it to exist... ' )
await asyncio . sleep ( 10 )
elif self . config [ " rpc.connection.type " ] == " tcp " :
while True :
try :
r , w = await asyncio . open_connection ( self . config [ " rpc.connection.host " ] ,
self . config [ " rpc.connection.port " ] )
break
except :
self . log . warn ( f ' No TCP connection open at { self . config [ " rpc.connection.host " ] } : { self . config [ " rpc.connection.path " ] } , wait for it to become available... ' )
await asyncio . sleep ( 10 )
else :
raise RuntimeError ( " invalid rpc connection type " )
self . _reader = r
self . _writer = w
self . loop . create_task ( self . _try_read_loop ( ) )
self . loop . create_task ( self . _command_loop ( ) )
await self . request ( " register " , peer_id = self . config [ " appservice.address " ] )
async def disconnect ( self ) - > None :
2022-03-21 01:14:52 -04:00
if self . _writer is not None :
self . _writer . write_eof ( )
await self . _writer . drain ( )
self . _writer = None
self . _reader = None
2022-02-25 02:22:50 -05:00
@property
def _next_req_id ( self ) - > int :
self . _req_id + = 1
return self . _req_id
def add_event_handler ( self , method : str , handler : EventHandler ) - > None :
self . _event_handlers . setdefault ( method , [ ] ) . append ( handler )
def remove_event_handler ( self , method : str , handler : EventHandler ) - > None :
2022-03-18 03:52:55 -04:00
try :
self . _event_handlers . setdefault ( method , [ ] ) . remove ( handler )
except ValueError :
pass
def set_event_handlers ( self , method : str , handlers : list [ EventHandler ] ) - > None :
self . _event_handlers [ method ] = handlers
2022-02-25 02:22:50 -05:00
async def _run_event_handler ( self , req_id : int , command : str , req : dict [ str , Any ] ) - > None :
if req_id > self . _min_broadcast_id :
self . log . debug ( f " Ignoring duplicate broadcast { req_id } " )
return
self . _min_broadcast_id = req_id
try :
handlers = self . _event_handlers [ command ]
except KeyError :
self . log . warning ( " No handlers for %s " , command )
else :
for handler in handlers :
try :
await handler ( req )
except Exception :
self . log . exception ( " Exception in event handler " )
async def _handle_incoming_line ( self , line : str ) - > None :
try :
req = json . loads ( line )
except json . JSONDecodeError :
self . log . debug ( f " Got non-JSON data from server: { line } " )
return
try :
req_id = req . pop ( " id " )
command = req . pop ( " command " )
is_sequential = req . pop ( " is_sequential " , False )
except KeyError :
self . log . debug ( f " Got invalid request from server: { line } " )
return
if req_id < 0 :
if not is_sequential :
self . loop . create_task ( self . _run_event_handler ( req_id , command , req ) )
else :
self . _command_queue . put_nowait ( ( req_id , command , req ) )
return
try :
waiter = self . _response_waiters [ req_id ]
except KeyError :
self . log . debug ( f " Nobody waiting for response to { req_id } " )
return
if command == " response " :
waiter . set_result ( req . get ( " response " ) )
elif command == " error " :
waiter . set_exception ( RPCError ( req . get ( " error " , line ) ) )
else :
self . log . warning ( f " Unexpected response command to { req_id } : { command } { req } " )
async def _command_loop ( self ) - > None :
while True :
req_id , command , req = await self . _command_queue . get ( )
await self . _run_event_handler ( req_id , command , req )
self . _command_queue . task_done ( )
async def _try_read_loop ( self ) - > None :
try :
await self . _read_loop ( )
except Exception :
self . log . exception ( " Fatal error in read loop " )
async def _read_loop ( self ) - > None :
while self . _reader is not None and not self . _reader . at_eof ( ) :
line = b ' '
while True :
try :
line + = await self . _reader . readuntil ( )
break
except asyncio . IncompleteReadError as e :
line + = e . partial
break
except asyncio . LimitOverrunError as e :
self . log . warning ( f " Buffer overrun: { e } " )
line + = await self . _reader . read ( self . _reader . _limit )
if not line :
continue
try :
line_str = line . decode ( " utf-8 " )
except UnicodeDecodeError :
self . log . exception ( " Got non-unicode request from server: %s " , line )
continue
try :
await self . _handle_incoming_line ( line_str )
except Exception :
self . log . exception ( " Failed to handle incoming request %s " , line_str )
self . log . debug ( " Reader disconnected " )
self . _reader = None
self . _writer = None
async def _raw_request ( self , command : str , is_secret : bool = False , * * data : JSON ) - > asyncio . Future [ JSON ] :
req_id = self . _next_req_id
future = self . _response_waiters [ req_id ] = self . loop . create_future ( )
req = { " id " : req_id , " command " : command , * * data }
self . log . debug ( " Request %d : %s %s " , req_id , command , data if not is_secret else " <REDACTED> " )
assert self . _writer is not None
self . _writer . write ( json . dumps ( req ) . encode ( " utf-8 " ) )
self . _writer . write ( b " \n " )
await self . _writer . drain ( )
return future
async def request ( self , command : str , * * data : JSON ) - > JSON :
future = await self . _raw_request ( command , * * data )
return await future