Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input constructor for qMultiFidelityHypervolumeKnowledgeGradient #2524

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 86 additions & 15 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
_get_hv_value_function,
qHypervolumeKnowledgeGradient,
qMultiFidelityHypervolumeKnowledgeGradient,
)
from botorch.acquisition.multi_objective.logei import (
qLogExpectedHypervolumeImprovement,
Expand Down Expand Up @@ -1274,21 +1275,6 @@ def construct_inputs_qKG(
return inputs_qkg


def _get_ref_point(
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
) -> Tensor:

if objective is None:
ref_point = objective_thresholds
elif isinstance(objective, RiskMeasureMCObjective):
ref_point = objective.preprocessing_function(objective_thresholds)
else:
ref_point = objective(objective_thresholds)

return ref_point


@acqf_input_constructor(qHypervolumeKnowledgeGradient)
def construct_inputs_qHVKG(
model: Model,
Expand Down Expand Up @@ -1381,6 +1367,76 @@ def construct_inputs_qMFKG(
}


@acqf_input_constructor(qMultiFidelityHypervolumeKnowledgeGradient)
def construct_inputs_qMFHVKG(
model: Model,
training_data: MaybeDict[SupervisedDataset],
bounds: list[tuple[float, float]],
target_fidelities: dict[int, Union[int, float]],
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
fidelity_weights: Optional[dict[int, float]] = None,
cost_intercept: float = 1.0,
num_trace_observations: int = 0,
num_fantasies: int = 8,
num_pareto: int = 10,
**optimize_objective_kwargs: TOptimizeObjectiveKwargs,
) -> dict[str, Any]:
r"""
Construct kwargs for `qMultiFidelityHypervolumeKnowledgeGradient` constructor.
"""

inputs_mf = construct_inputs_mf_base(
target_fidelities=target_fidelities,
fidelity_weights=fidelity_weights,
cost_intercept=cost_intercept,
num_trace_observations=num_trace_observations,
)

if num_trace_observations > 0:
raise NotImplementedError(
"Trace observations are not currently supported "
"by `qMultiFidelityHypervolumeKnowledgeGradient`."
)

del inputs_mf["expand"]

X = _get_dataset_field(training_data, "X", first_only=True)
_bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device)

ref_point = _get_ref_point(
objective_thresholds=objective_thresholds, objective=objective
)

acq_function = _get_hv_value_function(
model=model,
ref_point=ref_point,
use_posterior_mean=True,
objective=objective,
)

_, current_value = optimize_objective(
model=model,
bounds=_bounds.t(),
q=num_pareto,
acq_function=acq_function,
fixed_features=target_fidelities,
**optimize_objective_kwargs,
)

return {
"model": model,
"objective": objective,
"ref_point": ref_point,
"num_fantasies": num_fantasies,
"num_pareto": num_pareto,
"current_value": current_value.detach().cpu().max(),
"target_fidelities": target_fidelities,
**inputs_mf,
}


@acqf_input_constructor(qMultiFidelityMaxValueEntropy)
def construct_inputs_qMFMES(
model: Model,
Expand Down Expand Up @@ -1806,3 +1862,18 @@ def construct_inputs_NIPV(
"posterior_transform": posterior_transform,
}
return inputs


def _get_ref_point(
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
) -> Tensor:

if objective is None:
ref_point = objective_thresholds
elif isinstance(objective, RiskMeasureMCObjective):
ref_point = objective.preprocessing_function(objective_thresholds)
else:
ref_point = objective(objective_thresholds)

return ref_point
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def __init__(
)
self.project = project
if kwargs.get("expand") is not None:
raise NotImplementedError("Trace observations are not currently supported.")
raise NotImplementedError(
"Trace observations are not currently supported "
"by `qMultiFidelityHypervolumeKnowledgeGradient`."
)
self.expand = lambda X: X
self.valfunc_cls = valfunc_cls
self.valfunc_argfac = valfunc_argfac
Expand Down
Loading
Loading