diff --git a/docs/source/tutorials/xrt-blop-demo.ipynb b/docs/source/tutorials/xrt-blop-demo.ipynb index 433ff8d..1b5fcc8 100644 --- a/docs/source/tutorials/xrt-blop-demo.ipynb +++ b/docs/source/tutorials/xrt-blop-demo.ipynb @@ -41,13 +41,13 @@ "metadata": {}, "outputs": [], "source": [ - "import sys, os\n", + "import time\n", + "\n", "from matplotlib import pyplot as plt\n", - "from blop.sim.xrt_beamline import Beamline\n", "\n", - "from blop import DOF, Objective, Agent\n", + "from blop import DOF, Agent, Objective\n", "from blop.digestion import beam_stats_digestion\n", - "import time" + "from blop.sim.xrt_beamline import Beamline" ] }, { @@ -78,10 +78,10 @@ "dofs = [\n", " DOF(description=\"KBV R\",\n", " device=beamline.kbv_dsv,\n", - " search_domain=(R1-dR1, R1+dR1)),\n", + " search_domain=(R1 - dR1, R1 + dR1)),\n", " DOF(description=\"KBH R\",\n", " device=beamline.kbh_dsh,\n", - " search_domain=(R2-dR2, R2+dR2)),\n", + " search_domain=(R2 - dR2, R2 + dR2)),\n", "\n", "]" ] @@ -94,7 +94,7 @@ "outputs": [], "source": [ "objectives = [\n", - " Objective(name=\"bl_det_sum\", \n", + " Objective(name=\"bl_det_sum\",\n", " target=\"max\",\n", " transform=\"log\",\n", " trust_domain=(20, 1e12)),\n", diff --git a/src/blop/agent.py b/src/blop/agent.py index c86efa8..bbc5965 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -5,7 +5,7 @@ import warnings from collections import OrderedDict from collections.abc import Callable, Generator, Hashable, Iterator, Mapping, Sequence -from typing import Any, cast, Optional, Union +from typing import Any, cast import bluesky.plan_stubs as bps # noqa F401 import botorch # type: ignore[import-untyped] @@ -148,7 +148,7 @@ def random_ref_point(self) -> ArrayLike: raise RuntimeError("'random_ref_point' is not defined for multi-objective optimization.") return train_targets[self.argmax_best_f(weights="random")] - def raw_inputs(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor: + def raw_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor: """ Get the raw, untransformed inputs for a DOF (or for a subset). """ @@ -156,7 +156,7 @@ def raw_inputs(self, index: Optional[Union[str, int]] = None, **subset_kwargs) - return torch.stack([self.raw_inputs(dof.name) for dof in self.dofs(**subset_kwargs)], dim=-1) return torch.tensor(self._table.loc[:, self.dofs[index].name].values, dtype=torch.double) - def train_inputs(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor: + def train_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor: """ A two-dimensional tensor of all DOF values for training on. """ @@ -168,7 +168,7 @@ def train_inputs(self, index: Optional[Union[str, int]] = None, **subset_kwargs) raw_inputs = self.raw_inputs(index=index, **subset_kwargs) return dof._transform(raw_inputs) - def raw_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> dict[str, torch.Tensor]: + def raw_targets_dict(self, index: str | int | None = None, **subset_kwargs) -> dict[str, torch.Tensor]: """ Get the raw, untransformed targets for an objective (or for a subset of objectives) as a dict. """ @@ -177,13 +177,13 @@ def raw_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_kwa key = self.objectives[index].name return {key: torch.tensor(self._table.loc[:, key].values, dtype=torch.double)} - def raw_targets(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor: + def raw_targets(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor: """ Get the raw, untransformed targets for an objective (or for a subset of objectives) as a tensor. """ return torch.stack(list(self.raw_targets_dict(index=index, **subset_kwargs).values()), axis=-1) - def train_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> dict[str, torch.Tensor]: + def train_targets_dict(self, index: str | int | None = None, **subset_kwargs) -> dict[str, torch.Tensor]: """ Returns the values associated with an objective name. """ @@ -205,7 +205,7 @@ def train_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_k return targets_dict - def train_targets(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor: + def train_targets(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor: """ Returns the values associated with an objective name as an (n_samples, n_objective) tensor. """ diff --git a/src/blop/dofs.py b/src/blop/dofs.py index 1b6dd8d..5266ac3 100644 --- a/src/blop/dofs.py +++ b/src/blop/dofs.py @@ -2,7 +2,7 @@ import time as ttime import uuid from collections.abc import Iterable, Sequence -from typing import Any, Literal, Optional, Union, cast, overload +from typing import Any, Literal, cast, overload import numpy as np import pandas as pd @@ -89,16 +89,16 @@ def __init__( name: str = None, description: str = "", type: Literal["continuous", "binary", "ordinal", "categorical"] = "continuous", - search_domain: Union[tuple[float, float], set[int], set[str], set[bool]] = (-np.inf, np.inf), - trust_domain: Optional[Union[tuple[float, float], set[int], set[str], set[bool]]] = None, - domain: Optional[Union[tuple[float, float], set[int], set[str], set[bool]]] = None, + search_domain: tuple[float, float] | set[int] | set[str] | set[bool] = (-np.inf, np.inf), + trust_domain: tuple[float, float] | set[int] | set[str] | set[bool] | None = None, + domain: tuple[float, float] | set[int] | set[str] | set[bool] | None = None, active: bool = True, read_only: bool = False, - transform: Optional[Literal["log", "logit", "arctanh"]] = None, - device: Optional[Signal] = None, + transform: Literal["log", "logit", "arctanh"] | None = None, + device: Signal | None = None, tags: list[str] = None, travel_expense: float = 1, - units: Optional[str] = None, + units: str | None = None, ): # these should be set first, as they are just variables self.name = name @@ -184,7 +184,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({filling})" @property - def search_domain(self) -> Union[tuple[float, float], set[int], set[str], set[bool]]: + def search_domain(self) -> tuple[float, float] | set[int] | set[str] | set[bool]: """ A writable DOF always has a search domain, and a read-only DOF will return its current value. """ @@ -198,21 +198,21 @@ def search_domain(self) -> Union[tuple[float, float], set[int], set[str], set[bo return self._search_domain @search_domain.setter - def search_domain(self, value: Union[tuple[float, float], set[int], set[str], set[bool]]): + def search_domain(self, value: tuple[float, float] | set[int] | set[str] | set[bool]): """ Make sure that the search domain is within the trust domain before setting it. """ value = validate_set(value, type=self.type) trust_domain = self.trust_domain if is_subset(value, trust_domain, type=self.type, proper=False): - self._search_domain = cast(Union[tuple[float, float], set[int], set[str], set[bool]], value) + self._search_domain = cast(tuple[float, float] | set[int] | set[str] | set[bool], value) else: raise ValueError( f"Cannot set search domain to {value} as it is not a subset of the trust domain {trust_domain}." ) @property - def trust_domain(self) -> Union[tuple[float, float], set[int], set[str], set[bool]]: + def trust_domain(self) -> tuple[float, float] | set[int] | set[str] | set[bool]: """ If _trust_domain is None, then we trust the entire domain (so we return the domain). """ @@ -235,7 +235,7 @@ def trust_domain(self, value): # The search domain must stay a subset of the trust domain, so set it as the intersection. self.search_domain = intersection(self.search_domain, value) - self._trust_domain = cast(Union[tuple[float, float], set[int], set[str], set[bool]], value) + self._trust_domain = cast(tuple[float, float] | set[int] | set[str] | set[bool], value) @property def domain(self) -> tuple[float, float] | set[int] | set[str] | set[bool]: diff --git a/src/blop/utils/sets.py b/src/blop/utils/sets.py index a92fa77..c99e251 100644 --- a/src/blop/utils/sets.py +++ b/src/blop/utils/sets.py @@ -1,9 +1,9 @@ -from typing import Union, cast +from typing import cast import numpy as np -def validate_set(s, type="continuous") -> Union[set, tuple[float, float]]: +def validate_set(s, type="continuous") -> set | tuple[float, float]: """ Check """ @@ -12,7 +12,7 @@ def validate_set(s, type="continuous") -> Union[set, tuple[float, float]]: try: x1, x2 = float(s[0]), float(s[1]) if x1 <= x2: - return cast(Union[tuple[float, float]], (x1, x2)) + return cast(tuple[float, float], (x1, x2)) except Exception: pass raise ValueError(