Skip to content

Commit

Permalink
Remove unused experiment arg from compute_model_fit_metrics_from_mode…
Browse files Browse the repository at this point in the history
…lbridge (#2504)

Summary:

This will make it easier to call on model spec.

Reviewed By: saitcakmak

Differential Revision: D58208209
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jun 6, 2024
1 parent 689eeed commit 6fbaf7c
Show file tree
Hide file tree
Showing 6 changed files with 0 additions and 13 deletions.
2 changes: 0 additions & 2 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple

import numpy as np
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, recombine_observations
from ax.core.optimization_config import OptimizationConfig
from ax.modelbridge.base import ModelBridge, unwrap_observation_data
Expand Down Expand Up @@ -495,7 +494,6 @@ def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:

def compute_model_fit_metrics_from_modelbridge(
model_bridge: ModelBridge,
experiment: Experiment,
fit_metrics_dict: Optional[Dict[str, ModelFitMetricProtocol]] = None,
generalization: bool = False,
untransform: bool = False,
Expand Down
4 changes: 0 additions & 4 deletions ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_model_fit_metrics(self) -> None:
# testing compute_model_fit_metrics_from_modelbridge with default metrics
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
untransform=False,
)
r2 = fit_metrics.get("coefficient_of_determination")
Expand All @@ -101,7 +100,6 @@ def test_model_fit_metrics(self) -> None:
with self.subTest(untransform=untransform):
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=generalization,
untransform=untransform,
fit_metrics_dict={"Entropy": entropy_of_observations},
Expand All @@ -128,7 +126,6 @@ def test_model_fit_metrics(self) -> None:
# testing with empty metrics
empty_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
fit_metrics_dict={},
)
self.assertIsInstance(empty_metrics, dict)
Expand All @@ -138,7 +135,6 @@ def test_model_fit_metrics(self) -> None:
with warnings.catch_warnings(record=True) as ws:
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
untransform=untransform,
generalization=generalization,
)
Expand Down
2 changes: 0 additions & 2 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,7 +1929,6 @@ def _helper_path_that_refits_the_model_if_it_is_not_already_initialized(
# testing compatibility with compute_model_fit_metrics_from_modelbridge
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
untransform=False,
)
r2 = fit_metrics.get("coefficient_of_determination")
Expand All @@ -1949,7 +1948,6 @@ def _helper_path_that_refits_the_model_if_it_is_not_already_initialized(
# testing with empty metrics dict
empty_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
fit_metrics_dict={},
untransform=False,
)
Expand Down
1 change: 0 additions & 1 deletion ax/service/utils/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,7 +1522,6 @@ def warn_if_unpredictable_metrics(
return None
model_fit_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=experiment,
generalization=True, # use generalization metrics for user warning
untransform=False,
)
Expand Down
2 changes: 0 additions & 2 deletions ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord:
model_bridge = get_fitted_model_bridge(scheduler)
model_fit_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=False,
untransform=False,
)
Expand All @@ -131,7 +130,6 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord:
# generalization metrics
model_gen_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=True,
untransform=False,
)
Expand Down
2 changes: 0 additions & 2 deletions ax/telemetry/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def test_scheduler_model_fit_metrics_logging(self) -> None:

fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=False,
untransform=False,
)
Expand All @@ -212,7 +211,6 @@ def test_scheduler_model_fit_metrics_logging(self) -> None:
# check generalization metrics
gen_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=True,
untransform=False,
)
Expand Down

0 comments on commit 6fbaf7c

Please sign in to comment.