Skip to content

Commit

Permalink
NoisyExpectedHypervolumeMixin (pytorch#2045)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2045

X-link: facebook/Ax#1909

This commit introduces `NoisyExpectedHypervolumeMixin`, a derivative of `CachedCholeskyMCSamplerMixin` that separates out much of the Pareto-partitioning required for `qNEHVI`.

Differential Revision: D50337502

fbshipit-source-id: 2c17b7e50fdcbbe667a21a30829f62ddd94c4d69
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Oct 16, 2023
1 parent 7e57f8c commit ba191c0
Show file tree
Hide file tree
Showing 3 changed files with 393 additions and 279 deletions.
302 changes: 27 additions & 275 deletions botorch/acquisition/multi_objective/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,27 @@

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Callable, List, Optional, Union

import torch
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.acquisition.cached_cholesky import CachedCholeskyMCSamplerMixin
from botorch.acquisition.multi_objective.objective import (
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.utils import (
prune_inferior_points_multi_objective,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.model import Model
from botorch.models.transforms.input import InputPerturbation
from botorch.sampling.base import MCSampler
from botorch.utils.multi_objective.box_decompositions.box_decomposition_list import (
BoxDecompositionList,
)
from botorch.utils.multi_objective.box_decompositions.dominated import (
DominatedPartitioning,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
FastNondominatedPartitioning,
NondominatedPartitioning,
)
from botorch.utils.multi_objective.box_decompositions.utils import (
_pad_batch_pareto_frontier,
from botorch.utils.multi_objective.hypervolume import (
NoisyExpectedHypervolumeMixin,
SubsetIndexCachingMixin,
)
from botorch.utils.multi_objective.hypervolume import SubsetIndexCachingMixin
from botorch.utils.objective import compute_smoothed_feasibility_indicator
from botorch.utils.torch import BufferDict
from botorch.utils.transforms import (
concatenate_pending_points,
is_fully_bayesian,
Expand Down Expand Up @@ -250,7 +235,9 @@ def _compute_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
q = obj.shape[-2]
if self.constraints is not None:
feas_weights = compute_smoothed_feasibility_indicator(
constraints=self.constraints, samples=samples, eta=self.eta
constraints=self.constraints,
samples=samples,
eta=self.eta,
) # `sample_shape x batch-shape x q`
device = self.ref_point.device
q_subset_indices = self.compute_q_subset_indices(q_out=q, device=device)
Expand Down Expand Up @@ -326,7 +313,7 @@ def forward(self, X: Tensor) -> Tensor:


class qNoisyExpectedHypervolumeImprovement(
qExpectedHypervolumeImprovement, CachedCholeskyMCSamplerMixin
NoisyExpectedHypervolumeMixin, qExpectedHypervolumeImprovement
):
def __init__(
self,
Expand Down Expand Up @@ -407,268 +394,33 @@ def __init__(
`q` points.
cache_root: A boolean indicating whether to cache the root
decomposition over `X_baseline` and use low-rank updates.
marginalize_dim: A batch dimension that should be marginalized.
"""
if len(ref_point) < 2:
raise ValueError(
"qNoisyExpectedHypervolumeImprovement supports m>=2 outcomes "
f"but ref_point has length {len(ref_point)}, which is smaller than 2."
)
ref_point = torch.as_tensor(
ref_point, dtype=X_baseline.dtype, device=X_baseline.device
)
super(qExpectedHypervolumeImprovement, self).__init__(
MultiObjectiveMCAcquisitionFunction.__init__(
self,
model=model,
sampler=sampler,
objective=objective,
constraints=constraints,
eta=eta,
)
CachedCholeskyMCSamplerMixin.__init__(
self, model=model, cache_root=cache_root, sampler=sampler
)

if X_baseline.ndim > 2:
raise UnsupportedError(
"qNoisyExpectedHypervolumeImprovement does not support batched "
f"X_baseline. Expected 2 dims, got {X_baseline.ndim}."
)
if prune_baseline:
X_baseline = prune_inferior_points_multi_objective(
model=model,
X=X_baseline,
objective=objective,
constraints=constraints,
ref_point=ref_point,
marginalize_dim=marginalize_dim,
)
self.register_buffer("ref_point", ref_point)
self.alpha = alpha
self.q_in = -1
self.q_out = -1
self.q_subset_indices = BufferDict()
self.partitioning = None
# set partitioning class and args
self.p_kwargs = {}
if self.alpha > 0:
self.p_kwargs["alpha"] = self.alpha
self.p_class = NondominatedPartitioning
else:
self.p_class = FastNondominatedPartitioning
self.register_buffer("_X_baseline", X_baseline)
self.register_buffer("_X_baseline_and_pending", X_baseline)
self.register_buffer(
"cache_pending",
torch.tensor(cache_pending, dtype=bool),
)
self.register_buffer(
"_prev_nehvi",
torch.tensor(0.0, dtype=ref_point.dtype, device=ref_point.device),
)
self.register_buffer(
"_max_iep",
torch.tensor(max_iep, dtype=torch.long),
)
self.register_buffer(
"incremental_nehvi",
torch.tensor(incremental_nehvi, dtype=torch.bool),
)

# Base sampler is initialized in _set_cell_bounds.
self.base_sampler = None

if X_pending is not None:
# This will call self._set_cell_bounds if the number of pending
# points is greater than self._max_iep.
self.set_X_pending(X_pending)
# In the case that X_pending is not None, but there are fewer than
# max_iep pending points, the box decompositions are not performed in
# set_X_pending. Therefore, we need to perform a box decomposition over
# f(X_baseline) here.
if X_pending is None or X_pending.shape[-2] <= self._max_iep:
self._set_cell_bounds(num_new_points=X_baseline.shape[0])
# Set q_in=-1 to so that self.sampler is updated at the next forward call.
self.q_in = -1

@property
def X_baseline(self) -> Tensor:
r"""Return X_baseline augmented with pending points cached using CBD."""
return self._X_baseline_and_pending

def _compute_initial_hvs(self, obj: Tensor, feas: Optional[Tensor] = None) -> None:
r"""Compute hypervolume dominated by f(X_baseline) under each sample.
Args:
obj: A `sample_shape x batch_shape x n x m`-dim tensor of samples
of objectives.
feas: `sample_shape x batch_shape x n`-dim tensor of samples
of feasibility indicators.
"""
initial_hvs = []
for i, sample in enumerate(obj):
if self.constraints is not None:
sample = sample[feas[i]]
dominated_partitioning = DominatedPartitioning(
ref_point=self.ref_point,
Y=sample,
)
hv = dominated_partitioning.compute_hypervolume()
initial_hvs.append(hv)
self.register_buffer(
"_initial_hvs",
torch.tensor(initial_hvs, dtype=obj.dtype, device=obj.device).view(
self._batch_sample_shape, *obj.shape[-2:]
),
)

def _set_cell_bounds(self, num_new_points: int) -> None:
r"""Compute the box decomposition under each posterior sample.
Args:
num_new_points: The number of new points (beyond the points
in X_baseline) that were used in the previous box decomposition.
In the first box decomposition, this should be the number of points
in X_baseline.
"""
feas = None
if self.X_baseline.shape[0] > 0:
with torch.no_grad():
posterior = self.model.posterior(self.X_baseline)
# Reset sampler, accounting for possible one-to-many transform.
self.q_in = -1
if self.base_sampler is None:
# Initialize the base sampler if needed.
samples = self.get_posterior_samples(posterior)
self.base_sampler = deepcopy(self.sampler)
else:
samples = self.base_sampler(posterior)
n_w = posterior._extended_shape()[-2] // self.X_baseline.shape[-2]
self._set_sampler(q_in=num_new_points * n_w, posterior=posterior)
# cache posterior
if self._cache_root:
# Note that this implicitly uses LinearOperator's caching to check if
# the proper root decomposition has already been cached to
# `posterior.mvn.lazy_covariance_matrix`, which it may have been in
# the call to `self.base_sampler`, and computes it if not found
self._baseline_L = self._compute_root_decomposition(posterior=posterior)
obj = self.objective(samples, X=self.X_baseline)
if self.constraints is not None:
feas = torch.stack(
[c(samples) <= 0 for c in self.constraints], dim=0
).all(dim=0)
else:
sample_shape = (
self.sampler.sample_shape
if self.sampler is not None
else self._default_sample_shape
)
obj = torch.empty(
*sample_shape,
0,
self.ref_point.shape[-1],
dtype=self.ref_point.dtype,
device=self.ref_point.device,
)
self._batch_sample_shape = obj.shape[:-2]
# collapse batch dimensions
# use numel() rather than view(-1) to handle case of no baseline points
new_batch_shape = self._batch_sample_shape.numel()
obj = obj.view(new_batch_shape, *obj.shape[-2:])
if self.constraints is not None and feas is not None:
feas = feas.view(new_batch_shape, *feas.shape[-1:])

if self.partitioning is None and not self.incremental_nehvi:
self._compute_initial_hvs(obj=obj, feas=feas)
if self.ref_point.shape[-1] > 2:
# the partitioning algorithms run faster on the CPU
# due to advanced indexing
ref_point_cpu = self.ref_point.cpu()
obj_cpu = obj.cpu()
if self.constraints is not None and feas is not None:
feas_cpu = feas.cpu()
obj_cpu = [obj_cpu[i][feas_cpu[i]] for i in range(obj.shape[0])]
partitionings = []
for sample in obj_cpu:
partitioning = self.p_class(
ref_point=ref_point_cpu, Y=sample, **self.p_kwargs
)
partitionings.append(partitioning)
self.partitioning = BoxDecompositionList(*partitionings)
else:
# use batched partitioning
obj = _pad_batch_pareto_frontier(
Y=obj,
ref_point=self.ref_point.unsqueeze(0).expand(
obj.shape[0], self.ref_point.shape[-1]
),
feasibility_mask=feas,
)
self.partitioning = self.p_class(
ref_point=self.ref_point, Y=obj, **self.p_kwargs
)
cell_bounds = self.partitioning.get_hypercell_bounds().to(self.ref_point)
cell_bounds = cell_bounds.view(
2, *self._batch_sample_shape, *cell_bounds.shape[-2:]
)
self.register_buffer("cell_lower_bounds", cell_bounds[0])
self.register_buffer("cell_upper_bounds", cell_bounds[1])

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
r"""Informs the acquisition function about pending design points.
Args:
X_pending: `n x d` Tensor with `n` `d`-dim design points that have
been submitted for evaluation but have not yet been evaluated.
"""
if X_pending is None:
self.X_pending = None
else:
if X_pending.requires_grad:
warnings.warn(
"Pending points require a gradient but the acquisition function"
" will not provide a gradient to these points.",
BotorchWarning,
)
X_pending = X_pending.detach().clone()
if self.cache_pending:
X_baseline = torch.cat([self._X_baseline, X_pending], dim=-2)
# Number of new points is the total number of points minus
# (the number of previously cached pending points plus the
# of number of baseline points).
num_new_points = X_baseline.shape[0] - self.X_baseline.shape[0]
if num_new_points > 0:
if num_new_points > self._max_iep:
# Set the new baseline points to include pending points.
self.register_buffer("_X_baseline_and_pending", X_baseline)
# Recompute box decompositions.
self._set_cell_bounds(num_new_points=num_new_points)
if not self.incremental_nehvi:
self._prev_nehvi = (
(self._hypervolumes - self._initial_hvs)
.clamp_min(0.0)
.mean()
)
# Set to None so that pending points are not concatenated in
# forward.
self.X_pending = None
# Set q_in=-1 to so that self.sampler is updated at the next
# forward call.
self.q_in = -1
else:
self.X_pending = X_pending[-num_new_points:]
else:
self.X_pending = X_pending

@property
def _hypervolumes(self) -> Tensor:
r"""Compute hypervolume over X_baseline under each posterior sample.
Returns:
A `n_samples`-dim tensor of hypervolumes.
"""
return (
self.partitioning.compute_hypervolume()
.to(self.ref_point) # for m > 2, the partitioning is on the CPU
.view(self._batch_sample_shape)
SubsetIndexCachingMixin.__init__(self)
NoisyExpectedHypervolumeMixin.__init__(
self,
model=model,
ref_point=ref_point,
X_baseline=X_baseline,
sampler=self.sampler,
objective=self.objective,
constraints=self.constraints,
X_pending=X_pending,
prune_baseline=prune_baseline,
alpha=alpha,
cache_pending=cache_pending,
max_iep=max_iep,
incremental_nehvi=incremental_nehvi,
cache_root=cache_root,
marginalize_dim=marginalize_dim,
)

@concatenate_pending_points
Expand Down
Loading

0 comments on commit ba191c0

Please sign in to comment.