Skip to content

Commit

Permalink
Add TODO notes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic authored and julianStreibel committed May 7, 2024
1 parent c721015 commit 1389b1d
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions baybe/surrogates/multi_armed_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BernoulliMultiArmedBanditSurrogate(Surrogate):
supports_transfer_learning: ClassVar[bool] = False
# See base class.

# TODO: Introduce BetaPrior class
prior: tuple[float, float] = field(
default=(1, 1), converter=lambda x: cattrs.structure(x, tuple[float, float])
)
Expand Down Expand Up @@ -74,6 +75,8 @@ def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> None:
# See base class.

# TODO: Fix requirement of OHE encoding
# TODO: Generalize to arbitrary number of categorical parameters
if not (
(len(searchspace.parameters) == 1)
and isinstance(p := searchspace.parameters[0], CategoricalParameter)
Expand All @@ -84,6 +87,8 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No
f"spanned by exactly one categorical parameter using one-hot encoding."
)

# TODO: Incorporate training target validation at the appropriate place in
# the BayBE ecosystem.
wins = (train_x * train_y).sum(axis=0)
losses = (train_x * (1 - train_y)).sum(axis=0)
self._win_lose_counts = np.vstack([wins, losses])

0 comments on commit 1389b1d

Please sign in to comment.