From bc39aa11e1628b5e7c3658378e8c93d5e9bd70d2 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Mon, 24 Oct 2022 13:48:59 -0700 Subject: [PATCH] Let Pyre know that `AcquisitionFunction.model` is a `Model` (#1216) Summary: X-link: https://github.com/facebook/Ax/pull/1216 ## Motivation Pyre is not smart enough to understand that calling `self.add_module('model', model)` makes `self.model` have the type of `model`, which is true due to some fairly complex underlying logic inherited from `torch.nn.Module`. However, PyTorch is smart enough to properly `add_module` if we just do `self.model = model`. This also works for tensors, but only if the tensor is explicitly registered as a buffer (by name, not necessarily by value) before assignment. ### Have you read the [Contributing Guidelines on pull requests] Yes Pull Request resolved: https://github.com/pytorch/botorch/pull/1452 Test Plan: - Unit tests should be unaffected - Pyre error count drops from 1379 to 1309 (down 5%). - Added explicit tests that `_modules` and `_buffers` are properly initialized Reviewed By: Balandat Differential Revision: D40469725 Pulled By: esantorella fbshipit-source-id: 531cec5b77fc74faf478c4c96f1ceaa596ca8162 --- botorch/acquisition/acquisition.py | 2 +- botorch/acquisition/analytic.py | 23 +++++++++++++------ botorch/acquisition/knowledge_gradient.py | 9 ++++---- botorch/acquisition/monte_carlo.py | 4 ++-- .../multi_objective/test_monte_carlo.py | 3 +++ .../multi_objective/test_multi_fidelity.py | 1 + test/acquisition/test_analytic.py | 11 +++++++++ test/acquisition/test_knowledge_gradient.py | 4 ++++ test/acquisition/test_monte_carlo.py | 4 ++++ 9 files changed, 47 insertions(+), 14 deletions(-) diff --git a/botorch/acquisition/acquisition.py b/botorch/acquisition/acquisition.py index a3c018bd21..f0544529bd 100644 --- a/botorch/acquisition/acquisition.py +++ b/botorch/acquisition/acquisition.py @@ -36,7 +36,7 @@ def __init__(self, model: Model) -> None: model: A fitted model. """ super().__init__() - self.add_module("model", model) + self.model: Model = model @classmethod def _deprecate_acqf_objective( diff --git a/botorch/acquisition/analytic.py b/botorch/acquisition/analytic.py index 2f59da82a9..a8508b6cb5 100644 --- a/botorch/acquisition/analytic.py +++ b/botorch/acquisition/analytic.py @@ -462,13 +462,22 @@ def _preprocess_constraint_bounds( con_upper_inds.append(k) con_upper.append(constraints[k][1]) # tensor-based indexing is much faster than list-based advanced indexing - self.register_buffer("con_lower_inds", torch.tensor(con_lower_inds)) - self.register_buffer("con_upper_inds", torch.tensor(con_upper_inds)) - self.register_buffer("con_both_inds", torch.tensor(con_both_inds)) - # tensor indexing - self.register_buffer("con_both", torch.tensor(con_both, dtype=torch.float)) - self.register_buffer("con_lower", torch.tensor(con_lower, dtype=torch.float)) - self.register_buffer("con_upper", torch.tensor(con_upper, dtype=torch.float)) + for k in [ + "con_lower_inds", + "con_upper_inds", + "con_both_inds", + "con_both", + "con_lower", + "con_upper", + ]: + self.register_buffer(k, tensor=None) + + self.con_lower_inds = torch.tensor(con_lower_inds) + self.con_upper_inds = torch.tensor(con_upper_inds) + self.con_both_inds = torch.tensor(con_both_inds) + self.con_both = torch.tensor(con_both) + self.con_lower = torch.tensor(con_lower) + self.con_upper = torch.tensor(con_upper) def _compute_prob_feas(self, X: Tensor, means: Tensor, sigmas: Tensor) -> Tensor: r"""Compute feasibility probability for each batch of X. diff --git a/botorch/acquisition/knowledge_gradient.py b/botorch/acquisition/knowledge_gradient.py index be0d3c4256..6f80ebf9f3 100644 --- a/botorch/acquisition/knowledge_gradient.py +++ b/botorch/acquisition/knowledge_gradient.py @@ -150,12 +150,13 @@ def __init__( "If using a multi-output model without an objective, " "posterior_transform must scalarize the output." ) - self.sampler = sampler + self.sampler: MCSampler = sampler self.objective = objective self.posterior_transform = posterior_transform self.set_X_pending(X_pending) + self.X_pending: Tensor = self.X_pending self.inner_sampler = inner_sampler - self.num_fantasies = num_fantasies + self.num_fantasies: int = num_fantasies self.current_value = current_value @t_batch_mode_transform() @@ -338,7 +339,7 @@ def __init__( project: Callable[[Tensor], Tensor] = lambda X: X, expand: Callable[[Tensor], Tensor] = lambda X: X, valfunc_cls: Optional[Type[AcquisitionFunction]] = None, - valfunc_argfac: Optional[Callable[[Model, Dict[str, Any]]]] = None, + valfunc_argfac: Optional[Callable[[Model], Dict[str, Any]]] = None, **kwargs: Any, ) -> None: r"""Multi-Fidelity q-Knowledge Gradient (one-shot optimization). @@ -529,7 +530,7 @@ def _get_value_function( sampler: Optional[MCSampler] = None, project: Optional[Callable[[Tensor], Tensor]] = None, valfunc_cls: Optional[Type[AcquisitionFunction]] = None, - valfunc_argfac: Optional[Callable[[Model, Dict[str, Any]]]] = None, + valfunc_argfac: Optional[Callable[[Model], Dict[str, Any]]] = None, ) -> AcquisitionFunction: r"""Construct value function (i.e. inner acquisition function).""" if valfunc_cls is not None: diff --git a/botorch/acquisition/monte_carlo.py b/botorch/acquisition/monte_carlo.py index aa70187fdd..5aa0ec081f 100644 --- a/botorch/acquisition/monte_carlo.py +++ b/botorch/acquisition/monte_carlo.py @@ -76,7 +76,7 @@ def __init__( super().__init__(model=model) if sampler is None: sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True) - self.add_module("sampler", sampler) + self.sampler: MCSampler = sampler if objective is None and model.num_outputs != 1: if posterior_transform is None: raise UnsupportedError( @@ -91,7 +91,7 @@ def __init__( if objective is None: objective = IdentityMCObjective() self.posterior_transform = posterior_transform - self.add_module("objective", objective) + self.objective: MCAcquisitionObjective = objective self.set_X_pending(X_pending) @abstractmethod diff --git a/test/acquisition/multi_objective/test_monte_carlo.py b/test/acquisition/multi_objective/test_monte_carlo.py index 4b6f6b4043..583fd708ba 100644 --- a/test/acquisition/multi_objective/test_monte_carlo.py +++ b/test/acquisition/multi_objective/test_monte_carlo.py @@ -159,6 +159,9 @@ def test_q_expected_hypervolume_improvement(self): samples2 = torch.zeros(1, 2, 2, **tkwargs) mm2 = MockModel(MockPosterior(samples=samples2)) acqf.model = mm2 + self.assertEqual(acqf.model, mm2) + self.assertIn("model", acqf._modules) + self.assertEqual(acqf._modules["model"], mm2) res = acqf(X2) self.assertEqual(res.item(), 0.0) # check cached indices diff --git a/test/acquisition/multi_objective/test_multi_fidelity.py b/test/acquisition/multi_objective/test_multi_fidelity.py index 43986341ae..3891aa32ac 100644 --- a/test/acquisition/multi_objective/test_multi_fidelity.py +++ b/test/acquisition/multi_objective/test_multi_fidelity.py @@ -73,6 +73,7 @@ def test_momf(self): samples2 = torch.zeros(1, 2, 2, **tkwargs) mm2 = MockModel(MockPosterior(samples=samples2)) acqf.model = mm2 + self.assertEqual(acqf.model, mm2) res = acqf(X2) self.assertEqual(res.item(), 0.0) # check cached indices diff --git a/test/acquisition/test_analytic.py b/test/acquisition/test_analytic.py index 7287c8b29c..578ba30781 100644 --- a/test/acquisition/test_analytic.py +++ b/test/acquisition/test_analytic.py @@ -344,6 +344,17 @@ def test_constrained_expected_improvement(self): module = ConstrainedExpectedImprovement( model=mm, best_f=0.0, objective_index=0, constraints={1: [None, 0]} ) + # test initialization + for k in [ + "con_lower_inds", + "con_upper_inds", + "con_both_inds", + "con_both", + "con_lower", + "con_upper", + ]: + self.assertIn(k, module._buffers) + X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy ei = module(X) ei_expected_unconstrained = torch.tensor( diff --git a/test/acquisition/test_knowledge_gradient.py b/test/acquisition/test_knowledge_gradient.py index 93ea43f6f9..01fad603ed 100644 --- a/test/acquisition/test_knowledge_gradient.py +++ b/test/acquisition/test_knowledge_gradient.py @@ -610,6 +610,10 @@ def test_get_value_function(self): mm = MockModel(None) # test PosteriorMean vf = _get_value_function(mm) + # test initialization + self.assertIn("model", vf._modules) + self.assertEqual(vf._modules["model"], mm) + self.assertIsInstance(vf, PosteriorMean) self.assertIsNone(vf.posterior_transform) # test SimpleRegret diff --git a/test/acquisition/test_monte_carlo.py b/test/acquisition/test_monte_carlo.py index 0ef0d37beb..4d265b6086 100644 --- a/test/acquisition/test_monte_carlo.py +++ b/test/acquisition/test_monte_carlo.py @@ -83,6 +83,10 @@ def test_q_expected_improvement(self): # basic test sampler = IIDNormalSampler(num_samples=2) acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler) + # test initialization + for k in ["objective", "sampler"]: + self.assertIn(k, acqf._modules) + res = acqf(X) self.assertEqual(res.item(), 0.0)