-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor BestModelSelector to operate on ModelSpecs (#2557)
Summary: Pull Request resolved: #2557 `BestModelSelector` was previously limited to selecting the best out of a given dictionary of CV diagnostics that were computed in `ModelSpec.cross_validate`. This setup limited extensibility, since any change would require updating `ModelSpec` code to the diagnostics that are computed. This diff refactors `BestModelSelector` to directly operate on the `ModelSpecs`. This new modular design will let each `BestModelSelector` class compute the necessary diagnostics internally, without locking us up to any pre-specified list. Other minor changes: - Removed `CallableEnum` and subclasses and replaced these with a single `ReductionCriterion` enum. - Split off `BestModelSelector` into a separate file to avoid circular imports. Differential Revision: D59249657
- Loading branch information
1 parent
f6bf1c6
commit 91640d8
Showing
7 changed files
with
212 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from functools import partial | ||
from typing import Callable, List, Union | ||
|
||
import numpy as np | ||
from ax.exceptions.core import UserInputError | ||
from ax.modelbridge.model_spec import ModelSpec | ||
from ax.utils.common.typeutils import not_none | ||
|
||
ARRAYLIKE = Union[np.ndarray, List[float], List[np.ndarray]] | ||
|
||
|
||
class BestModelSelector(ABC): | ||
@abstractmethod | ||
def best_model(self, model_specs: List[ModelSpec]) -> int: | ||
""" | ||
Return the index of the best ``ModelSpec``. | ||
""" | ||
|
||
|
||
class ReductionCriterion(Enum): | ||
"""An enum for callables that are used for aggregating diagnostics over metrics | ||
and selecting the best diagnostic in ``SingleDiagnosticBestModelSelector``. | ||
NOTE: This is used to ensure serializability of the callables. | ||
""" | ||
|
||
# NOTE: Callables need to be wrapped in `partial` to be registered as members. | ||
MEAN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.mean) | ||
MIN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.min) | ||
MAX: Callable[[ARRAYLIKE], np.ndarray] = partial(np.max) | ||
|
||
def __call__(self, array_like: ARRAYLIKE) -> np.ndarray: | ||
return self.value(array_like) | ||
|
||
|
||
class SingleDiagnosticBestModelSelector(BestModelSelector): | ||
"""Choose the best model using a single cross-validation diagnostic. | ||
The input is a list of ``ModelSpec``, each corresponding to one model. | ||
The specified diagnostic is extracted from each of the models, | ||
its values (each of which corresponds to a separate metric) are | ||
aggregated with the aggregation function, the best one is determined | ||
with the criterion, and the index of the best diagnostic result is returned. | ||
Example: | ||
:: | ||
s = SingleDiagnosticBestModelSelector( | ||
diagnostic = 'Fisher exact test p', | ||
metric_aggregation = ReductionCriterion.MEAN, | ||
criterion = ReductionCriterion.MIN, | ||
) | ||
best_diagnostic_index = s.best_diagnostic(diagnostics) | ||
Args: | ||
diagnostic: The name of the diagnostic to use, which should be | ||
a key in ``CVDiagnostic``. | ||
metric_aggregation: ``ReductionCriterion`` applied to the values of the | ||
diagnostic for a single model to produce a single number. | ||
criterion: ``ReductionCriterion`` used to determine which of the | ||
(aggregated) diagnostics is the best. | ||
Returns: | ||
int: index of the selected best diagnostic. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
diagnostic: str, | ||
metric_aggregation: ReductionCriterion, | ||
criterion: ReductionCriterion, | ||
) -> None: | ||
self.diagnostic = diagnostic | ||
if not isinstance(metric_aggregation, ReductionCriterion) or not isinstance( | ||
criterion, ReductionCriterion | ||
): | ||
raise UserInputError( | ||
"Both `metric_aggregation` and `criterion` must be " | ||
f"`ReductionCriterion`. Got {metric_aggregation=}, {criterion=}." | ||
) | ||
if criterion == ReductionCriterion.MEAN: | ||
raise UserInputError( | ||
f"{criterion=} is not supported. Please use MIN or MAX." | ||
) | ||
self.metric_aggregation = metric_aggregation | ||
self.criterion = criterion | ||
|
||
def best_model(self, model_specs: List[ModelSpec]) -> int: | ||
for model_spec in model_specs: | ||
model_spec.cross_validate() | ||
aggregated_diagnostic_values = [ | ||
self.metric_aggregation( | ||
list(not_none(model_spec.diagnostics)[self.diagnostic].values()) | ||
) | ||
for model_spec in model_specs | ||
] | ||
best_diagnostic = self.criterion(aggregated_diagnostic_values).item() | ||
return aggregated_diagnostic_values.index(best_diagnostic) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from unittest.mock import Mock | ||
|
||
from ax.exceptions.core import UserInputError | ||
from ax.modelbridge.best_model_selector import ( | ||
ReductionCriterion, | ||
SingleDiagnosticBestModelSelector, | ||
) | ||
from ax.modelbridge.model_spec import ModelSpec | ||
from ax.modelbridge.registry import Models | ||
from ax.utils.common.testutils import TestCase | ||
|
||
|
||
class TestBestModelSelector(TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
|
||
# Construct a series of model specs with dummy CV diagnostics. | ||
self.model_specs = [] | ||
for diagnostics in [ | ||
{"Fisher exact test p": {"y_a": 0.0, "y_b": 0.4}}, | ||
{"Fisher exact test p": {"y_a": 0.1, "y_b": 0.1}}, | ||
{"Fisher exact test p": {"y_a": 0.5, "y_b": 0.6}}, | ||
]: | ||
ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) | ||
ms._cv_results = Mock() | ||
ms._diagnostics = diagnostics | ||
self.model_specs.append(ms) | ||
|
||
def test_user_input_error(self) -> None: | ||
with self.assertRaisesRegex(UserInputError, "ReductionCriterion"): | ||
SingleDiagnosticBestModelSelector( | ||
"Fisher exact test p", metric_aggregation=min, criterion=max | ||
) | ||
with self.assertRaisesRegex(UserInputError, "use MIN or MAX"): | ||
SingleDiagnosticBestModelSelector( | ||
"Fisher exact test p", | ||
metric_aggregation=ReductionCriterion.MEAN, | ||
criterion=ReductionCriterion.MEAN, | ||
) | ||
|
||
def test_SingleDiagnosticBestModelSelector_min_mean(self) -> None: | ||
s = SingleDiagnosticBestModelSelector( | ||
diagnostic="Fisher exact test p", | ||
criterion=ReductionCriterion.MIN, | ||
metric_aggregation=ReductionCriterion.MEAN, | ||
) | ||
self.assertEqual(s.best_model(model_specs=self.model_specs), 1) | ||
|
||
def test_SingleDiagnosticBestModelSelector_min_min(self) -> None: | ||
s = SingleDiagnosticBestModelSelector( | ||
diagnostic="Fisher exact test p", | ||
criterion=ReductionCriterion.MIN, | ||
metric_aggregation=ReductionCriterion.MIN, | ||
) | ||
self.assertEqual(s.best_model(model_specs=self.model_specs), 0) | ||
|
||
def test_SingleDiagnosticBestModelSelector_max_mean(self) -> None: | ||
s = SingleDiagnosticBestModelSelector( | ||
diagnostic="Fisher exact test p", | ||
criterion=ReductionCriterion.MAX, | ||
metric_aggregation=ReductionCriterion.MEAN, | ||
) | ||
self.assertEqual(s.best_model(model_specs=self.model_specs), 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.