Skip to content

Commit

Permalink
Change DerelativizeTransform to not use model predictions when `use_r…
Browse files Browse the repository at this point in the history
…aw_status_quo` is `True` or when the status quo is infeasible. (#2036)

Summary:
Pull Request resolved: #2036

The was previously ignoring `use_raw_status_quo` and was using the values predicted by the model unless this errored out. This is misleading since we shouldn't be using the model predictions when `use_raw_status_quo` is `True`.

This also resulted in weird behavior in the case where the status quo arm was within the search space bounds but didn't satisfy the parameter constraints. This transform would then use the model to predict the metrics of the status quo (ignoring that `use_raw_status_quo` was `True`), but the model wasn't trained on the status quo since it didn't satisfy the constraints. Thus, this resulted in the model having to extrapolate and producing very weird predictions.

This diff changes the behavior to only use the model to predict the status quo metrics when (1) `use_raw_status_quo` is `False` and (2) the status quo is actually feasible.

Reviewed By: Balandat

Differential Revision: D51690727

fbshipit-source-id: 3f2648a0091c88e5e4c9c2f388421d3f563f6372
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Dec 1, 2023
1 parent 4e32365 commit 470811f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
32 changes: 16 additions & 16 deletions ax/modelbridge/transforms/derelativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from logging import Logger
from typing import List, Optional, TYPE_CHECKING

import numpy as np
Expand All @@ -14,6 +15,7 @@
from ax.modelbridge.base import unwrap_observation_data
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.ivw import ivw_metric_merge
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none


Expand All @@ -22,6 +24,9 @@
from ax import modelbridge as modelbridge_module # noqa F401


logger: Logger = get_logger(__name__)


class Derelativize(Transform):
"""Changes relative constraints to not-relative constraints using a plug-in
estimate of the status quo value.
Expand Down Expand Up @@ -59,22 +64,17 @@ def transform_optimization_config(
"Optimization config has relative constraint, but model was "
"not fit with status quo."
)
try:
f, _ = modelbridge.predict([modelbridge.status_quo.features])
except Exception:
# Check if it is out-of-design.
if use_raw_sq or not modelbridge.model_space.check_membership(
modelbridge.status_quo.features.parameters
):
# Out-of-design: use the raw observation
sq_data = ivw_metric_merge(
obsd=not_none(modelbridge.status_quo).data,
conflicting_noiseless="raise",
)
f, _ = unwrap_observation_data([sq_data])
else:
# Should have worked.
raise

sq = not_none(modelbridge.status_quo)
# Only use model predictions if the status quo is in the search space (including
# parameter constraints) and `use_raw_sq` is false.
if not use_raw_sq and modelbridge.model_space.check_membership(
sq.features.parameters
):
f, _ = modelbridge.predict([sq.features])
else:
sq_data = ivw_metric_merge(obsd=sq.data, conflicting_noiseless="raise")
f, _ = unwrap_observation_data([sq_data])

# Plug in the status quo value to each relative constraint.
for c in optimization_config.all_constraints:
Expand Down
28 changes: 14 additions & 14 deletions ax/modelbridge/transforms/tests/test_derelativize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from copy import deepcopy
from unittest import mock
from unittest.mock import patch
from unittest.mock import Mock, patch

import numpy as np
from ax.core.data import Data
Expand Down Expand Up @@ -75,20 +75,13 @@ def setUp(self) -> None:
]
),
)
# pyre-fixme[3]: Return type must be annotated.
def test_DerelativizeTransform(
self,
# pyre-fixme[2]: Parameter must be annotated.
mock_predict,
# pyre-fixme[2]: Parameter must be annotated.
mock_fit,
# pyre-fixme[2]: Parameter must be annotated.
mock_observations_from_data,
):
t = Derelativize(
search_space=None,
observations=[],
)
mock_predict: Mock,
mock_fit: Mock,
mock_observations_from_data: Mock,
) -> None:
t = Derelativize(search_space=None, observations=[])

# ModelBridge with in-design status quo
search_space = SearchSpace(
Expand Down Expand Up @@ -167,6 +160,13 @@ def test_DerelativizeTransform(
obsf = mock_predict.mock_calls[0][1][1][0]
obsf2 = ObservationFeatures(parameters={"x": 2.0, "y": 10.0})
self.assertTrue(obsf == obsf2)
self.assertEqual(mock_predict.call_count, 1)

# The model should not be used when `use_raw_status_quo` is True
t2 = deepcopy(t)
t2.config["use_raw_status_quo"] = True
t2.transform_optimization_config(deepcopy(oc), g, None)
self.assertEqual(mock_predict.call_count, 1)

# Test with relative constraint, out-of-design status quo
mock_predict.side_effect = RuntimeError()
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_DerelativizeTransform(
),
]
)
self.assertEqual(mock_predict.call_count, 2)
self.assertEqual(mock_predict.call_count, 1)

# Raises error if predict fails with in-design status quo
g = ModelBridge(
Expand Down

0 comments on commit 470811f

Please sign in to comment.