Skip to content

Commit

Permalink
Rename OrderedChoiceEncode => OrderedChoiceToIntegerRange (#2323)
Browse files Browse the repository at this point in the history
Summary:

This change renames the OrderedChoiceEncode transform to one which reflects its behavior- see T182722751 for the overall task. 

- Adds a new "OrderedChoiceToIntegerRange" class with the logic from the original OrderedChoiceEncode
- Updates OrderedChoiceEncode to inherit from DeprecatedTransformMixin and OrderedChoiceToIntegerRange
- Updates the registry to support the new transform.

Initially, the new classes will be decoded into the deprecated classes to maintain backwards compatibility. Once the new classes are landed, call sites will be updated to use the new class instead of the old.

Differential Revision: D55754487
  • Loading branch information
mgrange1998 authored and facebook-github-bot committed Apr 4, 2024
1 parent 4e9f6a8 commit 4466f6e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
14 changes: 12 additions & 2 deletions ax/modelbridge/transforms/choice_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

# pyre-strict

from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

import numpy as np
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TParamValue
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.deprecated_transform_mixin import (
DeprecatedTransformMixin,
)
from ax.modelbridge.transforms.utils import (
ClosestLookupDict,
construct_new_search_space,
Expand Down Expand Up @@ -120,7 +123,7 @@ def untransform_observation_features(
return observation_features


class OrderedChoiceEncode(ChoiceEncode):
class OrderedChoiceToIntegerRange(ChoiceEncode):
"""Convert ordered ChoiceParameters to integer RangeParameters.
Parameters will be transformed to an integer RangeParameters, mapped from the
Expand Down Expand Up @@ -187,6 +190,13 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
)


class OrderedChoiceEncode(DeprecatedTransformMixin, OrderedChoiceToIntegerRange):
"""Deprecated alias for OrderedChoiceToIntegerRange."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


def transform_choice_values(p: ChoiceParameter) -> Tuple[np.ndarray, ParameterType]:
"""Transforms the choice values and returns the new parameter type.
Expand Down
32 changes: 27 additions & 5 deletions ax/modelbridge/transforms/tests/test_choice_encode_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.modelbridge.transforms.choice_encode import ChoiceEncode, OrderedChoiceEncode
from ax.modelbridge.transforms.choice_encode import (
ChoiceEncode,
OrderedChoiceEncode,
OrderedChoiceToIntegerRange,
)
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.core_stubs import get_robust_search_space
Expand Down Expand Up @@ -166,7 +170,7 @@ def test_TransformSearchSpace(self) -> None:
)
]
)
t = OrderedChoiceEncode(search_space=ss3, observations=[])
t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[])
with self.assertRaises(ValueError):
t.transform_search_space(ss3)

Expand Down Expand Up @@ -208,8 +212,8 @@ def test_w_parameter_distributions(self) -> None:
self.assertEqual(rss_new.parameters.get("c").parameter_type, ParameterType.INT)


class OrderedChoiceEncodeTransformTest(ChoiceEncodeTransformTest):
t_class = OrderedChoiceEncode
class OrderedChoiceToIntegerRangeTransformTest(ChoiceEncodeTransformTest):
t_class = OrderedChoiceToIntegerRange

def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -258,10 +262,28 @@ def test_TransformSearchSpace(self) -> None:
)
]
)
t = OrderedChoiceEncode(search_space=ss3, observations=[])
t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[])
with self.assertRaises(ValueError):
t.transform_search_space(ss3)

def test_deprecated_OrderedChoiceEncode(self) -> None:
# Ensure we error if we try to transform a fidelity parameter
ss3 = SearchSpace(
parameters=[
ChoiceParameter(
"b",
parameter_type=ParameterType.FLOAT,
values=[1.0, 10.0, 100.0],
is_ordered=True,
is_fidelity=True,
target_value=100.0,
)
]
)
t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[])
t_deprecated = OrderedChoiceEncode(search_space=ss3, observations=[])
self.assertEqual(t.__dict__, t_deprecated.__dict__)


def normalize_values(values: Sized) -> List[float]:
values = np.array(values, dtype=float)
Expand Down
21 changes: 17 additions & 4 deletions ax/storage/transform_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

# pyre-strict

from typing import Dict, Type
from typing import Dict, List, Type

from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cap_parameter import CapParameter
from ax.modelbridge.transforms.choice_encode import ChoiceEncode, OrderedChoiceEncode
from ax.modelbridge.transforms.choice_encode import (
ChoiceEncode,
OrderedChoiceEncode,
OrderedChoiceToIntegerRange,
)
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
Expand Down Expand Up @@ -56,7 +60,8 @@
IVW: 4,
Log: 5,
OneHot: 6,
OrderedChoiceEncode: 7,
OrderedChoiceEncode: 7, # TO BE DEPRECATED
OrderedChoiceToIntegerRange: 7,
# This transform was upstreamed into the base modelbridge.
# Old transforms serialized with this will have the OutOfDesign transform
# replaced with a no-op, the base transform.
Expand All @@ -81,7 +86,15 @@
RelativizeWithConstantControl: 25,
}

"""
List of new classes of transforms which will be deprecated.
The reverse transform registry will refer to the old transforms at first,
and will be later migrated to point to the new transforms.
"""
TRANSFORMS_TO_UPDATE: List[Type[Transform]] = [
OrderedChoiceToIntegerRange # will replace OrderedChoiceEncode
]

REVERSE_TRANSFORM_REGISTRY: Dict[int, Type[Transform]] = {
v: k for k, v in TRANSFORM_REGISTRY.items()
v: k for k, v in TRANSFORM_REGISTRY.items() if k not in TRANSFORMS_TO_UPDATE
}

0 comments on commit 4466f6e

Please sign in to comment.