Skip to content

Commit

Permalink
Fix box decomposition behavior with empty or None Y (pytorch#1489)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1489

See T126108893

Differential Revision: D41170229

fbshipit-source-id: cde97e88fcd9470cd6f225f551809cea09faffa1
  • Loading branch information
esantorella authored and facebook-github-bot committed Nov 10, 2022
1 parent 1484c89 commit 30bf8ed
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
3 changes: 3 additions & 0 deletions botorch/utils/multi_objective/box_decompositions/dominated.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def compute_hypervolume(self) -> Tensor:
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""
if not hasattr(self, "_neg_pareto_Y"):
return torch.tensor(0.0).to(self._neg_ref_point)

if self._neg_pareto_Y.shape[-2] == 0:
return torch.zeros(
self._neg_pareto_Y.shape[:-2],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ def compute_hypervolume(self) -> Tensor:
Returns:
`(batch_shape)`-dim tensor containing the dominated hypervolume.
"""
if not hasattr(self, "_neg_pareto_Y"):
return torch.tensor(0.0).to(self._neg_ref_point)

if self._neg_pareto_Y.shape[-2] == 0:
return torch.zeros(
self._neg_pareto_Y.shape[:-2],
Expand Down Expand Up @@ -460,13 +463,15 @@ def _partition_space_2d(self) -> None:
)
self.register_buffer("hypercell_bounds", cell_bounds)

def compute_hypervolume(self):
def compute_hypervolume(self) -> Tensor:
r"""Compute hypervolume that is dominated by the Pareto Froniter.
Returns:
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""
if not hasattr(self, "_neg_pareto_Y"):
return torch.tensor(0.0).to(self._neg_ref_point)
if self._neg_pareto_Y.shape[-2] == 0:
return torch.zeros(
self._neg_pareto_Y.shape[:-2],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
BoxDecomposition,
FastPartitioning,
)
from botorch.utils.multi_objective.box_decompositions.dominated import (
DominatedPartitioning,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
FastNondominatedPartitioning,
NondominatedPartitioning,
)
from botorch.utils.multi_objective.box_decompositions.utils import (
update_local_upper_bounds_incremental,
)
Expand Down Expand Up @@ -262,3 +269,49 @@ def test_fast_partitioning(self):
if m == 2:
with self.assertRaises(NotImplementedError):
DummyFastPartitioning(ref_point=ref_point, Y=Y.unsqueeze(0))

def helper_hypervolume(self, Box_Decomp_cls: type) -> None:
"""
This test should be run for each non-abstract subclass of `BoxDecomposition`.
"""
# batching
n_outcomes, batch_dim, n = 2, 3, 4

ref_point = torch.zeros(n_outcomes)
Y = torch.ones(batch_dim, n, n_outcomes)

box_decomp = Box_Decomp_cls(ref_point=ref_point, Y=Y)
hv = box_decomp.compute_hypervolume()
self.assertEqual(hv.shape, (batch_dim,))
self.assertTrue(torch.allclose(hv, torch.ones(batch_dim)))

# no batching
Y = torch.ones(n, n_outcomes)

box_decomp = Box_Decomp_cls(ref_point=ref_point, Y=Y)
hv = box_decomp.compute_hypervolume()

self.assertEqual(hv.shape, ())
self.assertTrue(torch.allclose(hv, torch.tensor(1.0)))

# cases where there is nothing in Y, either because n=0 or Y is None
n = 0
Y_and_expected_shape = [
(torch.ones(batch_dim, n, n_outcomes), (batch_dim,)),
(torch.ones(n, n_outcomes), ()),
(None, ()),
]
for Y, expected_shape in Y_and_expected_shape:
box_decomp = Box_Decomp_cls(ref_point=ref_point, Y=Y)
hv = box_decomp.compute_hypervolume()
self.assertEqual(hv.shape, expected_shape)
self.assertTrue(torch.allclose(hv, torch.tensor(0.0)))

def test_hypervolume_nondom(self) -> None:
self.helper_hypervolume(NondominatedPartitioning)

def test_hypervolume_y_dom(self) -> None:
self.helper_hypervolume(DominatedPartitioning)

def test_hypervolume_fast(self) -> None:
self.helper_hypervolume(FastNondominatedPartitioning)

0 comments on commit 30bf8ed

Please sign in to comment.