Skip to content

Commit

Permalink
qLogNEI **kwargs removal (#2406)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2406

Removed the ability for qLogNEI to swallow unused kwargs.

Reviewed By: saitcakmak

Differential Revision: D59238948
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Jul 2, 2024
1 parent 42dfd09 commit a577eda
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from functools import partial

from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
from typing import Callable, List, Optional, Tuple, TypeVar, Union

import torch
from botorch.acquisition.cached_cholesky import CachedCholeskyMCSamplerMixin
Expand Down Expand Up @@ -275,7 +275,7 @@ def __init__(
cache_root: bool = True,
tau_max: float = TAU_MAX,
tau_relu: float = TAU_RELU,
**kwargs: Any,
marginalize_dim: Optional[int] = None,
) -> None:
r"""q-Noisy Expected Improvement.
Expand Down Expand Up @@ -314,7 +314,7 @@ def __init__(
approximations to max.
tau_relu: Temperature parameter controlling the sharpness of the smooth
approximations to ReLU.
kwargs: Here for qNEI for compatibility.
marginalize_dim: The dimension to marginalize over.
TODO: similar to qNEHVI, when we are using sequential greedy candidate
selection, we could incorporate pending points X_baseline and compute
Expand Down Expand Up @@ -343,7 +343,7 @@ def __init__(
posterior_transform=posterior_transform,
prune_baseline=prune_baseline,
cache_root=cache_root,
**kwargs,
marginalize_dim=marginalize_dim,
)

def _sample_forward(self, obj: Tensor) -> Tensor:
Expand Down Expand Up @@ -372,7 +372,7 @@ def _init_baseline(
posterior_transform: Optional[PosteriorTransform] = None,
prune_baseline: bool = False,
cache_root: bool = True,
**kwargs: Any,
marginalize_dim: Optional[int] = None,
) -> None:
CachedCholeskyMCSamplerMixin.__init__(
self, model=model, cache_root=cache_root, sampler=sampler
Expand All @@ -383,7 +383,7 @@ def _init_baseline(
X=X_baseline,
objective=objective,
posterior_transform=posterior_transform,
marginalize_dim=kwargs.get("marginalize_dim"),
marginalize_dim=marginalize_dim,
constraints=self._constraints,
)
self.register_buffer("X_baseline", X_baseline)
Expand Down

0 comments on commit a577eda

Please sign in to comment.