Skip to content

Commit

Permalink
Add type hints (NSLS-II#87)
Browse files Browse the repository at this point in the history
* Add mypy dev dependency; Add typing to hardware_flyer and handlers

* Added typing for dofs.py

* Add typing to bayesian/kernels.py

* Added typing to monte_carlo.py

* black formatting

* Progress on typing objectives.py

* Added typing for objectives.py and utils/__init__.py

* Added typing for bayesian/acquisition/__init__.py and bayesian/models.py

* Add typing for sim/beamline.py

* Added more typing for objectives.py

* Added typing and vectorization to tests.py

* Type hinted and vectorized beam_stats_digestion

* Move prepare_re_env.py to examples

* Started typing agent.py

* A lot of progress on typing agent.py

* Fixed typing on agent.py

* No typing issues!

* Fixed tests and a few errors

* isort

* isort examples...

* isort and black were contradicting each other

* Ignore flake8 E704 errors

* Fixed tutorial notebooks' use of prepare_re_env.py

* Strip output of notebooks

* ruff check --fix

* Fixed issues with merge

* Revert change to skew_matrix_indices

* Fix kb sim test
  • Loading branch information
thomashopkins32 authored Feb 20, 2025
1 parent d736bda commit 7dc3cc9
Show file tree
Hide file tree
Showing 27 changed files with 1,064 additions and 890 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ exclude =
versioneer.py,
max-line-length = 125
# Ignore some style 'errors' produced while formatting by 'black'
ignore = E203, W503
ignore = E203, W503, E704
3 changes: 1 addition & 2 deletions docs/source/tutorials/hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@
},
"outputs": [],
"source": [
"from blop.utils import prepare_re_env # noqa: F401\n",
"%run -i prepare_re_env.py --db-type=temp\n",
"\n",
"%run -i $prepare_re_env.__file__ --db-type=temp\n",
"\n",
"from blop import DOF, Agent, Objective\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"from blop.utils import prepare_re_env # noqa: F401\n",
"\n",
"%run -i $prepare_re_env.__file__ --db-type=temp"
"%run -i prepare_re_env.py --db-type=temp"
]
},
{
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials/kb-mirrors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
"metadata": {},
"outputs": [],
"source": [
"from blop.utils import prepare_re_env # noqa: F401\n",
"\n",
"%run -i $prepare_re_env.__file__ --db-type=temp\n",
"%run -i prepare_re_env.py --db-type=temp\n",
"bec.disable_plots()"
]
},
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials/pareto-fronts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
"metadata": {},
"outputs": [],
"source": [
"from blop.utils import prepare_re_env # noqa: F401\n",
"\n",
"%run -i $prepare_re_env.__file__ --db-type=temp"
"%run -i prepare_re_env.py --db-type=temp"
]
},
{
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials/passive-dofs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"from blop.utils import prepare_re_env # noqa: F401\n",
"\n",
"%run -i $prepare_re_env.__file__ --db-type=temp"
"%run -i prepare_re_env.py --db-type=temp"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import bluesky.plan_stubs as bps # noqa F401
import bluesky.plans as bp # noqa F401
import databroker
import databroker # type: ignore[import-untyped]
import matplotlib.pyplot as plt
import numpy as np # noqa F401
from bluesky.callbacks import best_effort
from bluesky.run_engine import RunEngine
from databroker import Broker
from ophyd.utils import make_dir_tree
from ophyd.utils import make_dir_tree # type: ignore[import-untyped]

from blop.sim import HDF5Handler

Expand Down Expand Up @@ -93,17 +93,19 @@ def register_handlers(db, handlers):
kwargs_re = {"db_type": args.db_type, "root_dir": args.root_dir}
ret = re_env(**kwargs_re)
globals().update(**ret)
bec = ret["bec"]
db = ret["db"]

if args.use_sirepo:
from sirepo_bluesky.srw_handler import SRWFileHandler
from sirepo_bluesky.srw_handler import SRWFileHandler # type: ignore[import-untyped]

if args.env_type == "stepper":
from sirepo_bluesky.shadow_handler import ShadowFileHandler
from sirepo_bluesky.shadow_handler import ShadowFileHandler # type: ignore[import-untyped]

handlers = {"srw": SRWFileHandler, "SIREPO_FLYER": SRWFileHandler, "shadow": ShadowFileHandler}
plt.ion()
elif args.env_type == "flyer":
from sirepo_bluesky.madx_handler import MADXFileHandler
from sirepo_bluesky.madx_handler import MADXFileHandler # type: ignore[import-untyped]

handlers = {"srw": SRWFileHandler, "SIREPO_FLYER": SRWFileHandler, "madx": MADXFileHandler}
bec.disable_plots() # noqa: F821
Expand Down
199 changes: 199 additions & 0 deletions examples/bluesky_adaptive_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from collections.abc import Sequence
from typing import Any, Callable, Optional

import pandas as pd
from bluesky_adaptive.agents.base import Agent as BlueskyAdaptiveBaseAgent # type: ignore[import-untyped]
from databroker.client import BlueskyRun # type: ignore[import-untyped]
from numpy.typing import ArrayLike

from blop.agent import BaseAgent as BlopAgent # type: ignore[import-untyped]
from blop.digestion import default_digestion_function # type: ignore[import-untyped]


class BlueskyAdaptiveAgent(BlueskyAdaptiveBaseAgent, BlopAgent):
"""A BlueskyAdaptiveAgent that uses Blop for the underlying agent."""

# TODO: Move into main package once databroker V2 is supported

def __init__(
self,
*,
acqf_string: str,
route: bool,
sequential: bool,
upsample: int,
acqf_kwargs: dict[str, Any],
detector_names: Optional[list[str]] = None,
**kwargs,
):
super().__init__(**kwargs)
self._acqf_string = acqf_string
self._route = route
self._sequential = sequential
self._upsample = upsample
self._acqf_kwargs = acqf_kwargs
self._detector_names = detector_names or []

@property
def detector_names(self) -> list[str]:
return [str(name) for name in self._detector_names]

@detector_names.setter
def detector_names(self, names: list[str]):
self._detector_names = list(names)

@property
def acquisition_function(self) -> str:
return str(self._acqf_string)

@acquisition_function.setter
def acquisition_function(self, acqf_string: str):
self._acqf_string = str(acqf_string)

@property
def route(self) -> bool:
return bool(self._route)

@route.setter
def route(self, route: bool):
self._route = route

@property
def sequential(self) -> bool:
return bool(self._sequential)

@sequential.setter
def sequential(self, sequential: bool):
self._sequential = sequential

@property
def upsample(self) -> int:
return int(self._upsample)

@upsample.setter
def upsample(self, upsample: int):
self._upsample = int(upsample)

@property
def acqf_kwargs(self) -> dict[str, str]:
return {str(k): str(v) for k, v in self._acqf_kwargs.items()}

def update_acqf_kwargs(self, **kwargs):
self._acqf_kwargs.update(kwargs)

def server_registrations(self) -> list[str]:
"""This is how we make these avaialble to the REST API."""
self._register_method("Update Acquistion Function Kwargs", self.update_acqf_kwargs)
self._register_property("Acquisition Function", self.acquisition_function, self.acquisition_function)
self._register_property("Route Points", self.route, self.route)
self._register_property("Sequential Points", self.sequential, self.sequential)
self._register_property("Upsample Points", self.upsample, self.upsample)
return super().server_registrations()

def ask(self, batch_size: int) -> tuple[Sequence[dict[str, ArrayLike]], Sequence[ArrayLike]]:
default_result = super().ask(
n=batch_size,
acqf=self._acqf_string,
route=self._route,
sequential=self._sequential,
upsample=self._upsample,
**self._acqf_kwargs,
)

"""res = {
"points": {dof.name: list(points[..., i]) for i, dof in enumerate(active_dofs(read_only=False))},
"acqf_name": acqf_config["name"],
"acqf_obj": list(np.atleast_1d(acqf_obj.numpy())),
"acqf_kwargs": acqf_kwargs,
"duration_ms": duration,
"sequential": sequential,
"upsample": upsample,
"read_only_values": read_only_values,
# "posterior": p,
}
"""

points: dict[str, list[ArrayLike]] = default_result.pop("points")
acqf_obj: list[ArrayLike] = default_result.pop("acqf_obj")
# Turn dict of list of points into list of consistently sized points
points: list[tuple[ArrayLike]] = list(zip(*[value for _, value in points.items()]))
dicts = []
for point, obj in zip(points, acqf_obj):
d = default_result.copy()
d["point"] = point
d["acqf_obj"] = obj
dicts.append(d)
return points, dicts

def tell(self, x: dict[str, ArrayLike], y: dict[str, ArrayLike]):
x = {key: x_i for x_i, key in zip(x, self.dofs.names)}
y = {key: y_i for y_i, key in zip(y, self.objectives.names)}
super().tell(data={**x, **y})
return {**x, **y}

def report(self) -> dict[str, Any]:
raise NotImplementedError("Report is not implmented for BlueskyAdaptiveAgent")

def unpack_run(self, run: BlueskyRun) -> tuple[list[ArrayLike], list[ArrayLike]]:
"""Use my DOFs to convert the run into an independent array, and my objectives to create the dependent array.
In practice for shape management, we will use lists not np.arrays at this stage.
Parameters
----------
run : BlueskyRun
Returns
-------
independent_var :
The independent variable of the measurement
dependent_var :
The measured data, processed for relevance
"""
if not self.digestion or self.digestion == default_digestion_function:
# Assume all raw data is available in primary stream as keys
return (
[run.primary.data[key].read() for key in self.dofs.names],
[run.primary.data[key].read() for key in self.objectives.names],
)
else:
# Hope and pray that the digestion function designed for DataFrame can handle the XArray
data: pd.DataFrame = self.digestion(run.primary.data.read(), **self.digestion_kwargs)
return [data.loc[:, key] for key in self.dofs.names], [data.loc[:, key] for key in self.objectives.names]

def measurement_plan(self, point: ArrayLike) -> tuple[str, list[Any], dict[str, Any]]:
"""Fetch the string name of a registered plan, as well as the positional and keyword
arguments to pass that plan.
Args/Kwargs is a common place to transform relative into absolute motor coords, or
other device specific parameters.
By default, this measurement plan attempts to use in the built in functionality in a QueueServer compatible way.
Signals and Devices are not passed as objects, but serialized as strings for the RE as a service to use.
Parameters
----------
point : ArrayLike
Next point to measure using a given plan
Returns
-------
plan_name : str
plan_args : List
List of arguments to pass to plan from a point to measure.
plan_kwargs : dict
Dictionary of keyword arguments to pass the plan, from a point to measure.
"""
if isinstance(self.acquisition_plan, Callable):
plan_name = self.acquisition_plan.__name__
else:
plan_name = self.acquisition_plan
if plan_name == "default_acquisition_plan":
# Convert point back to dict form for the sake of compatability with default plan
acquisition_dofs = self.dofs(active=True, read_only=False)

return self.acquisition_plan.__name__, [
acquisition_dofs,
{dof.name: point[i] for i, dof in enumerate(acquisition_dofs)},
[*self.detector_names, *[dev.__name__ for dev in self.dofs.devices]],
]
else:
raise NotImplementedError("Only default_acquisition_plan is implemented")
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pre-commit = [
"import-linter",
"nbstripout",
]
adaptive = ["bluesky-adaptive"]

dev = [
"pytest-codecov",
"coverage",
Expand All @@ -69,9 +71,11 @@ dev = [
"sphinx_rtd_theme",
"ruff",
"import-linter",
"pandas-stubs",
"types-PyYAML",
"mypy",
]


[project.urls]
Homepage = "https://github.com/NSLS-II/blop"
Documentation = "https://nsls-ii.github.io/blop"
Expand Down
Loading

0 comments on commit 7dc3cc9

Please sign in to comment.