Skip to content

Commit

Permalink
server: clean up python rpc transports
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Mar 18, 2023
1 parent fae6661 commit 7739903
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 93 deletions.
4 changes: 2 additions & 2 deletions server/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 18 additions & 37 deletions server/python/plugin_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,6 @@ class SystemDeviceState(TypedDict):
stateTime: int
value: any


class StreamPipeReader:
def __init__(self, conn: multiprocessing.connection.Connection) -> None:
self.conn = conn
self.executor = concurrent.futures.ThreadPoolExecutor()

def readBlocking(self, n):
b = bytes(0)
while len(b) < n:
self.conn.poll(None)
add = os.read(self.conn.fileno(), n - len(b))
if not len(add):
raise Exception('unable to read requested bytes')
b += add
return b

async def read(self, n):
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.readBlocking(n))

class SystemManager(scrypted_python.scrypted_sdk.types.SystemManager):
def __init__(self, api: Any, systemState: Mapping[str, Mapping[str, SystemDeviceState]]) -> None:
super().__init__()
Expand Down Expand Up @@ -288,8 +269,9 @@ async def loadZip(self, packageJson, zipData, options: dict=None):
clusterSecret = options['clusterSecret']

async def handleClusterClient(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
rpcTransport = rpc_reader.RpcStreamTransport(reader, writer)
peer: rpc.RpcPeer
peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, reader = reader, writer = writer)
peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport)
async def connectRPCObject(id: str, secret: str):
m = hashlib.sha256()
m.update(bytes('%s%s' % (clusterPort, clusterSecret), 'utf8'))
Expand Down Expand Up @@ -324,7 +306,8 @@ async def connectRPCObject(value):
async def connectClusterPeer():
reader, writer = await asyncio.open_connection(
'127.0.0.1', port)
peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, reader = reader, writer = writer)
rpcTransport = rpc_reader.RpcStreamTransport(reader, writer)
peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport)
async def run_loop():
try:
await peerReadLoop()
Expand Down Expand Up @@ -485,8 +468,8 @@ def exit_check():
schedule_exit_check()

async def getFork():
reader = StreamPipeReader(parent_conn)
forkPeer, readLoop = await rpc_reader.prepare_peer_readloop(self.loop, reader = reader, writeFd = parent_conn.fileno())
rpcTransport = rpc_reader.RpcConnectionTransport(parent_conn)
forkPeer, readLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport)
forkPeer.peerName = 'thread'

async def updateStats(stats):
Expand All @@ -502,7 +485,7 @@ async def forkReadLoop():
finally:
allMemoryStats.pop(forkPeer)
parent_conn.close()
reader.executor.shutdown()
rpcTransport.executor.shutdown()
asyncio.run_coroutine_threadsafe(forkReadLoop(), loop=self.loop)
getRemote = await forkPeer.getParam('getRemote')
remote: PluginRemote = await getRemote(self.api, self.pluginId, self.hostInfo)
Expand Down Expand Up @@ -594,8 +577,8 @@ async def getServicePort(self, name):

allMemoryStats = {}

async def plugin_async_main(loop: AbstractEventLoop, readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None):
peer, readLoop = await rpc_reader.prepare_peer_readloop(loop, readFd=readFd, writeFd=writeFd, reader=reader, writer=writer)
async def plugin_async_main(loop: AbstractEventLoop, rpcTransport: rpc_reader.RpcTransport):
peer, readLoop = await rpc_reader.prepare_peer_readloop(loop, rpcTransport)
peer.params['print'] = print
peer.params['getRemote'] = lambda api, pluginId, hostInfo: PluginRemote(peer, api, pluginId, hostInfo, loop)

Expand Down Expand Up @@ -642,22 +625,22 @@ def stats_runner():
try:
await readLoop()
finally:
if reader and hasattr(reader, 'executor'):
r: StreamPipeReader = reader
if type(rpcTransport) == rpc_reader.RpcConnectionTransport:
r: rpc_reader.RpcConnectionTransport = rpcTransport
r.executor.shutdown()

def main(readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None):
def main(rpcTransport: rpc_reader.RpcTransport):
loop = asyncio.new_event_loop()

def gc_runner():
gc.collect()
loop.call_later(10, gc_runner)
gc_runner()

loop.run_until_complete(plugin_async_main(loop, readFd=readFd, writeFd=writeFd, reader=reader, writer=writer))
loop.run_until_complete(plugin_async_main(loop, rpcTransport))
loop.close()

def plugin_main(readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None):
def plugin_main(rpcTransport: rpc_reader.RpcTransport):
try:
import gi
gi.require_version('Gst', '1.0')
Expand All @@ -666,18 +649,16 @@ def plugin_main(readFd: int = None, writeFd: int = None, reader: asyncio.StreamR

loop = GLib.MainLoop()

worker = threading.Thread(target=main, args=(readFd, writeFd, reader, writer), name="asyncio-main")
worker = threading.Thread(target=main, args=(rpcTransport,), name="asyncio-main")
worker.start()

loop.run()
except:
main(readFd=readFd, writeFd=writeFd, reader=reader, writer=writer)
main(rpcTransport)


def plugin_fork(conn: multiprocessing.connection.Connection):
fd = os.dup(conn.fileno())
reader = StreamPipeReader(conn)
plugin_main(reader=reader, writeFd=fd)
plugin_main(rpc_reader.RpcConnectionTransport(conn))

if __name__ == "__main__":
plugin_main(3, 4)
plugin_main(rpc_reader.RpcFileTransport(3, 4))
168 changes: 114 additions & 54 deletions server/python/rpc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import base64
import json
import os
import sys
import threading
from asyncio.events import AbstractEventLoop
from os import sys
from typing import List

from typing import List, Any
import multiprocessing.connection
import aiofiles
import rpc

import concurrent.futures
import json

class BufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
Expand All @@ -36,78 +35,139 @@ def deserialize(self, value, serializationContext):
buffer = buffers.pop()
return buffer

async def readLoop(loop, peer: rpc.RpcPeer, reader: asyncio.StreamReader):
deserializationContext = {
'buffers': []
}
class RpcTransport:
async def prepare(self):
pass

if isinstance(reader, asyncio.StreamReader):
async def read(n):
return await reader.readexactly(n)
else:
async def read(n):
return await reader.read(n)
async def read(self):
pass

def writeBuffer(self, buffer, reject):
pass

while True:
lengthBytes = await read(4)
typeBytes = await read(1)
def writeJSON(self, json, reject):
pass

class RpcFileTransport(RpcTransport):
reader: asyncio.StreamReader

def __init__(self, readFd: int, writeFd: int) -> None:
super().__init__()
self.readFd = readFd
self.writeFd = writeFd
self.reader = None

async def prepare(self):
await super().prepare()
self.reader = await aiofiles.open(self.readFd, mode='rb')

async def read(self):
lengthBytes = await self.reader.read(4)
typeBytes = await self.reader.read(1)
type = typeBytes[0]
length = int.from_bytes(lengthBytes, 'big')
data = await read(length - 1)

data = await self.reader.read(length - 1)
if type == 1:
deserializationContext['buffers'].append(data)
return data
message = json.loads(data)
return message

def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
try:
for b in [lb, bytes([type]), buffer]:
os.write(self.writeFd, b)
except Exception as e:
if reject:
reject(e)

def writeJSON(self, j, reject):
return self.writeMessage(0, bytes(json.dumps(j), 'utf8'), reject)

def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)

class RpcStreamTransport(RpcTransport):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__()
self.reader = reader
self.writer = writer

async def read(self, n: int):
return await self.reader.readexactly(n)

def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
try:
for b in [lb, bytes([type]), buffer]:
self.writer.write(b)
except Exception as e:
if reject:
reject(e)

def writeJSON(self, j, reject):
return self.writeMessage(0, bytes(json.dumps(j), 'utf8'), reject)

def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)

class RpcConnectionTransport(RpcTransport):
def __init__(self, connection: multiprocessing.connection.Connection) -> None:
super().__init__()
self.connection = connection
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

async def read(self):
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv())

def writeMessage(self, json, reject):
try:
self.connection.send(json)
except Exception as e:
if reject:
reject(e)

def writeJSON(self, json, reject):
return self.writeMessage(json, reject)

def writeBuffer(self, buffer, reject):
return self.writeMessage(bytes(buffer), reject)

async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport):
deserializationContext = {
'buffers': []
}

while True:
message = await rpcTransport.read()

if type(message) != dict:
deserializationContext['buffers'].append(message)
continue

message = json.loads(data)
asyncio.run_coroutine_threadsafe(
peer.handleMessage(message, deserializationContext), loop)

deserializationContext = {
'buffers': []
}

async def prepare_peer_readloop(loop: AbstractEventLoop, readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None):
reader = reader or await aiofiles.open(readFd, mode='rb')
async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport):
await rpcTransport.prepare()

mutex = threading.Lock()

if writer:
def write(buffers, reject):
try:
for b in buffers:
writer.write(b)
except Exception as e:
if reject:
reject(e)
return None
else:
def write(buffers, reject):
try:
for b in buffers:
os.write(writeFd, b)
except Exception as e:
if reject:
reject(e)

def send(message, reject=None, serializationContext=None):
with mutex:
if serializationContext:
buffers = serializationContext.get('buffers', None)
if buffers:
for buffer in buffers:
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
type = 1
write([lb, bytes([type]), buffer], reject)

jsonString = json.dumps(message)
b = bytes(jsonString, 'utf8')
length = len(b) + 1
lb = length.to_bytes(4, 'big')
type = 0
write([lb, bytes([type]), b], reject)
rpcTransport.writeBuffer(buffer, reject)

rpcTransport.writeJSON(message, reject)

peer = rpc.RpcPeer(send)
peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer()
Expand All @@ -117,7 +177,7 @@ def send(message, reject=None, serializationContext=None):

async def peerReadLoop():
try:
await readLoop(loop, peer, reader)
await readLoop(loop, peer, rpcTransport)
except:
peer.kill()
raise
Expand Down

0 comments on commit 7739903

Please sign in to comment.