From 2bc108c03ca5b1813c9e1d9daee165eac381520b Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 27 Feb 2025 13:37:39 -0800 Subject: [PATCH] Allow for adhoc computation of multiple CVs at one time (#3427) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3427 This is the second ux improvement diff for adhoc cross validation computation. This allows for mulitple metrics to be computed at one time for a single adapter. It does not currently tile or drop down the metic names. This is a nice improvement we'd like to make, but it wasn't immediately clear how to go about this Differential Revision: D70200056 --- ax/analysis/plotly/cross_validation.py | 55 +++++++++++++------ .../plotly/tests/test_cross_validation.py | 18 ++++++ 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index f27d8fb0875..48658b50e73 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -106,11 +106,12 @@ def compute( def _compute_adhoc( self, adapter: Adapter, - metric_name: str, + metric_names: list[str], experiment: Experiment | None = None, folds: int = -1, untransform: bool = True, - ) -> PlotlyAnalysisCard: + metric_name_mapping: dict[str, str] | None = None, + ) -> list[PlotlyAnalysisCard]: """ Helper method to expose adhoc cross validation plotting. This overrides the default assumption that the adapter from the generation strategy should be @@ -118,8 +119,8 @@ def _compute_adhoc( Args: adapter: The adapter that will be assessed during cross validation. - metric_name: The name of the metric to plot. Must be provided for adhoc - plotting. + metric_names: A list of all the metrics to perform cross validation on. + Must be provided for adhoc plotting. experiment: Experiment associated with this analysis. Used to determine the priority of the analysis based on the metric importance in the optimization config. @@ -136,17 +137,31 @@ def _compute_adhoc( regions where outliers have been removed, we have found it to better reflect the how good the model used for candidate generation actually is. + metric_name_mapping: Optional mapping from default metric names to more + readable metric names. """ - return self._construct_plot( - adapter=adapter, - metric_name=metric_name, - folds=folds, - untransform=untransform, - # trial_index argument is used with generation strategy since this is an - # adhoc plot call, this will be None. - trial_index=None, - experiment=experiment, - ) + plots = [] + for metric_name in metric_names: + # replace metric name with human readable name if mapping is provided + refined_metric_name = ( + metric_name_mapping.get(metric_name, metric_name) + if metric_name_mapping + else metric_name + ) + plots.append( + self._construct_plot( + adapter=adapter, + metric_name=metric_name, + folds=folds, + untransform=untransform, + # trial_index argument is used with generation strategy since this + # is an adhoc plot call, this will be None. + trial_index=None, + experiment=experiment, + refined_metric_name=refined_metric_name, + ) + ) + return plots def _construct_plot( self, @@ -156,6 +171,7 @@ def _construct_plot( untransform: bool, trial_index: int | None, experiment: Experiment | None = None, + refined_metric_name: str | None = None, ) -> PlotlyAnalysisCard: """ Args: @@ -181,6 +197,8 @@ def _construct_plot( experiment: Optional Experiment associated with this analysis. Used to set the priority of the analysis based on the metric importance in the optimization config. + metric_name_mapping: Optional mapping from default metric names to more + readable metric names. """ df = _prepare_data( adapter=adapter, @@ -209,8 +227,11 @@ def _construct_plot( else: nudge = 0 + # If a human readable metric name is provided, use it in the title + metric_title = refined_metric_name if refined_metric_name else metric_name + return self._create_plotly_analysis_card( - title=f"Cross Validation for {metric_name}", + title=f"Cross Validation for {metric_title}", subtitle=f"Out-of-sample predictions using {k_folds_substring} CV", level=AnalysisCardLevel.LOW.value + nudge, df=df, @@ -271,7 +292,9 @@ def _prepare_data( return pd.DataFrame.from_records(records) -def _prepare_plot(df: pd.DataFrame) -> go.Figure: +def _prepare_plot( + df: pd.DataFrame, +) -> go.Figure: # Create a scatter plot using Plotly Graph Objects for more control fig = go.Figure() # Add scatter trace with error bars diff --git a/ax/analysis/plotly/tests/test_cross_validation.py b/ax/analysis/plotly/tests/test_cross_validation.py index 94bf11ca749..5ea85ec5faf 100644 --- a/ax/analysis/plotly/tests/test_cross_validation.py +++ b/ax/analysis/plotly/tests/test_cross_validation.py @@ -9,6 +9,7 @@ from ax.analysis.plotly.cross_validation import CrossValidationPlot from ax.core.trial import Trial from ax.exceptions.core import UserInputError +from ax.modelbridge.registry import Generators from ax.service.ax_client import AxClient, ObjectiveProperties from ax.utils.common.testutils import TestCase from ax.utils.testing.mock import mock_botorch_optimize @@ -98,3 +99,20 @@ def test_it_can_specify_trial_index_correctly(self) -> None: arm_name, card.df["arm_name"].unique(), ) + + @mock_botorch_optimize + def test_compute_adhoc(self) -> None: + metrics = ["bar"] + metric_mapping = {"bar": "spunky"} + data = self.client.experiment.lookup_data() + adapter = Generators.BOTORCH_MODULAR( + experiment=self.client.experiment, data=data + ) + analysis = CrossValidationPlot()._compute_adhoc( + adapter=adapter, metric_names=metrics, metric_name_mapping=metric_mapping + ) + self.assertEqual(len(analysis), 1) + card = analysis[0] + self.assertEqual(card.name, "CrossValidationPlot") + # validate that the metric name replacement occured + self.assertEqual(card.title, "Cross Validation for spunky")