diff --git a/ax/core/observation.py b/ax/core/observation.py index ba4955ce588..690d9543c52 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -9,6 +9,7 @@ from __future__ import annotations import json +import warnings from copy import deepcopy from typing import Dict, Iterable, List, Optional, Set, Tuple @@ -23,7 +24,7 @@ from ax.core.types import TCandidateMetadata, TParameterization from ax.utils.common.base import Base from ax.utils.common.constants import Keys -from ax.utils.common.typeutils import not_none +from ax.utils.common.typeutils import checked_cast, not_none TIME_COLS = {"start_time", "end_time"} @@ -315,6 +316,13 @@ def _observations_from_dataframe( for f, val in features.items(): if f in OBS_KWARGS: obs_kwargs[f] = val + # add start and end time of trial if the start and end time + # is the same for all metrics and arms + for col in TIME_COLS: + if col in d.columns: + times = d[col] + if times.nunique() == 1 and not times.isnull().any(): + obs_kwargs[col] = times.iloc[0] fidelities = features.get("fidelities") if fidelities is not None: obs_parameters.update(json.loads(fidelities)) @@ -338,12 +346,26 @@ def _observations_from_dataframe( return observations -def get_feature_cols(data: Data) -> List[str]: - return list(OBS_COLS.intersection(data.df.columns)) - +def get_feature_cols(data: Data, is_map_data: bool = False) -> List[str]: + feature_cols = OBS_COLS.intersection(data.df.columns) + # note we use this check, rather than isinstance, since + # only some Modelbridges (e.g. MapTorchModelBridge) + # use observations_from_map_data, which is required + # to properly handle MapData features (e.g. fidelity). + if is_map_data: + data = checked_cast(MapData, data) + feature_cols = feature_cols.union(data.map_keys) + + for column in TIME_COLS: + if column in feature_cols and len(data.df[column].unique()) > 1: + warnings.warn( + f"`{column} is not consistent and being discarded from " + "observation data", + stacklevel=5, + ) + feature_cols.discard(column) -def get_feature_cols_from_map_data(map_data: MapData) -> List[str]: - return list(OBS_COLS.intersection(map_data.df.columns).union(map_data.map_keys)) + return list(feature_cols) def observations_from_data( @@ -458,7 +480,7 @@ def observations_from_map_data( limit_rows_per_group=limit_rows_per_group, include_first_last=True, ) - feature_cols = get_feature_cols_from_map_data(map_data) + feature_cols = get_feature_cols(map_data, is_map_data=True) observations = [] arm_name_only = len(feature_cols) == 1 # there will always be an arm name # One DataFrame where all rows have all features. diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index a840096ff84..3aded6b6031 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -27,7 +27,9 @@ separate_observations, ) from ax.core.trial import Trial +from ax.core.types import TParameterization from ax.utils.common.testutils import TestCase +from ax.utils.common.typeutils import not_none class ObservationsTest(TestCase): @@ -662,6 +664,115 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: ) self.assertEqual(obs.arm_name, cname_truth[i]) + def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None: + params0: TParameterization = {"x": 0, "y": "a"} + params1: TParameterization = {"x": 1, "y": "a"} + truth = [ + { + "arm_name": "0_0", + "parameters": params0, + "mean": 2.0, + "sem": 2.0, + "trial_index": 0, + "metric_name": "a", + "start_time": "2024-03-20 08:45:00", + "end_time": "2024-03-20 08:47:00", + }, + { + "arm_name": "0_0", + "parameters": params0, + "mean": 3.0, + "sem": 3.0, + "trial_index": 0, + "metric_name": "b", + "start_time": "2024-03-20 08:45:00", + }, + { + "arm_name": "0_1", + "parameters": params1, + "mean": 4.0, + "sem": 4.0, + "trial_index": 0, + "metric_name": "a", + "start_time": "2024-03-20 08:43:00", + "end_time": "2024-03-20 08:46:00", + }, + { + "arm_name": "0_1", + "parameters": params1, + "mean": 5.0, + "sem": 5.0, + "trial_index": 0, + "metric_name": "b", + "start_time": "2024-03-20 08:45:00", + "end_time": "2024-03-20 08:46:00", + }, + ] + arms_by_name = { + "0_0": Arm(name="0_0", parameters=params0), + "0_1": Arm(name="0_1", parameters=params1), + } + experiment = Mock() + experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + trials = { + 0: BatchTrial(experiment, GeneratorRun(arms=list(arms_by_name.values()))) + } + type(experiment).arms_by_name = PropertyMock(return_value=arms_by_name) + type(experiment).trials = PropertyMock(return_value=trials) + + df = pd.DataFrame(truth)[ + [ + "arm_name", + "trial_index", + "mean", + "sem", + "metric_name", + "start_time", + "end_time", + ] + ] + data = Data(df=df) + observations = observations_from_data(experiment, data) + + self.assertEqual(len(observations), 2) + # Get them in the order we want for tests below + if observations[0].features.parameters["x"] == 1: + observations.reverse() + + obs_truth = { + "arm_name": ["0_0", "0_1"], + "parameters": [{"x": 0, "y": "a"}, {"x": 1, "y": "a"}], + "metric_names": [["a", "b"], ["a", "b"]], + "means": [np.array([2.0, 3.0]), np.array([4.0, 5.0])], + "covariance": [np.diag([4.0, 9.0]), np.diag([16.0, 25.0])], + } + + for i, obs in enumerate(observations): + self.assertEqual(obs.features.parameters, obs_truth["parameters"][i]) + self.assertEqual( + obs.features.trial_index, + 0, + ) + self.assertEqual(obs.data.metric_names, obs_truth["metric_names"][i]) + self.assertTrue(np.array_equal(obs.data.means, obs_truth["means"][i])) + self.assertTrue( + np.array_equal(obs.data.covariance, obs_truth["covariance"][i]) + ) + self.assertEqual(obs.arm_name, obs_truth["arm_name"][i]) + self.assertEqual(obs.arm_name, obs_truth["arm_name"][i]) + if i == 0: + self.assertEqual( + not_none(obs.features.start_time).strftime("%Y-%m-%d %X"), + "2024-03-20 08:45:00", + ) + self.assertIsNone(obs.features.end_time) + else: + self.assertIsNone(obs.features.start_time) + self.assertEqual( + not_none(obs.features.end_time).strftime("%Y-%m-%d %X"), + "2024-03-20 08:46:00", + ) + def test_SeparateObservations(self) -> None: obs_arm_name = "0_0" obs = Observation( diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 766b9e4f586..6a49dc7113b 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -431,13 +431,12 @@ def _set_status_quo( if len(sq_obs) == 0: logger.warning(f"Status quo {status_quo_name} not present in data") + elif len(sq_obs) > 1: + logger.warning( + f"Status quo {status_quo_name} found in data with multiple " + "features. Use status_quo_features to specify which to use." + ) else: - if len(sq_obs) > 1: - logger.warning( - f"Status quo {status_quo_name} found in data with multiple " - "features. Use status_quo_features to specify which to use." - " Defaulting to the first observation." - ) self._status_quo = sq_obs[0] elif status_quo_features is not None: diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index 3d3e39b9721..cb3ce7e9e53 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -578,14 +578,17 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen): # create data where metrics vary in start and end times data = get_non_monolithic_branin_moo_data() - bridge = ModelBridge( - experiment=exp, - data=data, - model=Model(), - search_space=exp.search_space, - ) + with warnings.catch_warnings(record=True) as ws: + bridge = ModelBridge( + experiment=exp, + data=data, + model=Model(), + search_space=exp.search_space, + ) # just testing it doesn't error bridge.gen(5) + self.assertTrue(any("start_time" in str(w.message) for w in ws)) + self.assertTrue(any("end_time" in str(w.message) for w in ws)) # pyre-fixme[16]: Optional type has no attribute `arm_name`. self.assertEqual(bridge.status_quo.arm_name, "status_quo")