Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for adhoc computation of multiple CVs at one time #3427

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 99 additions & 39 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.cross_validation import cross_validate
from plotly import express as px, graph_objects as go
from plotly import graph_objects as go
from pyre_extensions import none_throws


Expand Down Expand Up @@ -106,20 +107,22 @@ def compute(
def _compute_adhoc(
self,
adapter: Adapter,
metric_name: str,
data: Data,
experiment: Experiment | None = None,
folds: int = -1,
untransform: bool = True,
) -> PlotlyAnalysisCard:
metric_name_mapping: dict[str, str] | None = None,
) -> list[PlotlyAnalysisCard]:
"""
Helper method to expose adhoc cross validation plotting. This overrides the
default assumption that the adapter from the generation strategy should be
used. Only for advanced users in a notebook setting.

Args:
adapter: The adapter that will be assessed during cross validation.
metric_name: The name of the metric to plot. Must be provided for adhoc
plotting.
data: The Data that was used to fit the model. Will be used in this
adhoc cross validation call to compute the cross validation for all
metrics in the Data object.
experiment: Experiment associated with this analysis. Used to determine
the priority of the analysis based on the metric importance in the
optimization config.
Expand All @@ -136,17 +139,34 @@ def _compute_adhoc(
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
metric_name_mapping: Optional mapping from default metric names to more
readable metric names.
"""
return self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this is an
# adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
)
plots = []
# Get all unique metric names in the data object, CVs will be computed for
# all metrics in the data object
metric_names = list(data.df["metric_name"].unique())
for metric_name in metric_names:
# replace metric name with human readable name if mapping is provided
refined_metric_name = (
metric_name_mapping.get(metric_name, metric_name)
if metric_name_mapping
else metric_name
)
plots.append(
self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this
# is an adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
refined_metric_name=refined_metric_name,
)
)
return plots

def _construct_plot(
self,
Expand All @@ -156,6 +176,7 @@ def _construct_plot(
untransform: bool,
trial_index: int | None,
experiment: Experiment | None = None,
refined_metric_name: str | None = None,
) -> PlotlyAnalysisCard:
"""
Args:
Expand All @@ -181,6 +202,8 @@ def _construct_plot(
experiment: Optional Experiment associated with this analysis. Used to set
the priority of the analysis based on the metric importance in the
optimization config.
refined_metric_name: Optional replacement for raw metric name, useful for
imporving readability of the plot title.
"""
df = _prepare_data(
adapter=adapter,
Expand Down Expand Up @@ -209,8 +232,11 @@ def _construct_plot(
else:
nudge = 0

# If a human readable metric name is provided, use it in the title
metric_title = refined_metric_name if refined_metric_name else metric_name

return self._create_plotly_analysis_card(
title=f"Cross Validation for {metric_name}",
title=f"Cross Validation for {metric_title}",
subtitle=f"Out-of-sample predictions using {k_folds_substring} CV",
level=AnalysisCardLevel.LOW.value + nudge,
df=df,
Expand Down Expand Up @@ -263,42 +289,73 @@ def _prepare_data(
"arm_name": observed.arm_name,
"observed": observed.data.means[observed_i],
"predicted": predicted.means[predicted_i],
# Take the square root of the SEM to get the standard deviation
"observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5,
"predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5,
# Compute the 95% confidence intervals for plotting purposes
"observed_95_ci": observed.data.covariance[observed_i][observed_i]
** 0.5
* 1.96,
"predicted_95_ci": predicted.covariance[predicted_i][predicted_i] ** 0.5
* 1.96,
}
records.append(record)
return pd.DataFrame.from_records(records)


def _prepare_plot(df: pd.DataFrame) -> go.Figure:
fig = px.scatter(
df,
x="observed",
y="predicted",
error_x="observed_sem",
error_y="predicted_sem",
hover_data=["arm_name", "observed", "predicted"],
def _prepare_plot(
df: pd.DataFrame,
) -> go.Figure:
# Create a scatter plot using Plotly Graph Objects for more control
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=df["observed"],
y=df["predicted"],
mode="markers",
marker={
"color": "rgba(0, 0, 255, 0.3)", # partially transparent blue
},
error_x={
"type": "data",
"array": df["observed_95_ci"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
error_y={
"type": "data",
"array": df["predicted_95_ci"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
text=df["arm_name"],
hovertemplate=(
"<b>Arm Name: %{text}</b><br>"
+ "Predicted: %{y}<br>"
+ "Observed: %{x}<br>"
+ "<extra></extra>" # Removes the trace name from the hover
),
hoverlabel={
"bgcolor": "rgba(0, 0, 255, 0.2)", # partially transparent blue
"font": {"color": "black"},
},
)
)

# Add a gray dashed line at y=x starting and ending just outside of the region of
# interest for reference. A well fit model should have points clustered around this
# line.
# interest for reference. A well fit model should have points clustered around
# this line.
lower_bound = (
min(
(df["observed"] - df["observed_sem"].fillna(0)).min(),
(df["predicted"] - df["predicted_sem"].fillna(0)).min(),
(df["observed"] - df["observed_95_ci"].fillna(0)).min(),
(df["predicted"] - df["predicted_95_ci"].fillna(0)).min(),
)
* 0.99
* 0.999 # tight autozoom
)
upper_bound = (
max(
(df["observed"] + df["observed_sem"].fillna(0)).max(),
(df["predicted"] + df["predicted_sem"].fillna(0)).max(),
(df["observed"] + df["observed_95_ci"].fillna(0)).max(),
(df["predicted"] + df["predicted_95_ci"].fillna(0)).max(),
)
* 1.01
* 1.001 # tight autozoom
)

fig.add_shape(
type="line",
x0=lower_bound,
Expand All @@ -308,11 +365,14 @@ def _prepare_plot(df: pd.DataFrame) -> go.Figure:
line={"color": "gray", "dash": "dot"},
)

# Force plot to display as a square
fig.update_xaxes(range=[lower_bound, upper_bound], constrain="domain")
# Update axes with tight autozoom that remains square
fig.update_xaxes(
range=[lower_bound, upper_bound], constrain="domain", title="Actual Outcome"
)
fig.update_yaxes(
range=[lower_bound, upper_bound],
scaleanchor="x",
scaleratio=1,
title="Predicted Outcome",
)

return fig
19 changes: 18 additions & 1 deletion ax/analysis/plotly/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.modelbridge.registry import Generators
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.common.testutils import TestCase
from ax.utils.testing.mock import mock_botorch_optimize
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_compute(self) -> None:
self.assertEqual(card.category, AnalysisCardCategory.INSIGHT)
self.assertEqual(
{*card.df.columns},
{"arm_name", "observed", "observed_sem", "predicted", "predicted_sem"},
{"arm_name", "observed", "observed_95_ci", "predicted", "predicted_95_ci"},
)
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")
Expand Down Expand Up @@ -98,3 +99,19 @@ def test_it_can_specify_trial_index_correctly(self) -> None:
arm_name,
card.df["arm_name"].unique(),
)

@mock_botorch_optimize
def test_compute_adhoc(self) -> None:
metric_mapping = {"bar": "spunky"}
data = self.client.experiment.lookup_data()
adapter = Generators.BOTORCH_MODULAR(
experiment=self.client.experiment, data=data
)
analysis = CrossValidationPlot()._compute_adhoc(
adapter=adapter, data=data, metric_name_mapping=metric_mapping
)
self.assertEqual(len(analysis), 1)
card = analysis[0]
self.assertEqual(card.name, "CrossValidationPlot")
# validate that the metric name replacement occured
self.assertEqual(card.title, "Cross Validation for spunky")