diff --git a/ax/modelbridge/transforms/tests/test_percentile_y_transform.py b/ax/modelbridge/transforms/tests/test_percentile_y_transform.py index 2ec0ddcb753..83a5fd34d8f 100644 --- a/ax/modelbridge/transforms/tests/test_percentile_y_transform.py +++ b/ax/modelbridge/transforms/tests/test_percentile_y_transform.py @@ -13,6 +13,7 @@ from ax.exceptions.core import DataRequiredError from ax.modelbridge.transforms.percentile_y import PercentileY from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_observations_with_invalid_value class PercentileYTransformTest(TestCase): @@ -126,3 +127,13 @@ def test_TransformObservationsWithWinsorization(self) -> None: np.allclose(mean_results, expected), msg=f"Unexpected mean Results: {mean_results}. Expected: {expected}.", ) + + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value( + invalid_value=invalid_value + ) + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + PercentileY(observations=observations, config={"metrics": ["m1"]}) diff --git a/ax/modelbridge/transforms/tests/test_power_y_transform.py b/ax/modelbridge/transforms/tests/test_power_y_transform.py index 03d63a26627..6e07461127f 100644 --- a/ax/modelbridge/transforms/tests/test_power_y_transform.py +++ b/ax/modelbridge/transforms/tests/test_power_y_transform.py @@ -26,6 +26,7 @@ ) from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_observations_with_invalid_value from sklearn.preprocessing import PowerTransformer @@ -328,3 +329,11 @@ def test_TransformOptimizationConfig(self) -> None: "that are part of a ScalarizedOutcomeConstraint.", str(cm.exception), ) + + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value(invalid_value) + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + PowerTransformY(observations=observations, config={"metrics": ["m1"]}) diff --git a/ax/modelbridge/transforms/tests/test_standardize_y_transform.py b/ax/modelbridge/transforms/tests/test_standardize_y_transform.py index 450775c4497..d38d96ac5eb 100644 --- a/ax/modelbridge/transforms/tests/test_standardize_y_transform.py +++ b/ax/modelbridge/transforms/tests/test_standardize_y_transform.py @@ -20,6 +20,7 @@ from ax.exceptions.core import DataRequiredError from ax.modelbridge.transforms.standardize_y import StandardizeY from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_observations_with_invalid_value class StandardizeYTransformTest(TestCase): @@ -153,6 +154,16 @@ def test_TransformOptimizationConfig(self) -> None: with self.assertRaises(ValueError): oc = self.t.transform_optimization_config(oc, None, None) + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value( + invalid_value=invalid_value + ) + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + StandardizeY(observations=observations, config={"metrics": ["m1"]}) + def osd_allclose(osd1: ObservationData, osd2: ObservationData) -> bool: if osd1.metric_names != osd2.metric_names: diff --git a/ax/modelbridge/transforms/tests/test_winsorize_transform.py b/ax/modelbridge/transforms/tests/test_winsorize_transform.py index 28a47af3306..8f20cf25520 100644 --- a/ax/modelbridge/transforms/tests/test_winsorize_transform.py +++ b/ax/modelbridge/transforms/tests/test_winsorize_transform.py @@ -8,7 +8,7 @@ import warnings from copy import deepcopy -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple from unittest import mock import numpy as np @@ -40,7 +40,10 @@ ) from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_optimization_config +from ax.utils.testing.core_stubs import ( + get_observations_with_invalid_value, + get_optimization_config, +) from typing_extensions import SupportsIndex INF = float("inf") @@ -642,6 +645,19 @@ def test_relative_constraints( ) self.assertDictEqual(t.cutoffs, {"a": (-INF, 3.5), "b": (-INF, 12.0)}) + def test_non_finite_data_raises(self) -> None: + for invalid_value in [float("nan"), float("inf")]: + observations = get_observations_with_invalid_value( + invalid_value=invalid_value + ) + config: Dict[str, Any] = { + "winsorization_config": WinsorizationConfig(upper_quantile_margin=0.2) + } + with self.assertRaisesRegex( + ValueError, f"Non-finite data found for metric m1: {invalid_value}" + ): + Winsorize(search_space=None, observations=observations, config=config) + # pyre-fixme[2]: Parameter must be annotated. def get_transform(observation_data, config=None, optimization_config=None) -> Winsorize: diff --git a/ax/modelbridge/transforms/utils.py b/ax/modelbridge/transforms/utils.py index 40749651fe6..0f4385deec8 100644 --- a/ax/modelbridge/transforms/utils.py +++ b/ax/modelbridge/transforms/utils.py @@ -66,14 +66,30 @@ def __getitem__(self, key: Number) -> Any: def get_data( - observation_data: List[ObservationData], metric_names: Union[List[str], None] = None + observation_data: List[ObservationData], + metric_names: Union[List[str], None] = None, + raise_on_non_finite_data: bool = True, ) -> Dict[str, List[float]]: - """Extract all metrics if `metric_names` is None.""" + """Extract all metrics if `metric_names` is None. + + Raises a value error if any data is non-finite. + + Args: + observation_data: List of observation data. + metric_names: List of metric names. + raise_on_non_finite_data: If true, raises an exception on nan/inf. + + Returns: + A dictionary mapping metric names to lists of metric values. + """ Ys = defaultdict(list) for obsd in observation_data: for i, m in enumerate(obsd.metric_names): if metric_names is None or m in metric_names: - Ys[m].append(obsd.means[i]) + val = obsd.means[i] + if raise_on_non_finite_data and (not np.isfinite(val)): + raise ValueError(f"Non-finite data found for metric {m}: {val}") + Ys[m].append(val) return Ys diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 0dd4574aeaf..8e7d596b0e6 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -41,7 +41,7 @@ from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective -from ax.core.observation import ObservationFeatures +from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, OptimizationConfig, @@ -1987,6 +1987,18 @@ def get_map_data(trial_index: int = 0) -> MapData: ) +def get_observations_with_invalid_value(invalid_value: float) -> List[Observation]: + obsd_with_non_finite = ObservationData( + metric_names=["m1"] * 4, + means=np.array([-100, 4, invalid_value, 2]), + covariance=np.eye(4), + ) + observations = [ + Observation(features=ObservationFeatures({}), data=obsd_with_non_finite) + ] + return observations + + # pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter. def get_map_key_info() -> MapKeyInfo: return MapKeyInfo(key="epoch", default_value=0.0)