Skip to content

Commit

Permalink
Add **kwargs to MC and KG Acquisition Function Constructors (pytorch#478
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#478

Different acquisition functions take different kwargs as inputs into their constructors. To standardize the inputs, we add `**kwargs` to the constructors, specifically for `qEI`, `qNEI`, `qKG`, and `qMFKG`.

Reviewed By: Balandat, lena-kashtelyan

Differential Revision: D22416290

fbshipit-source-id: 1f64efb6471ea7a43c2e3a057b407ef3b7331d4b
  • Loading branch information
EricZLou authored and facebook-github-bot committed Jul 8, 2020
1 parent 10a71ae commit 5dd7123
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion botorch/acquisition/knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import torch
from botorch import settings
Expand Down Expand Up @@ -71,6 +71,7 @@ def __init__(
inner_sampler: Optional[MCSampler] = None,
X_pending: Optional[Tensor] = None,
current_value: Optional[Tensor] = None,
**kwargs: Any,
) -> None:
r"""q-Knowledge Gradient (one-shot optimization).
Expand Down Expand Up @@ -227,6 +228,7 @@ def __init__(
cost_aware_utility: Optional[CostAwareUtility] = None,
project: Callable[[Tensor], Tensor] = lambda X: X,
expand: Callable[[Tensor], Tensor] = lambda X: X,
**kwargs: Any,
) -> None:
r"""Multi-Fidelity q-Knowledge Gradient (one-shot optimization).
Expand Down
4 changes: 3 additions & 1 deletion botorch/acquisition/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import math
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Any, Optional, Union

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(
sampler: Optional[MCSampler] = None,
objective: Optional[MCAcquisitionObjective] = None,
X_pending: Optional[Tensor] = None,
**kwargs: Any,
) -> None:
r"""q-Expected Improvement.
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(
objective: Optional[MCAcquisitionObjective] = None,
X_pending: Optional[Tensor] = None,
prune_baseline: bool = False,
**kwargs: Any,
) -> None:
r"""q-Noisy Expected Improvement.
Expand Down

0 comments on commit 5dd7123

Please sign in to comment.