Skip to content

Commit

Permalink
modify PenalizedMCObjective to support non-batch eval
Browse files Browse the repository at this point in the history
Summary:
To match penalty term with MC objective, we current unsqueeze the first dim which corresponds to the dimension of MC samples. However, when a `qxd`-dim X tensor is evaluated e.g. computing feasibility, it causes shape mismatch. As one would expect `q`-dim tensor returned, it will return `1xq`-dim tensor instead.

To fix, we check the dims of obj; if it is non-mc samples, we will sequeeze the first dim back.

Differential Revision: D49305807
  • Loading branch information
qingfeng10 authored and facebook-github-bot committed Oct 26, 2023
1 parent d81a674 commit 307f98f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
if self.expand_dim is not None:
# reshape penalty_obj to match the dim
penalty_obj = penalty_obj.unsqueeze(self.expand_dim)
# this happens when samples is `q x m`-dim Tensors and X is `q x d`-dim tensor
# obj returned from GenericMCObjective is `q`-dim Tensor and penalty_obj is `1xq`-dim Tensor
if obj.ndim == 1:
assert penalty_obj.shape == torch.Size([1, samples.shape[-2]])
penalty_obj = penalty_obj.squeeze(dim=0)
return obj - self.regularization_parameter * penalty_obj


Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/test_penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_penalized_mc_objective(self):
samples = torch.randn(4, 3, device=self.device, dtype=dtype)
X = torch.randn(4, 5, device=self.device, dtype=dtype)
penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X)
self.assertTrue(torch.equal(obj(samples, X), penalized_obj))
self.assertTrue(torch.equal(obj(samples, X), penalized_obj.squeeze(0)))
# test 'q x d' Tensor X
samples = torch.randn(4, 2, 3, device=self.device, dtype=dtype)
X = torch.randn(2, 5, device=self.device, dtype=dtype)
Expand Down

0 comments on commit 307f98f

Please sign in to comment.