Skip to content

Commit

Permalink
BoxDecomposition cleanup (pytorch#1490)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1490

- Change `compute_hypervolume` so that each BoxDecomposition subclass uses shared logic for the no-data case
- [debatable] When `Y` is `None`, functions of Y like `box_decomp._neg_Y` are `None` rather than being unset attributes, so we do "if self._neg_Y is None" rather than catching an AttributeError. This makes catching type errors easier since otherwise Pyre is unhappy about references to the potentially-uninitialized attribute.
- Took out unnecessary "register_buffer" calls (this happens automatically with `torch.nn.Module.setattr`)

Differential Revision: D41172490

fbshipit-source-id: 01e3e159846b5b6cbd77763bdec00ee5d86aa50d
  • Loading branch information
esantorella authored and facebook-github-bot committed Nov 10, 2022
1 parent 7eada74 commit 95fe8bc
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,18 @@ def __init__(
Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
"""
super().__init__()
self.register_buffer("_neg_ref_point", -ref_point)
self.register_buffer("sort", torch.tensor(sort, dtype=torch.bool))
self._neg_ref_point = -ref_point
self.sort = torch.tensor(sort, dtype=torch.bool)
self.num_outcomes = ref_point.shape[-1]

if Y is not None:
self._update_neg_Y(Y=Y)
self.reset()
self._neg_Y = -Y
self._validate_inputs()
self._neg_pareto_Y = self._compute_pareto_Y()
self.partition_space()
else:
self._neg_Y = None
self._neg_pareto_Y = None

@property
def pareto_Y(self) -> Tensor:
Expand All @@ -65,9 +71,9 @@ def pareto_Y(self) -> Tensor:
Returns:
A `n_pareto x m`-dim tensor of outcomes.
"""
try:
if self._neg_pareto_Y is not None:
return -self._neg_pareto_Y
except AttributeError:
else:
raise BotorchError("pareto_Y has not been initialized")

@property
Expand All @@ -86,14 +92,14 @@ def Y(self) -> Tensor:
Returns:
A `n x m`-dim tensor of outcomes.
"""
return -self._neg_Y

def _reset_pareto_Y(self) -> bool:
r"""Update the non-dominated front.
if self._neg_Y is not None:
return -self._neg_Y
else:
raise BotorchError("Y data has not been initialized")

Returns:
A boolean indicating whether the Pareto frontier has changed.
"""
def _compute_pareto_Y(self) -> Tensor:
if self._neg_Y is None:
raise BotorchError("Y data has not been initialized")
# is_non_dominated assumes maximization
if self._neg_Y.shape[-2] == 0:
pareto_Y = self._neg_Y
Expand All @@ -116,11 +122,20 @@ def _reset_pareto_Y(self) -> bool:
)
else:
pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])]
return pareto_Y

def _reset_pareto_Y(self) -> bool:
r"""Update the non-dominated front.
if not hasattr(self, "_neg_pareto_Y") or not torch.equal(
Returns:
A boolean indicating whether the Pareto frontier has changed.
"""
pareto_Y = self._compute_pareto_Y()

if (self._neg_pareto_Y is None) or not torch.equal(
pareto_Y, self._neg_pareto_Y
):
self.register_buffer("_neg_pareto_Y", pareto_Y)
self._neg_pareto_Y = pareto_Y
return True
return False

Expand All @@ -139,13 +154,12 @@ def _partition_space_2d(self) -> None:
raise NotImplementedError

@abstractmethod
def _partition_space(self):
def _partition_space(self) -> None:
r"""Partition the non-dominated space into disjoint hypercells.
This method supports an arbitrary number of outcomes, but is
less efficient than `partition_space_2d` for the 2-outcome case.
"""
pass # pragma: no cover

@abstractmethod
def get_hypercell_bounds(self) -> Tensor:
Expand All @@ -155,7 +169,6 @@ def get_hypercell_bounds(self) -> Tensor:
A `2 x num_cells x num_outcomes`-dim tensor containing the
lower and upper vertices bounding each hypercell.
"""
pass # pragma: no cover

def _update_neg_Y(self, Y: Tensor) -> bool:
r"""Update the set of outcomes.
Expand All @@ -164,11 +177,11 @@ def _update_neg_Y(self, Y: Tensor) -> bool:
A boolean indicating if _neg_Y was initialized.
"""
# multiply by -1, since internally we minimize.
try:
if self._neg_Y is not None:
self._neg_Y = torch.cat([self._neg_Y, -Y], dim=-2)
return False
except AttributeError:
self.register_buffer("_neg_Y", -Y)
else:
self._neg_Y = -Y
return True

def update(self, Y: Tensor) -> None:
Expand All @@ -183,8 +196,7 @@ def update(self, Y: Tensor) -> None:
self._update_neg_Y(Y=Y)
self.reset()

def reset(self) -> None:
r"""Reset non-dominated front and decomposition."""
def _validate_inputs(self) -> None:
self.batch_shape = self.Y.shape[:-2]
self.num_outcomes = self.Y.shape[-1]
if len(self.batch_shape) > 1:
Expand All @@ -198,20 +210,36 @@ def reset(self) -> None:
f"{type(self).__name__} only supports a batched box "
f"decompositions in the 2-objective setting."
)

def reset(self) -> None:
r"""Reset non-dominated front and decomposition."""
self._validate_inputs()
is_new_pareto = self._reset_pareto_Y()
# Update decomposition if the Pareto front changed
if is_new_pareto:
self.partition_space()

@abstractmethod
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
"""Compute hypervolume for the case that there is data in self._neg_pareto_Y."""

def compute_hypervolume(self) -> Tensor:
r"""Compute hypervolume that is dominated by the Pareto Froniter.
Returns:
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""
pass # pragma: no cover
if self._neg_pareto_Y is None:
return torch.tensor(0.0)

if self._neg_pareto_Y.shape[-2] == 0:
return torch.zeros(
self._neg_pareto_Y.shape[:-2],
dtype=self._neg_pareto_Y.dtype,
device=self._neg_pareto_Y.device,
)
return self._compute_hypervolume_if_y_has_data()


class FastPartitioning(BoxDecomposition, ABC):
Expand Down
16 changes: 3 additions & 13 deletions botorch/utils/multi_objective/box_decompositions/dominated.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from __future__ import annotations

import torch
from botorch.utils.multi_objective.box_decompositions.box_decomposition import (
FastPartitioning,
)
Expand Down Expand Up @@ -39,7 +38,7 @@ def _partition_space_2d(self) -> None:
pareto_Y_sorted=self.pareto_Y.flip(-2),
ref_point=self.ref_point,
)
self.register_buffer("hypercell_bounds", cell_bounds)
self.hypercell_bounds = cell_bounds

def _get_partitioning(self) -> None:
r"""Get the bounds of each hypercell in the decomposition."""
Expand All @@ -49,22 +48,13 @@ def _get_partitioning(self) -> None:
cell_bounds = -minimization_cell_bounds.flip(0)
self.register_buffer("hypercell_bounds", cell_bounds)

def compute_hypervolume(self) -> Tensor:
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
r"""Compute hypervolume that is dominated by the Pareto Frontier.
Returns:
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""
if not hasattr(self, "_neg_pareto_Y"):
return torch.tensor(0.0).to(self._neg_ref_point)

if self._neg_pareto_Y.shape[-2] == 0:
return torch.zeros(
self._neg_pareto_Y.shape[:-2],
dtype=self._neg_pareto_Y.dtype,
device=self._neg_pareto_Y.device,
)
return (
(self.hypercell_bounds[1] - self.hypercell_bounds[0])
.prod(dim=-1)
Expand All @@ -77,4 +67,4 @@ def _get_single_cell(self) -> None:
cell_bounds = self.ref_point.expand(
2, *self._neg_pareto_Y.shape[:-2], 1, self.num_outcomes
).clone()
self.register_buffer("hypercell_bounds", cell_bounds)
self.hypercell_bounds = cell_bounds
53 changes: 9 additions & 44 deletions botorch/utils/multi_objective/box_decompositions/non_dominated.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,8 @@ def _partition_space(self) -> None:
# hypercells contains the indices of the (augmented) Pareto front
# that specify that bounds of the each hypercell.
# It is a `2 x num_cells x m`-dim tensor
self.register_buffer(
"hypercells",
torch.empty(
2, 0, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
),
self.hypercells = torch.empty(
2, 0, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
)
outcome_idxr = torch.arange(
self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
Expand Down Expand Up @@ -216,7 +213,7 @@ def _partition_space_2d(self) -> None:
dim=-1,
)
# 2 x batch_shape x n_cells x 2
self.register_buffer("hypercells", torch.stack([lower, upper], dim=0))
self.hypercells = torch.stack([lower, upper], dim=0)

def _get_augmented_pareto_front_indices(self) -> Tensor:
r"""Get indices of augmented Pareto front."""
Expand Down Expand Up @@ -337,25 +334,7 @@ def _get_hypercell_bounds(self, aug_pareto_Y: Tensor) -> Tensor:
view_shape = (2, *self.batch_shape, num_cells, self.num_outcomes)
return cell_bounds_values.view(view_shape)

def compute_hypervolume(self) -> Tensor:
r"""Compute the hypervolume for the given reference point.
This method computes the hypervolume of the non-dominated space
and computes the difference between the hypervolume between the
ideal point and hypervolume of the non-dominated space.
Returns:
`(batch_shape)`-dim tensor containing the dominated hypervolume.
"""
if not hasattr(self, "_neg_pareto_Y"):
return torch.tensor(0.0).to(self._neg_ref_point)

if self._neg_pareto_Y.shape[-2] == 0:
return torch.zeros(
self._neg_pareto_Y.shape[:-2],
dtype=self._neg_pareto_Y.dtype,
device=self._neg_pareto_Y.device,
)
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
ref_point = _expand_ref_point(
ref_point=self.ref_point, batch_shape=self.batch_shape
)
Expand Down Expand Up @@ -413,7 +392,7 @@ def _get_single_cell(self) -> None:
device=self._neg_pareto_Y.device,
)
cell_bounds[0] = self.ref_point
self.register_buffer("hypercell_bounds", cell_bounds)
self.hypercell_bounds = cell_bounds

def _get_partitioning(self) -> None:
r"""Compute non-dominated partitioning.
Expand All @@ -432,7 +411,7 @@ def _get_partitioning(self) -> None:
device=self._neg_ref_point.device,
)
# initialize local upper bounds for the second minimization problem
self.register_buffer("_U2", new_ref_point)
self._U2 = new_ref_point
# initialize defining points for the second minimization problem
# use ref point for maximization as the ideal point for minimization.
self._Z2 = self.ref_point.expand(
Expand All @@ -450,7 +429,7 @@ def _get_partitioning(self) -> None:
cell_bounds = get_partition_bounds(
Z=self._Z2, U=self._U2, ref_point=new_ref_point.view(-1)
)
self.register_buffer("hypercell_bounds", cell_bounds)
self.hypercell_bounds = cell_bounds

def _partition_space_2d(self) -> None:
r"""Partition the non-dominated space into disjoint hypercells.
Expand All @@ -461,23 +440,9 @@ def _partition_space_2d(self) -> None:
pareto_Y_sorted=self.pareto_Y.flip(-2),
ref_point=self.ref_point,
)
self.register_buffer("hypercell_bounds", cell_bounds)

def compute_hypervolume(self) -> Tensor:
r"""Compute hypervolume that is dominated by the Pareto Froniter.
self.hypercell_bounds = cell_bounds

Returns:
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""
if not hasattr(self, "_neg_pareto_Y"):
return torch.tensor(0.0).to(self._neg_ref_point)
if self._neg_pareto_Y.shape[-2] == 0:
return torch.zeros(
self._neg_pareto_Y.shape[:-2],
dtype=self._neg_pareto_Y.dtype,
device=self._neg_pareto_Y.device,
)
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
ideal_point = self.pareto_Y.max(dim=-2, keepdim=True).values
total_volume = (
(ideal_point.squeeze(-2) - self.ref_point).clamp_min(0.0).prod(dim=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DummyBoxDecomposition(BoxDecomposition):
def _partition_space(self):
pass

def compute_hypervolume(self):
def _compute_hypervolume_if_y_has_data(self):
pass

def get_hypercell_bounds(self):
Expand Down Expand Up @@ -66,7 +66,7 @@ def setUp(self):
device=self.device,
)

def test_box_decomposition(self):
def test_box_decomposition(self) -> None:
with self.assertRaises(TypeError):
BoxDecomposition()
for dtype, m, sort in product(
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_fast_partitioning(self):
DummyFastPartitioning(ref_point=ref_point, Y=Y.unsqueeze(0))


class TestBoxDecomposition_Hypervolume(BotorchTestCase):
class TestBoxDecomposition_no_set_up(BotorchTestCase):
def helper_hypervolume(self, Box_Decomp_cls: type) -> None:
"""
This test should be run for each non-abstract subclass of `BoxDecomposition`.
Expand All @@ -292,7 +292,6 @@ def helper_hypervolume(self, Box_Decomp_cls: type) -> None:

box_decomp = Box_Decomp_cls(ref_point=ref_point, Y=Y)
hv = box_decomp.compute_hypervolume()

self.assertEqual(hv.shape, ())
self.assertTrue(torch.allclose(hv, torch.tensor(1.0)))

Expand All @@ -316,3 +315,11 @@ def test_hypervolume(self) -> None:
FastNondominatedPartitioning,
]:
self.helper_hypervolume(cl)

def test_uninitialized_y(self) -> None:
ref_point = torch.zeros(2)
box_decomp = NondominatedPartitioning(ref_point=ref_point)
with self.assertRaises(BotorchError):
box_decomp.Y
with self.assertRaises(BotorchError):
box_decomp._compute_pareto_Y()

0 comments on commit 95fe8bc

Please sign in to comment.