Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to override ModelSpec.model_key #2726

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ def __post_init__(self) -> None:
)
model_spec = FactoryFunctionModelSpec(
factory_function=self.model,
# Only pass down the model name if it is not empty.
model_key_override=self.model_name if self.model_name else None,
model_kwargs=self.model_kwargs,
model_gen_kwargs=self.model_gen_kwargs,
)
Expand Down
38 changes: 20 additions & 18 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.exceptions.core import AxWarning, UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.cross_validation import (
compute_diagnostics,
Expand Down Expand Up @@ -61,6 +61,9 @@ class ModelSpec(SortableBase, SerializationMixin):
model_gen_kwargs: dict[str, Any] = field(default_factory=dict)
# Kwargs to pass to `cross_validate`.
model_cv_kwargs: dict[str, Any] = field(default_factory=dict)
# An optional override for the model key. Each `ModelSpec` in a
# `GenerationNode` must have a unique key to ensure identifiability.
model_key_override: Optional[str] = None

# Fitted model, constructed using specified `model_kwargs` and `Data`
# on `ModelSpec.fit`
Expand Down Expand Up @@ -106,12 +109,10 @@ def fixed_features(self, value: Optional[ObservationFeatures]) -> None:
@property
def model_key(self) -> str:
"""Key string to identify the model used by this ``ModelSpec``."""
# NOTE: In the future, might need to add more to model key to make
# model specs with the same model (but different kwargs) easier to
# distinguish from their key. Could also add separate property, just
# `key` (for `ModelSpec.key`, which will be unique even between model
# specs with same model type).
return self.model_enum.value
if self.model_key_override is not None:
return self.model_key_override
else:
return self.model_enum.value

def fit(
self,
Expand Down Expand Up @@ -342,23 +343,24 @@ def __post_init__(self) -> None:
"as the required `factory_function` argument to "
"`FactoryFunctionModelSpec`."
)
if self.model_key_override is None:
try:
# `model` is defined via a factory function.
# pyre-ignore[16]: Anonymous callable has no attribute `__name__`.
self.model_key_override = not_none(self.factory_function).__name__
except Exception:
raise TypeError(
f"{self.factory_function} is not a valid function, cannot extract "
"name. Please provide the model name using `model_key_override`."
)

warnings.warn(
"Using a factory function to describe the model, so optimization state "
"cannot be stored and optimization is not resumable if interrupted.",
AxWarning,
stacklevel=3,
)

@property
def model_key(self) -> str:
"""Key string to identify the model used by this ``ModelSpec``."""
try:
# `model` is defined via a factory function.
return not_none(self.factory_function).__name__ # pyre-ignore[16]
except Exception:
raise TypeError(
f"{self.factory_function} is not a valid function, cannot extract name."
)

def fit(
self,
experiment: Experiment,
Expand Down
11 changes: 11 additions & 0 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def test_init_factory_function(self) -> None:
generation_step.model_specs,
[FactoryFunctionModelSpec(factory_function=get_sobol)],
)
generation_step = GenerationStep(
model=get_sobol, num_trials=-1, model_name="test"
)
self.assertEqual(
generation_step.model_specs,
[
FactoryFunctionModelSpec(
factory_function=get_sobol, model_key_override="test"
)
],
)

def test_properties(self) -> None:
self.assertEqual(self.sobol_generation_step.model_spec, self.model_spec)
Expand Down
15 changes: 13 additions & 2 deletions ax/modelbridge/tests/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def test_fit(self, wrapped_extract_ssd: Mock) -> None:
wrapped_extract_ssd.assert_called_once()

def test_model_key(self) -> None:
ms = ModelSpec(model_enum=Models.GPEI)
self.assertEqual(ms.model_key, "GPEI")
ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR)
self.assertEqual(ms.model_key, "BoTorch")
ms = ModelSpec(
model_enum=Models.BOTORCH_MODULAR, model_key_override="MBM with defaults"
)
self.assertEqual(ms.model_key, "MBM with defaults")

@patch(f"{ModelSpec.__module__}.compute_diagnostics")
@patch(f"{ModelSpec.__module__}.cross_validate", return_value=["fake-cv-result"])
Expand Down Expand Up @@ -194,3 +198,10 @@ def test_construct(self) -> None:
def test_model_key(self) -> None:
ms = FactoryFunctionModelSpec(factory_function=get_sobol)
self.assertEqual(ms.model_key, "get_sobol")
with self.assertRaisesRegex(TypeError, "cannot extract name"):
# pyre-ignore[6] - Invalid factory function for testing.
FactoryFunctionModelSpec(factory_function="test")
ms = FactoryFunctionModelSpec(
factory_function=get_sobol, model_key_override="fancy sobol"
)
self.assertEqual(ms.model_key, "fancy sobol")