Skip to content

Commit

Permalink
Let Pyre know that AcquisitionFunction.model is a Model (#1216)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#1216

## Motivation

Pyre is not smart enough to understand that calling `self.add_module('model', model)` makes `self.model` have the type of `model`, which is true due to some fairly complex underlying logic inherited from `torch.nn.Module`. However, PyTorch is smart enough to properly `add_module` if we just do `self.model = model`. This also works for tensors, but only if the tensor is explicitly registered as a buffer (by name, not necessarily by value) before assignment.

### Have you read the [Contributing Guidelines on pull requests]
Yes

Pull Request resolved: #1452

Test Plan:
- Unit tests should be unaffected
- Pyre error count drops from 1379 to 1309 (down 5%).
- Added explicit tests that `_modules` and `_buffers` are properly initialized

Reviewed By: Balandat

Differential Revision: D40469725

Pulled By: esantorella

fbshipit-source-id: 531cec5b77fc74faf478c4c96f1ceaa596ca8162
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 24, 2022
1 parent 8441d62 commit bc39aa1
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 14 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, model: Model) -> None:
model: A fitted model.
"""
super().__init__()
self.add_module("model", model)
self.model: Model = model

@classmethod
def _deprecate_acqf_objective(
Expand Down
23 changes: 16 additions & 7 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,22 @@ def _preprocess_constraint_bounds(
con_upper_inds.append(k)
con_upper.append(constraints[k][1])
# tensor-based indexing is much faster than list-based advanced indexing
self.register_buffer("con_lower_inds", torch.tensor(con_lower_inds))
self.register_buffer("con_upper_inds", torch.tensor(con_upper_inds))
self.register_buffer("con_both_inds", torch.tensor(con_both_inds))
# tensor indexing
self.register_buffer("con_both", torch.tensor(con_both, dtype=torch.float))
self.register_buffer("con_lower", torch.tensor(con_lower, dtype=torch.float))
self.register_buffer("con_upper", torch.tensor(con_upper, dtype=torch.float))
for k in [
"con_lower_inds",
"con_upper_inds",
"con_both_inds",
"con_both",
"con_lower",
"con_upper",
]:
self.register_buffer(k, tensor=None)

self.con_lower_inds = torch.tensor(con_lower_inds)
self.con_upper_inds = torch.tensor(con_upper_inds)
self.con_both_inds = torch.tensor(con_both_inds)
self.con_both = torch.tensor(con_both)
self.con_lower = torch.tensor(con_lower)
self.con_upper = torch.tensor(con_upper)

def _compute_prob_feas(self, X: Tensor, means: Tensor, sigmas: Tensor) -> Tensor:
r"""Compute feasibility probability for each batch of X.
Expand Down
9 changes: 5 additions & 4 deletions botorch/acquisition/knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,13 @@ def __init__(
"If using a multi-output model without an objective, "
"posterior_transform must scalarize the output."
)
self.sampler = sampler
self.sampler: MCSampler = sampler
self.objective = objective
self.posterior_transform = posterior_transform
self.set_X_pending(X_pending)
self.X_pending: Tensor = self.X_pending
self.inner_sampler = inner_sampler
self.num_fantasies = num_fantasies
self.num_fantasies: int = num_fantasies
self.current_value = current_value

@t_batch_mode_transform()
Expand Down Expand Up @@ -338,7 +339,7 @@ def __init__(
project: Callable[[Tensor], Tensor] = lambda X: X,
expand: Callable[[Tensor], Tensor] = lambda X: X,
valfunc_cls: Optional[Type[AcquisitionFunction]] = None,
valfunc_argfac: Optional[Callable[[Model, Dict[str, Any]]]] = None,
valfunc_argfac: Optional[Callable[[Model], Dict[str, Any]]] = None,
**kwargs: Any,
) -> None:
r"""Multi-Fidelity q-Knowledge Gradient (one-shot optimization).
Expand Down Expand Up @@ -529,7 +530,7 @@ def _get_value_function(
sampler: Optional[MCSampler] = None,
project: Optional[Callable[[Tensor], Tensor]] = None,
valfunc_cls: Optional[Type[AcquisitionFunction]] = None,
valfunc_argfac: Optional[Callable[[Model, Dict[str, Any]]]] = None,
valfunc_argfac: Optional[Callable[[Model], Dict[str, Any]]] = None,
) -> AcquisitionFunction:
r"""Construct value function (i.e. inner acquisition function)."""
if valfunc_cls is not None:
Expand Down
4 changes: 2 additions & 2 deletions botorch/acquisition/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
super().__init__(model=model)
if sampler is None:
sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)
self.add_module("sampler", sampler)
self.sampler: MCSampler = sampler
if objective is None and model.num_outputs != 1:
if posterior_transform is None:
raise UnsupportedError(
Expand All @@ -91,7 +91,7 @@ def __init__(
if objective is None:
objective = IdentityMCObjective()
self.posterior_transform = posterior_transform
self.add_module("objective", objective)
self.objective: MCAcquisitionObjective = objective
self.set_X_pending(X_pending)

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions test/acquisition/multi_objective/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def test_q_expected_hypervolume_improvement(self):
samples2 = torch.zeros(1, 2, 2, **tkwargs)
mm2 = MockModel(MockPosterior(samples=samples2))
acqf.model = mm2
self.assertEqual(acqf.model, mm2)
self.assertIn("model", acqf._modules)
self.assertEqual(acqf._modules["model"], mm2)
res = acqf(X2)
self.assertEqual(res.item(), 0.0)
# check cached indices
Expand Down
1 change: 1 addition & 0 deletions test/acquisition/multi_objective/test_multi_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_momf(self):
samples2 = torch.zeros(1, 2, 2, **tkwargs)
mm2 = MockModel(MockPosterior(samples=samples2))
acqf.model = mm2
self.assertEqual(acqf.model, mm2)
res = acqf(X2)
self.assertEqual(res.item(), 0.0)
# check cached indices
Expand Down
11 changes: 11 additions & 0 deletions test/acquisition/test_analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,17 @@ def test_constrained_expected_improvement(self):
module = ConstrainedExpectedImprovement(
model=mm, best_f=0.0, objective_index=0, constraints={1: [None, 0]}
)
# test initialization
for k in [
"con_lower_inds",
"con_upper_inds",
"con_both_inds",
"con_both",
"con_lower",
"con_upper",
]:
self.assertIn(k, module._buffers)

X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
ei = module(X)
ei_expected_unconstrained = torch.tensor(
Expand Down
4 changes: 4 additions & 0 deletions test/acquisition/test_knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ def test_get_value_function(self):
mm = MockModel(None)
# test PosteriorMean
vf = _get_value_function(mm)
# test initialization
self.assertIn("model", vf._modules)
self.assertEqual(vf._modules["model"], mm)

self.assertIsInstance(vf, PosteriorMean)
self.assertIsNone(vf.posterior_transform)
# test SimpleRegret
Expand Down
4 changes: 4 additions & 0 deletions test/acquisition/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def test_q_expected_improvement(self):
# basic test
sampler = IIDNormalSampler(num_samples=2)
acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
# test initialization
for k in ["objective", "sampler"]:
self.assertIn(k, acqf._modules)

res = acqf(X)
self.assertEqual(res.item(), 0.0)

Expand Down

0 comments on commit bc39aa1

Please sign in to comment.