Skip to content

Commit

Permalink
pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Mar 5, 2025
1 parent d4a71c4 commit d22f991
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 29 deletions.
14 changes: 7 additions & 7 deletions docs/source/tutorials/xrt-blop-demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
"metadata": {},
"outputs": [],
"source": [
"import sys, os\n",
"import time\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from blop.sim.xrt_beamline import Beamline\n",
"\n",
"from blop import DOF, Objective, Agent\n",
"from blop import DOF, Agent, Objective\n",
"from blop.digestion import beam_stats_digestion\n",
"import time"
"from blop.sim.xrt_beamline import Beamline"
]
},
{
Expand Down Expand Up @@ -78,10 +78,10 @@
"dofs = [\n",
" DOF(description=\"KBV R\",\n",
" device=beamline.kbv_dsv,\n",
" search_domain=(R1-dR1, R1+dR1)),\n",
" search_domain=(R1 - dR1, R1 + dR1)),\n",
" DOF(description=\"KBH R\",\n",
" device=beamline.kbh_dsh,\n",
" search_domain=(R2-dR2, R2+dR2)),\n",
" search_domain=(R2 - dR2, R2 + dR2)),\n",
"\n",
"]"
]
Expand All @@ -94,7 +94,7 @@
"outputs": [],
"source": [
"objectives = [\n",
" Objective(name=\"bl_det_sum\", \n",
" Objective(name=\"bl_det_sum\",\n",
" target=\"max\",\n",
" transform=\"log\",\n",
" trust_domain=(20, 1e12)),\n",
Expand Down
14 changes: 7 additions & 7 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from collections import OrderedDict
from collections.abc import Callable, Generator, Hashable, Iterator, Mapping, Sequence
from typing import Any, cast, Optional, Union
from typing import Any, cast

import bluesky.plan_stubs as bps # noqa F401
import botorch # type: ignore[import-untyped]
Expand Down Expand Up @@ -148,15 +148,15 @@ def random_ref_point(self) -> ArrayLike:
raise RuntimeError("'random_ref_point' is not defined for multi-objective optimization.")
return train_targets[self.argmax_best_f(weights="random")]

def raw_inputs(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor:
def raw_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
"""
Get the raw, untransformed inputs for a DOF (or for a subset).
"""
if index is None:
return torch.stack([self.raw_inputs(dof.name) for dof in self.dofs(**subset_kwargs)], dim=-1)
return torch.tensor(self._table.loc[:, self.dofs[index].name].values, dtype=torch.double)

def train_inputs(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor:
def train_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
"""
A two-dimensional tensor of all DOF values for training on.
"""
Expand All @@ -168,7 +168,7 @@ def train_inputs(self, index: Optional[Union[str, int]] = None, **subset_kwargs)
raw_inputs = self.raw_inputs(index=index, **subset_kwargs)
return dof._transform(raw_inputs)

def raw_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> dict[str, torch.Tensor]:
def raw_targets_dict(self, index: str | int | None = None, **subset_kwargs) -> dict[str, torch.Tensor]:
"""
Get the raw, untransformed targets for an objective (or for a subset of objectives) as a dict.
"""
Expand All @@ -177,13 +177,13 @@ def raw_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_kwa
key = self.objectives[index].name
return {key: torch.tensor(self._table.loc[:, key].values, dtype=torch.double)}

def raw_targets(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor:
def raw_targets(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
"""
Get the raw, untransformed targets for an objective (or for a subset of objectives) as a tensor.
"""
return torch.stack(list(self.raw_targets_dict(index=index, **subset_kwargs).values()), axis=-1)

def train_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> dict[str, torch.Tensor]:
def train_targets_dict(self, index: str | int | None = None, **subset_kwargs) -> dict[str, torch.Tensor]:
"""
Returns the values associated with an objective name.
"""
Expand All @@ -205,7 +205,7 @@ def train_targets_dict(self, index: Optional[Union[str, int]] = None, **subset_k

return targets_dict

def train_targets(self, index: Optional[Union[str, int]] = None, **subset_kwargs) -> torch.Tensor:
def train_targets(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
"""
Returns the values associated with an objective name as an (n_samples, n_objective) tensor.
"""
Expand Down
24 changes: 12 additions & 12 deletions src/blop/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time as ttime
import uuid
from collections.abc import Iterable, Sequence
from typing import Any, Literal, Optional, Union, cast, overload
from typing import Any, Literal, cast, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -89,16 +89,16 @@ def __init__(
name: str = None,
description: str = "",
type: Literal["continuous", "binary", "ordinal", "categorical"] = "continuous",
search_domain: Union[tuple[float, float], set[int], set[str], set[bool]] = (-np.inf, np.inf),
trust_domain: Optional[Union[tuple[float, float], set[int], set[str], set[bool]]] = None,
domain: Optional[Union[tuple[float, float], set[int], set[str], set[bool]]] = None,
search_domain: tuple[float, float] | set[int] | set[str] | set[bool] = (-np.inf, np.inf),
trust_domain: tuple[float, float] | set[int] | set[str] | set[bool] | None = None,
domain: tuple[float, float] | set[int] | set[str] | set[bool] | None = None,
active: bool = True,
read_only: bool = False,
transform: Optional[Literal["log", "logit", "arctanh"]] = None,
device: Optional[Signal] = None,
transform: Literal["log", "logit", "arctanh"] | None = None,
device: Signal | None = None,
tags: list[str] = None,
travel_expense: float = 1,
units: Optional[str] = None,
units: str | None = None,
):
# these should be set first, as they are just variables
self.name = name
Expand Down Expand Up @@ -184,7 +184,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({filling})"

@property
def search_domain(self) -> Union[tuple[float, float], set[int], set[str], set[bool]]:
def search_domain(self) -> tuple[float, float] | set[int] | set[str] | set[bool]:
"""
A writable DOF always has a search domain, and a read-only DOF will return its current value.
"""
Expand All @@ -198,21 +198,21 @@ def search_domain(self) -> Union[tuple[float, float], set[int], set[str], set[bo
return self._search_domain

@search_domain.setter
def search_domain(self, value: Union[tuple[float, float], set[int], set[str], set[bool]]):
def search_domain(self, value: tuple[float, float] | set[int] | set[str] | set[bool]):
"""
Make sure that the search domain is within the trust domain before setting it.
"""
value = validate_set(value, type=self.type)
trust_domain = self.trust_domain
if is_subset(value, trust_domain, type=self.type, proper=False):
self._search_domain = cast(Union[tuple[float, float], set[int], set[str], set[bool]], value)
self._search_domain = cast(tuple[float, float] | set[int] | set[str] | set[bool], value)
else:
raise ValueError(
f"Cannot set search domain to {value} as it is not a subset of the trust domain {trust_domain}."
)

@property
def trust_domain(self) -> Union[tuple[float, float], set[int], set[str], set[bool]]:
def trust_domain(self) -> tuple[float, float] | set[int] | set[str] | set[bool]:
"""
If _trust_domain is None, then we trust the entire domain (so we return the domain).
"""
Expand All @@ -235,7 +235,7 @@ def trust_domain(self, value):
# The search domain must stay a subset of the trust domain, so set it as the intersection.
self.search_domain = intersection(self.search_domain, value)

self._trust_domain = cast(Union[tuple[float, float], set[int], set[str], set[bool]], value)
self._trust_domain = cast(tuple[float, float] | set[int] | set[str] | set[bool], value)

@property
def domain(self) -> tuple[float, float] | set[int] | set[str] | set[bool]:
Expand Down
6 changes: 3 additions & 3 deletions src/blop/utils/sets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Union, cast
from typing import cast

import numpy as np


def validate_set(s, type="continuous") -> Union[set, tuple[float, float]]:
def validate_set(s, type="continuous") -> set | tuple[float, float]:
"""
Check
"""
Expand All @@ -12,7 +12,7 @@ def validate_set(s, type="continuous") -> Union[set, tuple[float, float]]:
try:
x1, x2 = float(s[0]), float(s[1])
if x1 <= x2:
return cast(Union[tuple[float, float]], (x1, x2))
return cast(tuple[float, float], (x1, x2))
except Exception:
pass
raise ValueError(
Expand Down

0 comments on commit d22f991

Please sign in to comment.