diff --git a/baybe/surrogates/multi_armed_bandit.py b/baybe/surrogates/multi_armed_bandit.py index cbebb2f68..e7efd50c3 100644 --- a/baybe/surrogates/multi_armed_bandit.py +++ b/baybe/surrogates/multi_armed_bandit.py @@ -29,6 +29,7 @@ class BernoulliMultiArmedBanditSurrogate(Surrogate): supports_transfer_learning: ClassVar[bool] = False # See base class. + # TODO: Introduce BetaPrior class prior: tuple[float, float] = field( default=(1, 1), converter=lambda x: cattrs.structure(x, tuple[float, float]) ) @@ -74,6 +75,8 @@ def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> None: # See base class. + # TODO: Fix requirement of OHE encoding + # TODO: Generalize to arbitrary number of categorical parameters if not ( (len(searchspace.parameters) == 1) and isinstance(p := searchspace.parameters[0], CategoricalParameter) @@ -84,6 +87,8 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No f"spanned by exactly one categorical parameter using one-hot encoding." ) + # TODO: Incorporate training target validation at the appropriate place in + # the BayBE ecosystem. wins = (train_x * train_y).sum(axis=0) losses = (train_x * (1 - train_y)).sum(axis=0) self._win_lose_counts = np.vstack([wins, losses])