diff --git a/ax/modelbridge/transforms/tests/test_percentile_y_transform.py b/ax/modelbridge/transforms/tests/test_percentile_y_transform.py index 2ec0ddcb753..83a5fd34d8f 100644 --- a/ax/modelbridge/transforms/tests/test_percentile_y_transform.py +++ b/ax/modelbridge/transforms/tests/test_percentile_y_transform.py @@ -13,6 +13,7 @@ from ax.exceptions.core import DataRequiredError from ax.modelbridge.transforms.percentile_y import PercentileY from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_observations_with_invalid_value class PercentileYTransformTest(TestCase): @@ -126,3 +127,13 @@ def test_TransformObservationsWithWinsorization(self) -> None: np.allclose(mean_results, expected), msg=f"Unexpected mean Results: {mean_results}. Expected: {expected}.", ) + + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value( + invalid_value=invalid_value + ) + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + PercentileY(observations=observations, config={"metrics": ["m1"]}) diff --git a/ax/modelbridge/transforms/tests/test_power_y_transform.py b/ax/modelbridge/transforms/tests/test_power_y_transform.py index 03d63a26627..6e07461127f 100644 --- a/ax/modelbridge/transforms/tests/test_power_y_transform.py +++ b/ax/modelbridge/transforms/tests/test_power_y_transform.py @@ -26,6 +26,7 @@ ) from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_observations_with_invalid_value from sklearn.preprocessing import PowerTransformer @@ -328,3 +329,11 @@ def test_TransformOptimizationConfig(self) -> None: "that are part of a ScalarizedOutcomeConstraint.", str(cm.exception), ) + + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value(invalid_value) + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + PowerTransformY(observations=observations, config={"metrics": ["m1"]}) diff --git a/ax/modelbridge/transforms/tests/test_standardize_y_transform.py b/ax/modelbridge/transforms/tests/test_standardize_y_transform.py index 450775c4497..d38d96ac5eb 100644 --- a/ax/modelbridge/transforms/tests/test_standardize_y_transform.py +++ b/ax/modelbridge/transforms/tests/test_standardize_y_transform.py @@ -20,6 +20,7 @@ from ax.exceptions.core import DataRequiredError from ax.modelbridge.transforms.standardize_y import StandardizeY from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_observations_with_invalid_value class StandardizeYTransformTest(TestCase): @@ -153,6 +154,16 @@ def test_TransformOptimizationConfig(self) -> None: with self.assertRaises(ValueError): oc = self.t.transform_optimization_config(oc, None, None) + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value( + invalid_value=invalid_value + ) + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + StandardizeY(observations=observations, config={"metrics": ["m1"]}) + def osd_allclose(osd1: ObservationData, osd2: ObservationData) -> bool: if osd1.metric_names != osd2.metric_names: diff --git a/ax/modelbridge/transforms/tests/test_winsorize_transform.py b/ax/modelbridge/transforms/tests/test_winsorize_transform.py index 28a47af3306..61a52e97580 100644 --- a/ax/modelbridge/transforms/tests/test_winsorize_transform.py +++ b/ax/modelbridge/transforms/tests/test_winsorize_transform.py @@ -40,7 +40,10 @@ ) from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_optimization_config +from ax.utils.testing.core_stubs import ( + get_observations_with_invalid_value, + get_optimization_config, +) from typing_extensions import SupportsIndex INF = float("inf") @@ -665,6 +668,19 @@ def get_transform(observation_data, config=None, optimization_config=None) -> Wi config=config, ) + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value( + invalid_value=invalid_value + ) + config = { + "winsorization_config": WinsorizationConfig(upper_quantile_margin=0.2) + } + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + Winsorize(search_space=None, observations=observations, config=config) + def get_default_transform_cutoffs( optimization_config: OptimizationConfig, diff --git a/ax/modelbridge/transforms/utils.py b/ax/modelbridge/transforms/utils.py index 40749651fe6..ddde36fc25c 100644 --- a/ax/modelbridge/transforms/utils.py +++ b/ax/modelbridge/transforms/utils.py @@ -66,13 +66,22 @@ def __getitem__(self, key: Number) -> Any: def get_data( - observation_data: List[ObservationData], metric_names: Union[List[str], None] = None + observation_data: List[ObservationData], + metric_names: Union[List[str], None] = None, + raise_on_non_finite_data: bool = True, ) -> Dict[str, List[float]]: - """Extract all metrics if `metric_names` is None.""" + """Extract all metrics if `metric_names` is None. + + Raises a value error if any data is non-finite. + """ Ys = defaultdict(list) for obsd in observation_data: for i, m in enumerate(obsd.metric_names): if metric_names is None or m in metric_names: + if raise_on_non_finite_data and (not np.isfinite(obsd.means[i])): + raise ValueError( + f"Non-finite data found for metric {m}: {obsd.means[i]}" + ) Ys[m].append(obsd.means[i]) return Ys diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 0dd4574aeaf..668f8024239 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -41,7 +41,7 @@ from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective -from ax.core.observation import ObservationFeatures +from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, OptimizationConfig, @@ -1987,6 +1987,19 @@ def get_map_data(trial_index: int = 0) -> MapData: ) +def get_observations_with_invalid_value(invalid_value: float) -> List[Observation]: + obsd_with_non_finite = ObservationData( + metric_names=["m1"] * 4, + means=np.array([-100, 4, invalid_value, 2]), + covariance=np.eye(4), + ) + observations = [ + Observation(features=ObservationFeatures({}), data=obsd) + for obsd in [obsd_with_non_finite] + ] + return observations + + # pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter. def get_map_key_info() -> MapKeyInfo: return MapKeyInfo(key="epoch", default_value=0.0)