Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require unique model_key in GenerationNode, clean up model selection errors #2730

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
logger: Logger = get_logger(__name__)

TModelFactory = Callable[..., ModelBridge]
CANNOT_SELECT_ONE_MODEL_MSG = """\
Base `GenerationNode` does not implement selection among fitted \
models, so exactly one `ModelSpec` must be specified when using \
`GenerationNode._pick_fitted_model_to_gen_from` (usually called \
by `GenerationNode.gen`.
"""
MISSING_MODEL_SELECTOR_MESSAGE = (
"A `BestModelSelector` must be provided when using multiple "
"`ModelSpec`s in a `GenerationNode`. After fitting all `ModelSpec`s, "
"the `BestModelSelector` will be used to select the `ModelSpec` to "
"use for candidate generation."
)
MAX_GEN_DRAWS = 5
MAX_GEN_DRAWS_EXCEEDED_MESSAGE = (
f"GenerationStrategy exceeded `MAX_GEN_DRAWS` of {MAX_GEN_DRAWS} while trying to "
Expand Down Expand Up @@ -116,11 +116,14 @@ def __init__(
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
) -> None:
self._node_name = node_name
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
# and `_pick_fitted_model_to_gen_from` methods, we validate the
# length of `model_specs` in `_pick_fitted_model_to_gen_from` in order
# to not require all `GenerationNode` subclasses to override an `__init__`
# method to bypass that validation.
# Check that the model specs have unique model keys.
model_keys = {model_spec.model_key for model_spec in model_specs}
if len(model_keys) != len(model_specs):
raise UserInputError(
"Model keys must be unique across all model specs in a GenerationNode."
)
if len(model_specs) > 1 and best_model_selector is None:
raise UserInputError(MISSING_MODEL_SELECTOR_MESSAGE)
self.model_specs = model_specs
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
Expand Down Expand Up @@ -358,8 +361,8 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
`ModelSpec` and select it.
"""
if self.best_model_selector is None:
if len(self.model_specs) != 1:
raise NotImplementedError(CANNOT_SELECT_ONE_MODEL_MSG)
if len(self.model_specs) != 1: # pragma: no cover -- raised in __init__.
raise UserInputError(MISSING_MODEL_SELECTOR_MESSAGE)
return self.model_specs[0]

best_model = not_none(self.best_model_selector).best_model(
Expand Down Expand Up @@ -678,11 +681,7 @@ def __post_init__(self) -> None:
model_gen_kwargs=self.model_gen_kwargs,
)
if self.model_name == "":
try:
self.model_name = model_spec.model_key
except TypeError:
# Factory functions may not always have a model key defined.
self.model_name = f"Unknown {model_spec.__class__.__name__}"
self.model_name = model_spec.model_key

# Create transition criteria for this step. MaximumTrialsInStatus can be used
# to ensure that requirements related to num_trials and unlimited trials
Expand Down
58 changes: 32 additions & 26 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from logging import Logger
from unittest.mock import MagicMock, patch, PropertyMock
from unittest.mock import MagicMock, patch

from ax.core.base_trial import TrialStatus
from ax.core.observation import ObservationFeatures
Expand All @@ -17,7 +17,11 @@
SingleDiagnosticBestModelSelector,
)
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.generation_node import (
GenerationNode,
GenerationStep,
MISSING_MODEL_SELECTOR_MESSAGE,
)
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
Expand Down Expand Up @@ -46,6 +50,32 @@ def test_init(self) -> None:
self.assertEqual(
self.sobol_generation_node.model_specs, [self.sobol_model_spec]
)
with self.assertRaisesRegex(UserInputError, "Model keys must be unique"):
GenerationNode(
node_name="test",
model_specs=[self.sobol_model_spec, self.sobol_model_spec],
)
mbm_specs = [
ModelSpec(model_enum=Models.BOTORCH_MODULAR),
ModelSpec(model_enum=Models.BOTORCH_MODULAR, model_key_override="MBM v2"),
]
with self.assertRaisesRegex(UserInputError, MISSING_MODEL_SELECTOR_MESSAGE):
GenerationNode(
node_name="test",
model_specs=mbm_specs,
)
model_selector = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
metric_aggregation=ReductionCriterion.MEAN,
criterion=ReductionCriterion.MIN,
)
node = GenerationNode(
node_name="test",
model_specs=mbm_specs,
best_model_selector=model_selector,
)
self.assertEqual(node.model_specs, mbm_specs)
self.assertIs(node.best_model_selector, model_selector)

def test_fit(self) -> None:
dat = self.branin_experiment.lookup_data()
Expand Down Expand Up @@ -80,20 +110,6 @@ def test_gen(self) -> None:
# pyre-fixme[16]: Optional type has no attribute `get`.
self.assertEqual(gr._model_kwargs.get("init_position"), 3)

def test_gen_validates_one_model_spec(self) -> None:
generation_node = GenerationNode(
node_name="test",
model_specs=[self.sobol_model_spec, self.sobol_model_spec],
)
# Base generation node can only handle one model spec at the moment
# (this might change in the future), so it should raise a `NotImplemented
# Error` if we attempt to generate from a generation node that has
# more than one model spec. Note that the check is done in `gen` and
# not in the constructor to make `GenerationNode` mode convenient to
# subclass.
with self.assertRaises(NotImplementedError):
generation_node.gen()

@fast_botorch_optimize
def test_properties(self) -> None:
node = GenerationNode(
Expand Down Expand Up @@ -227,16 +243,6 @@ def test_init(self) -> None:
)
self.assertEqual(named_generation_step.model_name, "Custom Sobol")

with patch.object(
ModelSpec, "model_key", new=PropertyMock(side_effect=TypeError)
):
unknown_generation_step = GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.model_kwargs,
)
self.assertEqual(unknown_generation_step.model_name, "Unknown ModelSpec")

def test_min_trials_observed(self) -> None:
with self.assertRaisesRegex(UserInputError, "min_trials_observed > num_trials"):
GenerationStep(
Expand Down