Skip to content

Commit

Permalink
Allow for adhoc computation of multiple CVs at one time (facebook#3427)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3427

This is the second ux improvement diff for adhoc cross validation computation. This allows for mulitple metrics to be computed at one time for a single adapter.

It does not currently tile or drop down the metic names. This is a nice improvement we'd like to make, but it wasn't immediately clear how to go about this

Differential Revision: D70200056
  • Loading branch information
mgarrard authored and facebook-github-bot committed Feb 27, 2025
1 parent 3315a14 commit 2bc108c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
55 changes: 39 additions & 16 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,21 @@ def compute(
def _compute_adhoc(
self,
adapter: Adapter,
metric_name: str,
metric_names: list[str],
experiment: Experiment | None = None,
folds: int = -1,
untransform: bool = True,
) -> PlotlyAnalysisCard:
metric_name_mapping: dict[str, str] | None = None,
) -> list[PlotlyAnalysisCard]:
"""
Helper method to expose adhoc cross validation plotting. This overrides the
default assumption that the adapter from the generation strategy should be
used. Only for advanced users in a notebook setting.
Args:
adapter: The adapter that will be assessed during cross validation.
metric_name: The name of the metric to plot. Must be provided for adhoc
plotting.
metric_names: A list of all the metrics to perform cross validation on.
Must be provided for adhoc plotting.
experiment: Experiment associated with this analysis. Used to determine
the priority of the analysis based on the metric importance in the
optimization config.
Expand All @@ -136,17 +137,31 @@ def _compute_adhoc(
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
metric_name_mapping: Optional mapping from default metric names to more
readable metric names.
"""
return self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this is an
# adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
)
plots = []
for metric_name in metric_names:
# replace metric name with human readable name if mapping is provided
refined_metric_name = (
metric_name_mapping.get(metric_name, metric_name)
if metric_name_mapping
else metric_name
)
plots.append(
self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this
# is an adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
refined_metric_name=refined_metric_name,
)
)
return plots

def _construct_plot(
self,
Expand All @@ -156,6 +171,7 @@ def _construct_plot(
untransform: bool,
trial_index: int | None,
experiment: Experiment | None = None,
refined_metric_name: str | None = None,
) -> PlotlyAnalysisCard:
"""
Args:
Expand All @@ -181,6 +197,8 @@ def _construct_plot(
experiment: Optional Experiment associated with this analysis. Used to set
the priority of the analysis based on the metric importance in the
optimization config.
metric_name_mapping: Optional mapping from default metric names to more
readable metric names.
"""
df = _prepare_data(
adapter=adapter,
Expand Down Expand Up @@ -209,8 +227,11 @@ def _construct_plot(
else:
nudge = 0

# If a human readable metric name is provided, use it in the title
metric_title = refined_metric_name if refined_metric_name else metric_name

return self._create_plotly_analysis_card(
title=f"Cross Validation for {metric_name}",
title=f"Cross Validation for {metric_title}",
subtitle=f"Out-of-sample predictions using {k_folds_substring} CV",
level=AnalysisCardLevel.LOW.value + nudge,
df=df,
Expand Down Expand Up @@ -271,7 +292,9 @@ def _prepare_data(
return pd.DataFrame.from_records(records)


def _prepare_plot(df: pd.DataFrame) -> go.Figure:
def _prepare_plot(
df: pd.DataFrame,
) -> go.Figure:
# Create a scatter plot using Plotly Graph Objects for more control
fig = go.Figure()
# Add scatter trace with error bars
Expand Down
18 changes: 18 additions & 0 deletions ax/analysis/plotly/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.modelbridge.registry import Generators
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.common.testutils import TestCase
from ax.utils.testing.mock import mock_botorch_optimize
Expand Down Expand Up @@ -98,3 +99,20 @@ def test_it_can_specify_trial_index_correctly(self) -> None:
arm_name,
card.df["arm_name"].unique(),
)

@mock_botorch_optimize
def test_compute_adhoc(self) -> None:
metrics = ["bar"]
metric_mapping = {"bar": "spunky"}
data = self.client.experiment.lookup_data()
adapter = Generators.BOTORCH_MODULAR(
experiment=self.client.experiment, data=data
)
analysis = CrossValidationPlot()._compute_adhoc(
adapter=adapter, metric_names=metrics, metric_name_mapping=metric_mapping
)
self.assertEqual(len(analysis), 1)
card = analysis[0]
self.assertEqual(card.name, "CrossValidationPlot")
# validate that the metric name replacement occured
self.assertEqual(card.title, "Cross Validation for spunky")

0 comments on commit 2bc108c

Please sign in to comment.