From 4466f6e4cb3d8a0e635201a23ecaf7be451777c9 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 4 Apr 2024 13:02:23 -0700 Subject: [PATCH] Rename OrderedChoiceEncode => OrderedChoiceToIntegerRange (#2323) Summary: This change renames the OrderedChoiceEncode transform to one which reflects its behavior- see T182722751 for the overall task. - Adds a new "OrderedChoiceToIntegerRange" class with the logic from the original OrderedChoiceEncode - Updates OrderedChoiceEncode to inherit from DeprecatedTransformMixin and OrderedChoiceToIntegerRange - Updates the registry to support the new transform. Initially, the new classes will be decoded into the deprecated classes to maintain backwards compatibility. Once the new classes are landed, call sites will be updated to use the new class instead of the old. Differential Revision: D55754487 --- ax/modelbridge/transforms/choice_encode.py | 14 ++++++-- .../tests/test_choice_encode_transform.py | 32 ++++++++++++++++--- ax/storage/transform_registry.py | 21 +++++++++--- 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/ax/modelbridge/transforms/choice_encode.py b/ax/modelbridge/transforms/choice_encode.py index 14391230f4d..325c771bf02 100644 --- a/ax/modelbridge/transforms/choice_encode.py +++ b/ax/modelbridge/transforms/choice_encode.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import numpy as np from ax.core.observation import Observation, ObservationFeatures @@ -14,6 +14,9 @@ from ax.core.search_space import SearchSpace from ax.core.types import TParamValue from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transforms.deprecated_transform_mixin import ( + DeprecatedTransformMixin, +) from ax.modelbridge.transforms.utils import ( ClosestLookupDict, construct_new_search_space, @@ -120,7 +123,7 @@ def untransform_observation_features( return observation_features -class OrderedChoiceEncode(ChoiceEncode): +class OrderedChoiceToIntegerRange(ChoiceEncode): """Convert ordered ChoiceParameters to integer RangeParameters. Parameters will be transformed to an integer RangeParameters, mapped from the @@ -187,6 +190,13 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: ) +class OrderedChoiceEncode(DeprecatedTransformMixin, OrderedChoiceToIntegerRange): + """Deprecated alias for OrderedChoiceToIntegerRange.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def transform_choice_values(p: ChoiceParameter) -> Tuple[np.ndarray, ParameterType]: """Transforms the choice values and returns the new parameter type. diff --git a/ax/modelbridge/transforms/tests/test_choice_encode_transform.py b/ax/modelbridge/transforms/tests/test_choice_encode_transform.py index 37425ddbf51..98a4e927d9f 100644 --- a/ax/modelbridge/transforms/tests/test_choice_encode_transform.py +++ b/ax/modelbridge/transforms/tests/test_choice_encode_transform.py @@ -14,7 +14,11 @@ from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.parameter_constraint import ParameterConstraint from ax.core.search_space import RobustSearchSpace, SearchSpace -from ax.modelbridge.transforms.choice_encode import ChoiceEncode, OrderedChoiceEncode +from ax.modelbridge.transforms.choice_encode import ( + ChoiceEncode, + OrderedChoiceEncode, + OrderedChoiceToIntegerRange, +) from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import get_robust_search_space @@ -166,7 +170,7 @@ def test_TransformSearchSpace(self) -> None: ) ] ) - t = OrderedChoiceEncode(search_space=ss3, observations=[]) + t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[]) with self.assertRaises(ValueError): t.transform_search_space(ss3) @@ -208,8 +212,8 @@ def test_w_parameter_distributions(self) -> None: self.assertEqual(rss_new.parameters.get("c").parameter_type, ParameterType.INT) -class OrderedChoiceEncodeTransformTest(ChoiceEncodeTransformTest): - t_class = OrderedChoiceEncode +class OrderedChoiceToIntegerRangeTransformTest(ChoiceEncodeTransformTest): + t_class = OrderedChoiceToIntegerRange def setUp(self) -> None: super().setUp() @@ -258,10 +262,28 @@ def test_TransformSearchSpace(self) -> None: ) ] ) - t = OrderedChoiceEncode(search_space=ss3, observations=[]) + t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[]) with self.assertRaises(ValueError): t.transform_search_space(ss3) + def test_deprecated_OrderedChoiceEncode(self) -> None: + # Ensure we error if we try to transform a fidelity parameter + ss3 = SearchSpace( + parameters=[ + ChoiceParameter( + "b", + parameter_type=ParameterType.FLOAT, + values=[1.0, 10.0, 100.0], + is_ordered=True, + is_fidelity=True, + target_value=100.0, + ) + ] + ) + t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[]) + t_deprecated = OrderedChoiceEncode(search_space=ss3, observations=[]) + self.assertEqual(t.__dict__, t_deprecated.__dict__) + def normalize_values(values: Sized) -> List[float]: values = np.array(values, dtype=float) diff --git a/ax/storage/transform_registry.py b/ax/storage/transform_registry.py index 4ffe9bb43c4..82e4b980d36 100644 --- a/ax/storage/transform_registry.py +++ b/ax/storage/transform_registry.py @@ -6,11 +6,15 @@ # pyre-strict -from typing import Dict, Type +from typing import Dict, List, Type from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.cap_parameter import CapParameter -from ax.modelbridge.transforms.choice_encode import ChoiceEncode, OrderedChoiceEncode +from ax.modelbridge.transforms.choice_encode import ( + ChoiceEncode, + OrderedChoiceEncode, + OrderedChoiceToIntegerRange, +) from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames from ax.modelbridge.transforms.derelativize import Derelativize from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice @@ -56,7 +60,8 @@ IVW: 4, Log: 5, OneHot: 6, - OrderedChoiceEncode: 7, + OrderedChoiceEncode: 7, # TO BE DEPRECATED + OrderedChoiceToIntegerRange: 7, # This transform was upstreamed into the base modelbridge. # Old transforms serialized with this will have the OutOfDesign transform # replaced with a no-op, the base transform. @@ -81,7 +86,15 @@ RelativizeWithConstantControl: 25, } +""" +List of new classes of transforms which will be deprecated. +The reverse transform registry will refer to the old transforms at first, +and will be later migrated to point to the new transforms. +""" +TRANSFORMS_TO_UPDATE: List[Type[Transform]] = [ + OrderedChoiceToIntegerRange # will replace OrderedChoiceEncode +] REVERSE_TRANSFORM_REGISTRY: Dict[int, Type[Transform]] = { - v: k for k, v in TRANSFORM_REGISTRY.items() + v: k for k, v in TRANSFORM_REGISTRY.items() if k not in TRANSFORMS_TO_UPDATE }