Skip to content

Commit

Permalink
Moved statechange signals to sharedstate
Browse files Browse the repository at this point in the history
  • Loading branch information
jooste committed Dec 20, 2024
1 parent 754d504 commit 14e4378
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 120 deletions.
162 changes: 104 additions & 58 deletions bluesky/network/sharedstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,80 @@
BlueSky's sharedstate is used to keep a shared state
across client(s) and simulation node(s)
'''
from typing import Any, Generic, Optional, TypeVar
from typing import Any, Generic, Optional, Type, TypeVar
import numpy as np
from numbers import Number
from functools import partial

from types import SimpleNamespace
from copy import deepcopy
from collections import defaultdict

import bluesky as bs
from bluesky.core import signal
from bluesky.network import context as ctx
from bluesky.network.common import ActionType


# Keep track of the set of subscribed sharedstate topics. Store signals to emit
# whenever a state update of each topic is received
sigchanged: dict[str, signal.Signal] = dict()


def reset(remote_id=None):
''' Reset shared state data to defaults for remote simulation. '''
remotes[remote_id or bs.net.act_id] = _genstore()

# If this is the active node, also emit a signal about this change
if ctx.sender_id == bs.net.act_id:
ctx.action = ActionType.Reset
ctx.action_content = None
for topic, sig in sigchanged.items():
store = get(group=topic.lower())
sig.emit(store)
ctx.action = ActionType.NoAction


@signal.subscriber(topic='actnode-changed')
def on_actnode_changed(act_id):
ctx.action = ActionType.ActChange
ctx.action_content = None
for topic, sig in sigchanged.items():
store = get(group=topic.lower())
sig.emit(store)
ctx.action = ActionType.NoAction


def on_sharedstate_received(action, data):
''' Retrieve and process state data. '''
store = get(ctx.sender_id, ctx.topic.lower())

# Store sharedstate context
ctx.action = ActionType(action)
ctx.action_content = data

if ctx.action == ActionType.Update:
store.update(data)

elif ctx.action == ActionType.Append:
store.append(data)

elif ctx.action == ActionType.Extend:
store.extend(data)

elif ctx.action == ActionType.Replace:
store.replace(data)

elif ctx.action == ActionType.Delete:
store.delete(data)

# Inform subscribers of state update
# TODO: what to do with act vs all?
if ctx.sender_id == bs.net.act_id:
sigchanged[ctx.topic.lower()].emit(store)

# Reset context variables
ctx.action = ActionType.NoAction
ctx.action_content = None


def get(remote_id=None, group=None):
''' Retrieve a remote store, or a group in a remote store.
Expand All @@ -38,15 +96,19 @@ def setvalue(name, value, remote_id=None, group=None):

def setdefault(name: str, default: Any, group: Optional[str]=None):
''' Set the default value for variable 'name' in group 'group' '''
target = getattr(defaults, group, None) if group else defaults
if not target:
if group is not None:
setattr(defaults, group, Store(**{name:default}))
return
# Add group if it doesn't exist yet
if group is not None:
addtopic(group)

target = getattr(defaults, group) if group else defaults
if not hasattr(target, name):
# In case remote data already exists and this is a previously unknown variable, update stores
for remote_id in remotes.keys():
setvalue(name, deepcopy(default), remote_id, group)
setattr(target, name, default)


def addtopic(topic):
def addtopic(topic: str) -> signal.Signal:
''' Add a sharedstate topic if it doesn't yet exist.
This creates a storage group for this topic, which is added to each
Expand All @@ -56,16 +118,21 @@ def addtopic(topic):
- topic: The sharedstate topic to add
'''
topic = topic.lower()

# Create/get the signal that is emitted when this data changes
sig = signal.Signal(f'state-changed.{topic}')
sigchanged[topic] = sig

# No creation needed if topic is already known
if hasattr(defaults, topic):
return
if not hasattr(defaults, topic):
# Add store to the defaults
setattr(defaults, topic, Store())

# Add store to the defaults
setattr(defaults, topic, Store())
# Also add to existing stores if necessary
for remote in remotes.values():
setattr(remote, topic, Store())

# Also add to existing stores if necessary
for remote in remotes.values():
setattr(remote, topic, Store())
return sig


class Store(SimpleNamespace):
Expand Down Expand Up @@ -177,42 +244,31 @@ def __set_name__(self, owner, name):
# Get name from attribute name if not previously specified
self.name = name

# Retrieve annotated object type if present
objtype = owner.__annotations__.get(name)
if objtype:
self.default = objtype() if self.default is None else objtype(self.default)

# If underlying datatype is mutable, always immediately
# store per remote node
if not isinstance(self.default, (str, tuple, Number, frozenset, bytes)):
# If an annotated object type is specified create a generator for it
if objtype:
generators.append(partial(_generator, name=name, objtype=objtype, ctor_arg=self.default, group=self.group))
# Add group if it doesn't exist yet
if self.group is not None:
addtopic(self.group)

# In case remote data already exists, update stores
for remote_id in remotes.keys():
setvalue(name, objtype(self.default), remote_id, self.group)
# Otherwise assume deepcopy can be used to generate initial values per remote
else:
# If we have a default value, and it is not centrally known yet
# store it for Store generation
store = defaults if self.group is None else getattr(defaults, self.group, None)
if store is None or not hasattr(store, name):
if self.default is not None:
# If specified use our default
setdefault(name, self.default, self.group)

# In case remote data already exists, update stores
for remote_id in remotes.keys():
setvalue(name, deepcopy(self.default), remote_id, self.group)
else:
# Otherwise try to get the default type from annotation
# Look for __origin__ in case we have a GenericAlias like list[int]
tp = owner.__annotations__.get(name)
tp = getattr(tp, '__origin__', tp)
if isinstance(tp, Type):
# Exception case for NumPy ndarray,
# which has shape as a mandatory ctor argument
if tp is np.ndarray:
setdefault(name, np.ndarray(0), self.group)
else:
setdefault(name, tp(), self.group)

def __get__(self, obj, objtype=None) -> T:
''' Return the actual value for the currently active node. '''
# print('GETTER:', self, self.default)
if not bs.net.act_id and self.default is not None:
return self.default
if self.default is not None:
return getattr(
remotes[bs.net.act_id] if self.group is None else getattr(remotes[bs.net.act_id], self.group),
self.name, self.default)
raise KeyError(f'ActData: {self.name} not found, group={self.group}, active node id={bs.net.act_id}')
store = remotes[bs.net.act_id] if bs.net.act_id else defaults
return getattr(store if self.group is None else getattr(store, self.group), self.name)
# raise KeyError(f'ActData: {self.name} not found, group={self.group}, active node id={bs.net.act_id}')
# TODO: What is the active (client) node on the sim-side? Is this always within a currently processed stack command? -> stack.sender_id

def __set__(self, obj, value: T):
Expand All @@ -236,22 +292,12 @@ def _recursive_update(target, source):
def _genstore():
''' Generate a store object for a remote simulation from defaults. '''
store = deepcopy(defaults)
for g in generators:
g(store)
return store


def _generator(store, name, objtype, ctor_arg, group=None):
''' Custom generator for non-base types. '''
setattr(getattr(store, group) if group else store, name, objtype(ctor_arg))


# Keep track of default attribute values of mutable type.
# These always need to be stored per remote node.
defaults = Store()
# In some cases (such as for non-copyable types) a generator is specified
# instead of a default value
generators = list()

# Keep a dict of remote state storage namespaces
remotes = defaultdict(_genstore)
65 changes: 3 additions & 62 deletions bluesky/network/subscriber.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
from typing import Dict

import bluesky as bs
from bluesky.core.funcobject import FuncObject
Expand All @@ -12,10 +11,6 @@
#TODO:
# trigger voor actnode changed?

# Keep track of the set of subscribed sharedstate topics. Store signals to emit
# whenever a state update of each topic is received
changed: Dict[str, signal.Signal] = dict()


def subscriber(func=None, *, topic='', broadcast=True, actonly=False, raw=False, from_group=GROUPID_DEFAULT, to_group=''):
''' BlueSky network subscription decorator.
Expand Down Expand Up @@ -158,16 +153,14 @@ def _detect_type(self, *args, **kwargs):
self.msg_type = MessageType.SharedState
# In this case, all (non-raw) subscribers will be configured
# as sharedstate subscribers
ss.addtopic(self.topic)
sig = signal.Signal(f'state-changed.{self.topic}')
changed[self.topic] = sig
sig = ss.addtopic(self.topic)
while self.deferred_subs:
sig.connect(self.deferred_subs.pop())

# Finally send the sharedstate message on to the subscribers,
# and subscribe the sharedstate processing function to this topic
super().connect(on_sharedstate_received)
on_sharedstate_received(*args, **kwargs)
super().connect(ss.on_sharedstate_received)
ss.on_sharedstate_received(*args, **kwargs)

else:
self.msg_type = MessageType.Regular
Expand Down Expand Up @@ -235,25 +228,6 @@ def reset(*args):
# Clear state data to defaults for this simulation node
ss.reset(ctx.sender_id)

# If this is the active node, also emit a signal about this change
if ctx.sender_id == bs.net.act_id:
ctx.action = ActionType.Reset
ctx.action_content = None
for topic, sig in changed.items():
store = ss.get(group=topic.lower())
sig.emit(store)
ctx.action = None


@signal.subscriber(topic='actnode-changed')
def on_actnode_changed(act_id):
ctx.action = ActionType.ActChange
ctx.action_content = None
for topic, sig in changed.items():
store = ss.get(group=topic.lower())
sig.emit(store)
ctx.action = None


@signal.subscriber(topic='node-added')
def on_node_added(node_id):
Expand All @@ -263,36 +237,3 @@ def on_node_added(node_id):
topics = [topic for topic, sub in SubscriptionFactory.subscriptions.items()
if sub.msg_type in (MessageType.Unknown, MessageType.SharedState)]
bs.net.send('REQUEST', topics, to_group=node_id)


def on_sharedstate_received(action, data):
''' Retrieve and process state data. '''
store = ss.get(ctx.sender_id, ctx.topic.lower())

# Store sharedstate context
ctx.action = ActionType(action)
ctx.action_content = data

if ctx.action == ActionType.Update:
store.update(data)

elif ctx.action == ActionType.Append:
store.append(data)

elif ctx.action == ActionType.Extend:
store.extend(data)

elif ctx.action == ActionType.Replace:
store.replace(data)

elif ctx.action == ActionType.Delete:
store.delete(data)

# Inform subscribers of state update
# TODO: what to do with act vs all?
if ctx.sender_id == bs.net.act_id:
changed[ctx.topic].emit(store)

# Reset context variables
ctx.action = None
ctx.action_content = None

0 comments on commit 14e4378

Please sign in to comment.