Skip to content

Commit

Permalink
rpc: implement python async iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Mar 3, 2023
1 parent b2e5801 commit 096c036
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 93 deletions.
27 changes: 1 addition & 26 deletions server/python/plugin_remote.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import gc
import sys
import os
Expand All @@ -18,7 +17,7 @@
from collections.abc import Mapping
from io import StringIO
from os import sys
from typing import Any, List, Optional, Set, Tuple
from typing import Any, Optional, Set, Tuple

import scrypted_python.scrypted_sdk.types
from scrypted_python.scrypted_sdk import ScryptedStatic, PluginFork
Expand Down Expand Up @@ -202,30 +201,6 @@ async def requestRestart(self) -> None:
def getDeviceStorage(self, nativeId: str = None) -> Storage:
return self.nativeIds.get(nativeId, None)


class BufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
return base64.b64encode(value).decode('utf8')

def deserialize(self, value, serializationContext):
return base64.b64decode(value)


class SidebandBufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
buffers = serializationContext.get('buffers', None)
if not buffers:
buffers = []
serializationContext['buffers'] = buffers
buffers.append(value)
return len(buffers) - 1

def deserialize(self, value, serializationContext):
buffers: List = serializationContext.get('buffers', None)
buffer = buffers.pop()
return buffer


class PluginRemote:
systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {}
nativeIds: Mapping[str, DeviceStorage] = {}
Expand Down
21 changes: 21 additions & 0 deletions server/python/rpc-iterator-test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sys
import asyncio
from rpc_reader import prepare_peer_readloop

async def main():
peer, peerReadLoop = await prepare_peer_readloop(loop, 4, 3)
peer.params['foo'] = 3

async def ticker(delay, to):
for i in range(to):
# print(i)
yield i
await asyncio.sleep(delay)

peer.params['ticker'] = ticker(0, 3)

print('python starting')
await peerReadLoop()

loop = asyncio.new_event_loop()
loop.run_until_complete(main())
103 changes: 78 additions & 25 deletions server/python/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect
from typing_extensions import TypedDict
import weakref
import sys

jsonSerializable = set()
jsonSerializable.add(float)
Expand All @@ -16,14 +15,16 @@


async def maybe_await(value):
if (inspect.iscoroutinefunction(value) or inspect.iscoroutine(value)):
if (inspect.isawaitable(value)):
return await value
return value


class RpcResultException(Exception):
name = None
stack = None
class RPCResultError(Exception):
name: str
stack: str
message: str
caught: Exception

def __init__(self, caught, message):
self.caught = caught
Expand Down Expand Up @@ -85,6 +86,8 @@ def __apply__(self, method: str, args: list):


class RpcPeer:
RPC_RESULT_ERROR_NAME = 'RPCResultError'

def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None:
self.send = send
self.idCounter = 1
Expand Down Expand Up @@ -127,10 +130,52 @@ async def send(id: str, reject: Callable[[Exception], None]):
def kill(self):
self.killed = True

def createErrorResult(self, result: Any, name: str, message: str, tb: str):
result['stack'] = tb if tb else 'no stack'
result['result'] = name if name else 'no name'
result['message'] = message if message else 'no message'
def createErrorResult(self, result: Any, e: Exception):
s = self.serializeError(e)
result['result'] = s
result['throw'] = True

# TODO 3/2/2023 deprecate these properties
tb = traceback.format_exc()
message = str(e)
result['stack'] = tb or '[no stack]',
result['message'] = message or '[no message]',
# END TODO


def deserializeError(e: Dict) -> RPCResultError:
error = RPCResultError(None, e.get('message'))
error.stack = e.get('stack')
error.name = e.get('name')
return error

def serializeError(self, e: Exception):
tb = traceback.format_exc()
name = type(e).__name__
message = str(e)

serialized = {
'stack': tb or '[no stack]',
'name': name or '[no name]',
'message': message or '[no message]',
}

return {
'__remote_constructor_name': RpcPeer.RPC_RESULT_ERROR_NAME,
'__serialized_value': serialized,
}

def getProxyProperties(self, value):
if not hasattr(value, '__aiter__') or not hasattr(value, '__anext__'):
return getattr(value, '__proxy_props', None)

props = getattr(value, '__proxy_props', None) or {}
props['Symbol(Symbol.asyncIterator)'] = {
'next': '__anext__',
'throw': 'athrow',
'return': 'asend',
}
return props

def serialize(self, value, requireProxy, serializationContext: Dict):
if (not value or (not requireProxy and type(value) in jsonSerializable)):
Expand All @@ -139,6 +184,9 @@ def serialize(self, value, requireProxy, serializationContext: Dict):
__remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr(
value, '__proxy_constructor') else type(value).__name__

if isinstance(value, Exception):
return self.serializeError(value)

proxiedEntry = self.localProxied.get(value, None)
if proxiedEntry:
proxiedEntry['finalizerId'] = str(self.proxyCounter)
Expand All @@ -147,7 +195,7 @@ def serialize(self, value, requireProxy, serializationContext: Dict):
'__remote_proxy_id': proxiedEntry['id'],
'__remote_proxy_finalizer_id': proxiedEntry['finalizerId'],
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
}
return ret
Expand All @@ -170,7 +218,7 @@ def serialize(self, value, requireProxy, serializationContext: Dict):
'__remote_proxy_id': None,
'__remote_proxy_finalizer_id': None,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
'__serialized_value': serialized,
}
Expand All @@ -189,7 +237,7 @@ def serialize(self, value, requireProxy, serializationContext: Dict):
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxyId,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
}

Expand Down Expand Up @@ -235,6 +283,9 @@ def deserialize(self, value, deserializationContext: Dict):
__remote_proxy_oneway_methods = value.get(
'__remote_proxy_oneway_methods', None)

if __remote_constructor_name == RpcPeer.RPC_RESULT_ERROR_NAME:
return self.deserializeError(__serialized_value);

if __remote_proxy_id:
weakref = self.remoteWeakProxies.get('__remote_proxy_id', None)
proxy = weakref() if weakref else None
Expand All @@ -247,7 +298,7 @@ def deserialize(self, value, deserializationContext: Dict):
if __local_proxy_id:
ret = self.localProxyMap.get(__local_proxy_id, None)
if not ret:
raise RpcResultException(
raise RPCResultError(
None, 'invalid local proxy id %s' % __local_proxy_id)
return ret

Expand All @@ -258,7 +309,7 @@ def deserialize(self, value, deserializationContext: Dict):

return value

async def handleMessage(self, message: Any, deserializationContext: Dict):
async def handleMessage(self, message: Dict, deserializationContext: Dict):
try:
messageType = message['type']
if messageType == 'param':
Expand Down Expand Up @@ -310,11 +361,10 @@ async def handleMessage(self, message: Any, deserializationContext: Dict):
value = await maybe_await(target(*args))

result['result'] = self.serialize(value, False, serializationContext)
except StopAsyncIteration as e:
self.createErrorResult(result, e)
except Exception as e:
tb = traceback.format_exc()
# print('failure', method, e, tb)
self.createErrorResult(
result, type(e).__name__, str(e), tb)
self.createErrorResult(result, e)

if not message.get('oneway', False):
self.send(result, None, serializationContext)
Expand All @@ -323,18 +373,21 @@ async def handleMessage(self, message: Any, deserializationContext: Dict):
id = message['id']
future = self.pendingResults.get(id, None)
if not future:
raise RpcResultException(
raise RPCResultError(
None, 'unknown result %s' % id)
del self.pendingResults[id]
if hasattr(message, 'message') or hasattr(message, 'stack'):
e = RpcResultException(
if (hasattr(message, 'message') or hasattr(message, 'stack')) and not hasattr(message, 'throw'):
e = RPCResultError(
None, message.get('message', None))
e.stack = message.get('stack', None)
e.name = message.get('name', None)
future.set_exception(e)
return
future.set_result(self.deserialize(
message.get('result', None), deserializationContext))
deserialized = self.deserialize(message.get('result', None), deserializationContext)
if message.get('throw'):
future.set_exception(deserialized)
else:
future.set_result(deserialized)
elif messageType == 'finalize':
finalizerId = message.get('__local_proxy_finalizer_id', None)
proxyId = message['__local_proxy_id']
Expand All @@ -347,7 +400,7 @@ async def handleMessage(self, message: Any, deserializationContext: Dict):
self.localProxied.pop(local, None)
local = self.localProxyMap.pop(proxyId, None)
else:
raise RpcResultException(
raise RPCResultError(
None, 'unknown rpc message type %s' % type)
except Exception as e:
print("unhandled rpc error", self.peerName, e)
Expand All @@ -361,7 +414,7 @@ async def createPendingResult(self, cb: Callable[[str, Callable[[Exception], Non
self.idCounter = self.idCounter + 1
future = Future()
self.pendingResults[id] = future
await cb(id, lambda e: future.set_exception(RpcResultException(e, None)))
await cb(id, lambda e: future.set_exception(RPCResultError(e, None)))
return await future

async def getParam(self, param):
Expand Down
21 changes: 2 additions & 19 deletions server/python/rpc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,16 @@

import asyncio
import base64
import gc
import json
import sys
import os
import platform
import shutil
import subprocess
import sys
import threading
import time
import traceback
import zipfile
from asyncio.events import AbstractEventLoop
from asyncio.futures import Future
from asyncio.streams import StreamReader, StreamWriter
from collections.abc import Mapping
from io import StringIO
from os import sys
from typing import Any, List, Optional, Set, Tuple
from typing import List

import aiofiles
import scrypted_python.scrypted_sdk.types
from scrypted_python.scrypted_sdk import ScryptedStatic, PluginFork
from scrypted_python.scrypted_sdk.types import Device, DeviceManifest, EventDetails, ScryptedInterfaceProperty, Storage
from typing_extensions import TypedDict
import rpc
import multiprocessing
import multiprocessing.connection


class BufferSerializer(rpc.RpcSerializer):
Expand Down
Loading

0 comments on commit 096c036

Please sign in to comment.