From c4cead97cd4efcad4781416491ae8d51f075a451 Mon Sep 17 00:00:00 2001 From: David Eriksson Date: Fri, 1 Dec 2023 08:19:07 -0800 Subject: [PATCH] Change DerelativizeTransform to not use model predictions when `use_raw_status_quo` is `True` or when the status quo is infeasible. Summary: The was previously ignoring `use_raw_status_quo` and was using the values predicted by the model unless this errored out. This is misleading since we shouldn't be using the model predictions when `use_raw_status_quo` is `True`. This also resulted in weird behavior in the case where the status quo arm was within the search space bounds but didn't satisfy the parameter constraints. This transform would then use the model to predict the metrics of the status quo (ignoring that `use_raw_status_quo` was `True`), but the model wasn't trained on the status quo since it didn't satisfy the constraints. Thus, this resulted in the model having to extrapolate and producing very weird predictions. This diff changes the behavior to only use the model to predict the status quo metrics when (1) `use_raw_status_quo` is `False` and (2) the status quo is actually feasible. Reviewed By: Balandat Differential Revision: D51690727 --- .../tests/test_derelativize_transform.py | 28 ++++++++-------- ax/modelbridge/transforms/derelativize.py | 32 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/ax/modelbridge/tests/test_derelativize_transform.py b/ax/modelbridge/tests/test_derelativize_transform.py index 270e559bd4a..3874e4a9cbe 100644 --- a/ax/modelbridge/tests/test_derelativize_transform.py +++ b/ax/modelbridge/tests/test_derelativize_transform.py @@ -6,7 +6,7 @@ from copy import deepcopy from unittest import mock -from unittest.mock import patch +from unittest.mock import Mock, patch import numpy as np from ax.core.data import Data @@ -75,20 +75,13 @@ def setUp(self) -> None: ] ), ) - # pyre-fixme[3]: Return type must be annotated. def test_DerelativizeTransform( self, - # pyre-fixme[2]: Parameter must be annotated. - mock_predict, - # pyre-fixme[2]: Parameter must be annotated. - mock_fit, - # pyre-fixme[2]: Parameter must be annotated. - mock_observations_from_data, - ): - t = Derelativize( - search_space=None, - observations=[], - ) + mock_predict: Mock, + mock_fit: Mock, + mock_observations_from_data: Mock, + ) -> None: + t = Derelativize(search_space=None, observations=[]) # ModelBridge with in-design status quo search_space = SearchSpace( @@ -167,6 +160,13 @@ def test_DerelativizeTransform( obsf = mock_predict.mock_calls[0][1][1][0] obsf2 = ObservationFeatures(parameters={"x": 2.0, "y": 10.0}) self.assertTrue(obsf == obsf2) + self.assertEqual(mock_predict.call_count, 1) + + # The model should not be used when `use_raw_status_quo` is True + t2 = deepcopy(t) + t2.config["use_raw_status_quo"] = True + t2.transform_optimization_config(deepcopy(oc), g, None) + self.assertEqual(mock_predict.call_count, 1) # Test with relative constraint, out-of-design status quo mock_predict.side_effect = RuntimeError() @@ -215,7 +215,7 @@ def test_DerelativizeTransform( ), ] ) - self.assertEqual(mock_predict.call_count, 2) + self.assertEqual(mock_predict.call_count, 1) # Raises error if predict fails with in-design status quo g = ModelBridge( diff --git a/ax/modelbridge/transforms/derelativize.py b/ax/modelbridge/transforms/derelativize.py index 57b36385ddb..64563733c0d 100644 --- a/ax/modelbridge/transforms/derelativize.py +++ b/ax/modelbridge/transforms/derelativize.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from logging import Logger from typing import List, Optional, TYPE_CHECKING import numpy as np @@ -14,6 +15,7 @@ from ax.modelbridge.base import unwrap_observation_data from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.ivw import ivw_metric_merge +from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none @@ -22,6 +24,9 @@ from ax import modelbridge as modelbridge_module # noqa F401 +logger: Logger = get_logger(__name__) + + class Derelativize(Transform): """Changes relative constraints to not-relative constraints using a plug-in estimate of the status quo value. @@ -59,22 +64,17 @@ def transform_optimization_config( "Optimization config has relative constraint, but model was " "not fit with status quo." ) - try: - f, _ = modelbridge.predict([modelbridge.status_quo.features]) - except Exception: - # Check if it is out-of-design. - if use_raw_sq or not modelbridge.model_space.check_membership( - modelbridge.status_quo.features.parameters - ): - # Out-of-design: use the raw observation - sq_data = ivw_metric_merge( - obsd=not_none(modelbridge.status_quo).data, - conflicting_noiseless="raise", - ) - f, _ = unwrap_observation_data([sq_data]) - else: - # Should have worked. - raise + + sq = not_none(modelbridge.status_quo) + # Only use model predictions if the status quo is in the search space (including + # parameter constraints) and `use_raw_sq` is false. + if not use_raw_sq and modelbridge.model_space.check_membership( + sq.features.parameters + ): + f, _ = modelbridge.predict([sq.features]) + else: + sq_data = ivw_metric_merge(obsd=sq.data, conflicting_noiseless="raise") + f, _ = unwrap_observation_data([sq_data]) # Plug in the status quo value to each relative constraint. for c in optimization_config.all_constraints: