Skip to content

Commit

Permalink
Modify get_data to error out on nan/inf
Browse files Browse the repository at this point in the history
Summary:
This method is leveraged by `StandardizeY`, `Winsorize`, `PowerTransformY`, and `PercentileY`.

This change will improve the robustness of our transform layer to non-finite values and error out before we pass those down to BoTorch.

Differential Revision: D60681606
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Aug 3, 2024
1 parent 8587587 commit 22c0702
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 4 deletions.
11 changes: 11 additions & 0 deletions ax/modelbridge/transforms/tests/test_percentile_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.percentile_y import PercentileY
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_observations_with_invalid_value


class PercentileYTransformTest(TestCase):
Expand Down Expand Up @@ -126,3 +127,13 @@ def test_TransformObservationsWithWinsorization(self) -> None:
np.allclose(mean_results, expected),
msg=f"Unexpected mean Results: {mean_results}. Expected: {expected}.",
)

def test_non_finite_data_raises(self) -> None:
for invalid_value in [float("nan"), float("inf")]:
observations = get_observations_with_invalid_value(
invalid_value=invalid_value
)
with self.assertRaisesRegex(
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
):
PercentileY(observations=observations, config={"metrics": ["m1"]})
9 changes: 9 additions & 0 deletions ax/modelbridge/transforms/tests/test_power_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_observations_with_invalid_value
from sklearn.preprocessing import PowerTransformer


Expand Down Expand Up @@ -328,3 +329,11 @@ def test_TransformOptimizationConfig(self) -> None:
"that are part of a ScalarizedOutcomeConstraint.",
str(cm.exception),
)

def test_non_finite_data_raises(self) -> None:
for invalid_value in [float("nan"), float("inf")]:
observations = get_observations_with_invalid_value(invalid_value)
with self.assertRaisesRegex(
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
):
PowerTransformY(observations=observations, config={"metrics": ["m1"]})
11 changes: 11 additions & 0 deletions ax/modelbridge/transforms/tests/test_standardize_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_observations_with_invalid_value


class StandardizeYTransformTest(TestCase):
Expand Down Expand Up @@ -153,6 +154,16 @@ def test_TransformOptimizationConfig(self) -> None:
with self.assertRaises(ValueError):
oc = self.t.transform_optimization_config(oc, None, None)

def test_non_finite_data_raises(self) -> None:
for invalid_value in [float("nan"), float("inf")]:
observations = get_observations_with_invalid_value(
invalid_value=invalid_value
)
with self.assertRaisesRegex(
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
):
StandardizeY(observations=observations, config={"metrics": ["m1"]})


def osd_allclose(osd1: ObservationData, osd2: ObservationData) -> bool:
if osd1.metric_names != osd2.metric_names:
Expand Down
18 changes: 17 additions & 1 deletion ax/modelbridge/transforms/tests/test_winsorize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
)
from ax.models.winsorization_config import WinsorizationConfig
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_optimization_config
from ax.utils.testing.core_stubs import (
get_observations_with_invalid_value,
get_optimization_config,
)
from typing_extensions import SupportsIndex

INF = float("inf")
Expand Down Expand Up @@ -665,6 +668,19 @@ def get_transform(observation_data, config=None, optimization_config=None) -> Wi
config=config,
)

def test_non_finite_data_raises(self) -> None:
for invalid_value in [float("nan"), float("inf")]:
observations = get_observations_with_invalid_value(
invalid_value=invalid_value
)
config = {
"winsorization_config": WinsorizationConfig(upper_quantile_margin=0.2)
}
with self.assertRaisesRegex(
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
):
Winsorize(search_space=None, observations=observations, config=config)


def get_default_transform_cutoffs(
optimization_config: OptimizationConfig,
Expand Down
13 changes: 11 additions & 2 deletions ax/modelbridge/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,22 @@ def __getitem__(self, key: Number) -> Any:


def get_data(
observation_data: List[ObservationData], metric_names: Union[List[str], None] = None
observation_data: List[ObservationData],
metric_names: Union[List[str], None] = None,
raise_on_non_finite_data: bool = True,
) -> Dict[str, List[float]]:
"""Extract all metrics if `metric_names` is None."""
"""Extract all metrics if `metric_names` is None.
Raises a value error if any data is non-finite.
"""
Ys = defaultdict(list)
for obsd in observation_data:
for i, m in enumerate(obsd.metric_names):
if metric_names is None or m in metric_names:
if raise_on_non_finite_data and (not np.isfinite(obsd.means[i])):
raise ValueError(
f"Non-finite data found for metric {m}: {obsd.means[i]}"
)
Ys[m].append(obsd.means[i])
return Ys

Expand Down
15 changes: 14 additions & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.observation import ObservationFeatures
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
OptimizationConfig,
Expand Down Expand Up @@ -1987,6 +1987,19 @@ def get_map_data(trial_index: int = 0) -> MapData:
)


def get_observations_with_invalid_value(invalid_value: float) -> List[Observation]:
obsd_with_non_finite = ObservationData(
metric_names=["m1"] * 4,
means=np.array([-100, 4, invalid_value, 2]),
covariance=np.eye(4),
)
observations = [
Observation(features=ObservationFeatures({}), data=obsd)
for obsd in [obsd_with_non_finite]
]
return observations


# pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter.
def get_map_key_info() -> MapKeyInfo:
return MapKeyInfo(key="epoch", default_value=0.0)
Expand Down

0 comments on commit 22c0702

Please sign in to comment.