Skip to content

Commit

Permalink
Delete checked_cast and replace `checked_cast_(list|dict|to_tuple|o…
Browse files Browse the repository at this point in the history
…ptional)

Summary:
Make the below replacements:

`checked_cast_list` -> `assert_is_instance_list`
`checked_cast_dict` -> `assert_is_instance_dict`
`checked_cast_to_tuple` -> `assert_is_instance_of_tuple`
`checked_cast_optional` -> `assert_is_instance_optional`
`_argparse_type_encoder` untouched

Differential Revision: D67993468
  • Loading branch information
eonofrey authored and facebook-github-bot committed Jan 13, 2025
1 parent eef4d37 commit e6b6fd4
Showing 13 changed files with 1,297 additions and 1,322 deletions.
4 changes: 2 additions & 2 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
from ax.modelbridge.model_spec import FactoryFunctionModelSpec
from ax.modelbridge.transition_criterion import TrialBasedCriterion
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import checked_cast_list
from ax.utils.common.typeutils import assert_is_instance_list
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)
@@ -626,7 +626,7 @@ def clone_reset(self) -> GenerationStrategy:
return GenerationStrategy(name=self.name, nodes=cloned_nodes)

return GenerationStrategy(
name=self.name, steps=checked_cast_list(GenerationStep, cloned_nodes)
name=self.name, steps=assert_is_instance_list(cloned_nodes, GenerationStep)
)

def _unset_non_persistent_state_fields(self) -> None:
15 changes: 10 additions & 5 deletions ax/modelbridge/modelbridge_utils.py
Original file line number Diff line number Diff line change
@@ -56,7 +56,10 @@
pareto_frontier_evaluator,
)
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_optional, checked_cast_to_tuple
from ax.utils.common.typeutils import (
assert_is_instance_of_tuple,
assert_is_instance_optional,
)
from botorch.acquisition.multi_objective.multi_output_risk_measures import (
IndependentCVaR,
IndependentVaR,
@@ -218,7 +221,9 @@ def extract_search_space_digest(
if isinstance(p, ChoiceParameter):
if p.is_task:
task_features.append(i)
target_values[i] = checked_cast_to_tuple((int, float), p.target_value)
target_values[i] = assert_is_instance_of_tuple(
p.target_value, (int, float)
)
elif p.is_ordered:
ordinal_features.append(i)
else:
@@ -243,7 +248,7 @@ def extract_search_space_digest(
raise ValueError(f"Unknown parameter type {type(p)}")
if p.is_fidelity:
fidelity_features.append(i)
target_values[i] = checked_cast_to_tuple((int, float), p.target_value)
target_values[i] = assert_is_instance_of_tuple(p.target_value, (int, float))

return SearchSpaceDigest(
feature_names=param_names,
@@ -1054,8 +1059,8 @@ def _get_multiobjective_optimization_config(
objective_thresholds: TRefPoint | None = None,
) -> MultiObjectiveOptimizationConfig:
# Optimization_config
mooc = optimization_config or checked_cast_optional(
MultiObjectiveOptimizationConfig, modelbridge._optimization_config
mooc = optimization_config or assert_is_instance_optional(
modelbridge._optimization_config, MultiObjectiveOptimizationConfig
)
if not mooc:
raise ValueError(
4 changes: 2 additions & 2 deletions ax/modelbridge/transforms/power_transform_y.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_list
from ax.utils.common.typeutils import assert_is_instance_list
from pyre_extensions import assert_is_instance
from sklearn.preprocessing import PowerTransformer

@@ -216,5 +216,5 @@ def _compute_inverse_bounds(
bounds[1] = (-1.0 / lambda_ - mu) / sigma
elif lambda_ > 2.0 + tol:
bounds[0] = (1.0 / (2.0 - lambda_) - mu) / sigma
inv_bounds[k] = tuple(checked_cast_list(float, bounds))
inv_bounds[k] = tuple(assert_is_instance_list(bounds, float))
return inv_bounds
4 changes: 2 additions & 2 deletions ax/models/random/base.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_to_tuple
from ax.utils.common.typeutils import assert_is_instance_of_tuple
from botorch.utils.sampling import HitAndRunPolytopeSampler
from pyre_extensions import assert_is_instance
from torch import Tensor
@@ -129,7 +129,7 @@ def gen(
if model_gen_options:
max_draws = model_gen_options.get("max_rs_draws")
if max_draws is not None:
max_draws = int(checked_cast_to_tuple((int, float), max_draws))
max_draws = int(assert_is_instance_of_tuple(max_draws, (int, float)))
try:
# Always rejection sample, but this only rejects if there are
# constraints or actual duplicates and deduplicate is specified.
9 changes: 7 additions & 2 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,10 @@
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import _argparse_type_encoder, checked_cast_optional
from ax.utils.common.typeutils import (
_argparse_type_encoder,
assert_is_instance_optional,
)
from ax.utils.stats.model_fit_stats import (
DIAGNOSTIC_FN_DIRECTIONS,
DIAGNOSTIC_FNS,
@@ -1277,7 +1280,9 @@ def best_out_of_sample_point(
options = options or {}
acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class(
outcome_constraints=torch_opt_config.outcome_constraints,
seed_inner=checked_cast_optional(int, options.get(Keys.SEED_INNER, None)),
seed_inner=assert_is_instance_optional(
options.get(Keys.SEED_INNER, None), int
),
qmc=assert_is_instance(
options.get(Keys.QMC, True),
bool,
4 changes: 2 additions & 2 deletions ax/plot/scatter.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@
TNullableGeneratorRunsDict,
)
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_optional
from ax.utils.common.typeutils import assert_is_instance_optional
from ax.utils.stats.statstools import relativize
from plotly import subplots

@@ -419,7 +419,7 @@ def plot_multiple_metrics(
layout_offset_x = 0.15
else:
layout_offset_x = 0
rel = checked_cast_optional(bool, kwargs.get("rel"))
rel = assert_is_instance_optional(kwargs.get("rel"), bool)
if rel is not None:
warnings.warn(
"Use `rel_x` and `rel_y` instead of `rel`.",
5 changes: 4 additions & 1 deletion ax/service/tests/test_instantiation_utils.py
Original file line number Diff line number Diff line change
@@ -385,7 +385,10 @@ def test_choice_with_is_sorted(self) -> None:
else:
self.assertEqual(output.sort_values, sort_values)

with self.assertRaisesRegex(ValueError, "Value was not of type <class 'bool'>"):
with self.assertRaisesRegex(
TypeError,
r"obj is not an instance of cls: obj=\['Foo'\] cls=<class 'bool'>",
):
representation: dict[str, Any] = {
"name": "foo_or_bar",
"type": "choice",
21 changes: 13 additions & 8 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,10 @@
from ax.exceptions.core import UnsupportedError
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_optional, checked_cast_to_tuple
from ax.utils.common.typeutils import (
assert_is_instance_of_tuple,
assert_is_instance_optional,
)
from pyre_extensions import assert_is_instance, none_throws

DEFAULT_OBJECTIVE_NAME = "objective"
@@ -227,8 +230,8 @@ def _make_range_param(
parameter_type=cls._to_parameter_type(
bounds, parameter_type, name, "bounds"
),
lower=checked_cast_to_tuple((float, int), bounds[0]),
upper=checked_cast_to_tuple((float, int), bounds[1]),
lower=assert_is_instance_of_tuple(bounds[0], (float, int)),
upper=assert_is_instance_of_tuple(bounds[1], (float, int)),
log_scale=assert_is_instance(representation.get("log_scale", False), bool),
digits=representation.get("digits", None), # pyre-ignore[6]
is_fidelity=assert_is_instance(
@@ -258,17 +261,19 @@ def _make_choice_param(
values, parameter_type, name, "values"
),
values=values,
is_ordered=checked_cast_optional(bool, representation.get("is_ordered")),
is_ordered=assert_is_instance_optional(
representation.get("is_ordered"), bool
),
is_fidelity=assert_is_instance(
representation.get("is_fidelity", False), bool
),
is_task=assert_is_instance(representation.get("is_task", False), bool),
target_value=representation.get("target_value", None), # pyre-ignore[6]
sort_values=checked_cast_optional(
bool, representation.get("sort_values", None)
sort_values=assert_is_instance_optional(
representation.get("sort_values", None), bool
),
dependents=checked_cast_optional(
dict, representation.get("dependents", None)
dependents=assert_is_instance_optional(
representation.get("dependents", None), dict
),
)

57 changes: 25 additions & 32 deletions ax/utils/common/tests/test_typeutils.py
Original file line number Diff line number Diff line change
@@ -10,43 +10,36 @@
import numpy as np
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import (
checked_cast,
checked_cast_dict,
checked_cast_list,
checked_cast_optional,
assert_is_instance_dict,
assert_is_instance_list,
assert_is_instance_optional,
)
from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type
from pyre_extensions import assert_is_instance


class TestTypeUtils(TestCase):
def test_checked_cast(self) -> None:
self.assertEqual(checked_cast(float, 2.0), 2.0)
with self.assertRaises(ValueError):
checked_cast(float, 2)

def test_checked_cast_with_error_override(self) -> None:
self.assertEqual(checked_cast(float, 2.0), 2.0)
with self.assertRaises(NotImplementedError):
checked_cast(
float, 2, exception=NotImplementedError("foo() doesn't support ints")
)

def test_checked_cast_list(self) -> None:
self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0])
with self.assertRaises(ValueError):
checked_cast_list(float, [1.0, 2])

def test_checked_cast_optional(self) -> None:
self.assertEqual(checked_cast_optional(float, None), None)
with self.assertRaises(ValueError):
checked_cast_optional(float, 2)

def test_checked_cast_dict(self) -> None:
self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1})
with self.assertRaises(ValueError):
checked_cast_dict(str, int, {"some": 1.0})
with self.assertRaises(ValueError):
checked_cast_dict(str, int, {1: 1})
def test_assert_is_instance(self) -> None:
self.assertEqual(assert_is_instance(2.0, float), 2.0)
with self.assertRaises(TypeError):
assert_is_instance(2, float)

def test_assert_is_instance_list(self) -> None:
self.assertEqual(assert_is_instance_list([1.0, 2.0], float), [1.0, 2.0])
with self.assertRaises(TypeError):
assert_is_instance_list([1.0, 2], float)

def test_assert_is_instance_optional(self) -> None:
self.assertEqual(assert_is_instance_optional(None, float), None)
with self.assertRaises(TypeError):
assert_is_instance_optional(2, float)

def test_assert_is_instance_dict(self) -> None:
self.assertEqual(assert_is_instance_dict({"some": 1}, str, int), {"some": 1})
with self.assertRaises(TypeError):
assert_is_instance_dict({"some": 1.0}, str, int)
with self.assertRaises(TypeError):
assert_is_instance_dict({1: 1}, str, int)

def test_numpy_type_to_python_type(self) -> None:
self.assertEqual(type(numpy_type_to_python_type(np.int64(2))), int)
71 changes: 18 additions & 53 deletions ax/utils/common/typeutils.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@

from typing import Any, TypeVar

from pyre_extensions import assert_is_instance

T = TypeVar("T")
V = TypeVar("V")
@@ -15,79 +16,43 @@
Y = TypeVar("Y")


def checked_cast(typ: type[T], val: V, exception: Exception | None = None) -> T:
"""
Cast a value to a type (with a runtime safety check).
Returns the value unchanged and checks its type at runtime. This signals to the
typechecker that the value has the designated type.
Like `typing.cast`_ ``check_cast`` performs no runtime conversion on its argument,
but, unlike ``typing.cast``, ``checked_cast`` will throw an error if the value is
not of the expected type. The type passed as an argument should be a python class.
Args:
typ: the type to cast to
val: the value that we are casting
exception: override exception to raise if typecheck fails
Returns:
the ``val`` argument, unchanged
.. _typing.cast: https://docs.python.org/3/library/typing.html#typing.cast
"""
if not isinstance(val, typ):
raise (
exception
if exception is not None
else ValueError(f"Value was not of type {typ}:\n{val}")
)
return val


def checked_cast_optional(typ: type[T], val: V | None) -> T | None:
"""Calls checked_cast only if value is not None."""
def assert_is_instance_optional(val: V | None, typ: type[T]) -> T | None:
"""Asserts that the value is an instance of the given type if it is not None."""
if val is None:
return val
return checked_cast(typ, val)
return assert_is_instance(val, typ)


def checked_cast_list(typ: type[T], old_l: list[V]) -> list[T]:
"""Calls checked_cast on all items in a list."""
new_l = []
for val in old_l:
val = checked_cast(typ, val)
new_l.append(val)
return new_l
def assert_is_instance_list(old_l: list[V], typ: type[T]) -> list[T]:
"""Asserts that all items in a list are instances of the given type."""
return [assert_is_instance(val, typ) for val in old_l]


def checked_cast_dict(
key_typ: type[K], value_typ: type[V], d: dict[X, Y]
def assert_is_instance_dict(
d: dict[X, Y], key_type: type[K], val_type: type[V]
) -> dict[K, V]:
"""Calls checked_cast on all keys and values in the dictionary."""
"""Asserts that all keys and values in the dictionary are instances of the given classes."""
new_dict = {}
for key, val in d.items():
val = checked_cast(value_typ, val)
key = checked_cast(key_typ, key)
key = assert_is_instance(key, key_type)
val = assert_is_instance(val, val_type)
new_dict[key] = val
return new_dict


# pyre-fixme[34]: `T` isn't present in the function's parameters.
def checked_cast_to_tuple(typ: tuple[type[V], ...], val: V) -> T:
def assert_is_instance_of_tuple(val: V, typ: tuple[type[V], ...]) -> T:
"""
Cast a value to a union of multiple types (with a runtime safety check).
This function is similar to `checked_cast`, but allows for the type to be
defined as a tuple of types, in which case the value is cast as a union of
the types in the tuple.
Asserts that a value is an instance of any type in a tuple of types.
Args:
typ: the tuple of types to cast to
val: the value that we are casting
typ: the tuple of types to check against
val: the value that we are checking
Returns:
the ``val`` argument, unchanged
the `val` argument, unchanged
"""
if not isinstance(val, typ):
raise ValueError(f"Value was not of type {type!r}:\n{val!r}")
raise TypeError(f"Value was not of any type {typ!r}:\n{val!r}")
# pyre-fixme[7]: Expected `T` but got `V`.
return val

Loading

0 comments on commit e6b6fd4

Please sign in to comment.