Skip to content

Commit

Permalink
Keep map_key columns in MapData.df attribute (facebook#2701)
Browse files Browse the repository at this point in the history
Summary:
`MapData` has an attribute `df` that takes only the last row from each of `map_df`'s trial-arm-metric groups (when sorted by map_key values), and drops `map_key` columns.

This diff makes it so that `map_key` columns are still present in the `df` attribute, showing their values for each kept row. This will make it easier to understand and model partially complete trials in the future.

[RfC] I've patched up a few tests that assert these dataframes have a specific form, but there don't seem to be any tests failing for functional reasons (since the unused map_key columns would just go unused if not needed). Putting this up as an RfC though in case others feel strongly that these columns should be dropped.

Reviewed By: Balandat

Differential Revision: D61730570

fbshipit-source-id: 0f09b357e0382b48a42b61d5c0f5a63206736196
  • Loading branch information
bernardbeckerman authored and facebook-github-bot committed Aug 26, 2024
1 parent c76f52f commit cc05718
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
8 changes: 3 additions & 5 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,11 @@ def df(self) -> pd.DataFrame:
if self._memo_df is not None:
return self._memo_df

if not any(True for _ in self.map_keys):
if len(self.map_keys) == 0:
return self.map_df

self._memo_df = (
self.map_df.sort_values(list(self.map_keys))
.drop_duplicates(MapData.DEDUPLICATE_BY_COLUMNS, keep="last")
.loc[:, ~self.map_df.columns.isin(self.map_keys)]
self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates(
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
)

return self._memo_df
Expand Down
5 changes: 4 additions & 1 deletion ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,10 @@ def test_WarmStartMapData(self) -> None:
i_old_trial += 1

# Check that the data was attached for correct trials
old_df = old_experiment.fetch_data().df

# Old experiment has already been fetched, and re-fetching will add readings to
# still-running map metrics.
old_df = old_experiment.lookup_data().df
new_df = new_experiment.fetch_data().df

old_df = old_df.sort_values(by=["arm_name", "metric_name"], ignore_index=True)
Expand Down
2 changes: 1 addition & 1 deletion ax/core/tests/test_map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_upcast(self) -> None:

self.assertEqual(
fresh.df.columns.size,
fresh.map_df.columns.size - len(self.mmd.map_key_infos),
fresh.map_df.columns.size,
)

self.assertIsNotNone(fresh._memo_df) # Assert df is cached after first call
Expand Down

0 comments on commit cc05718

Please sign in to comment.