Skip to content

Commit

Permalink
Add an option to override ModelSpec.model_key
Browse files Browse the repository at this point in the history
Summary:
Most of our models now utilize MBM, which can represent many different models using the same Models.BOTORCH_MODULAR. Being able to overwrite the default model key "BoTorch" will allow us to be more expressive about what model was used to generate a given candidate.

Follow-up diff will utilize this to require unique model key for each ModelSpec in a GenerationNode, which will ensure identifiability when multiple MBM models are used with model selection.

Differential Revision: D61984169
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Aug 29, 2024
1 parent 5d91216 commit ca3ee3e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
38 changes: 20 additions & 18 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.exceptions.core import AxWarning, UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.cross_validation import (
compute_diagnostics,
Expand Down Expand Up @@ -61,6 +61,9 @@ class ModelSpec(SortableBase, SerializationMixin):
model_gen_kwargs: dict[str, Any] = field(default_factory=dict)
# Kwargs to pass to `cross_validate`.
model_cv_kwargs: dict[str, Any] = field(default_factory=dict)
# An optional override for the model key. Each `ModelSpec` in a
# `GenerationNode` must have a unique key to ensure identifiability.
model_key_override: Optional[str] = None

# Fitted model, constructed using specified `model_kwargs` and `Data`
# on `ModelSpec.fit`
Expand Down Expand Up @@ -106,12 +109,10 @@ def fixed_features(self, value: Optional[ObservationFeatures]) -> None:
@property
def model_key(self) -> str:
"""Key string to identify the model used by this ``ModelSpec``."""
# NOTE: In the future, might need to add more to model key to make
# model specs with the same model (but different kwargs) easier to
# distinguish from their key. Could also add separate property, just
# `key` (for `ModelSpec.key`, which will be unique even between model
# specs with same model type).
return self.model_enum.value
if self.model_key_override is not None:
return self.model_key_override
else:
return self.model_enum.value

def fit(
self,
Expand Down Expand Up @@ -342,23 +343,24 @@ def __post_init__(self) -> None:
"as the required `factory_function` argument to "
"`FactoryFunctionModelSpec`."
)
if self.model_key_override is None:
try:
# `model` is defined via a factory function.
# pyre-ignore[16]: Anonymous callable has no attribute `__name__`.
self.model_key_override = not_none(self.factory_function).__name__
except Exception:
raise TypeError(
f"{self.factory_function} is not a valid function, cannot extract "
"name. Please provide the model name using `model_key_override`."
)

warnings.warn(
"Using a factory function to describe the model, so optimization state "
"cannot be stored and optimization is not resumable if interrupted.",
AxWarning,
stacklevel=3,
)

@property
def model_key(self) -> str:
"""Key string to identify the model used by this ``ModelSpec``."""
try:
# `model` is defined via a factory function.
return not_none(self.factory_function).__name__ # pyre-ignore[16]
except Exception:
raise TypeError(
f"{self.factory_function} is not a valid function, cannot extract name."
)

def fit(
self,
experiment: Experiment,
Expand Down
15 changes: 13 additions & 2 deletions ax/modelbridge/tests/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def test_fit(self, wrapped_extract_ssd: Mock) -> None:
wrapped_extract_ssd.assert_called_once()

def test_model_key(self) -> None:
ms = ModelSpec(model_enum=Models.GPEI)
self.assertEqual(ms.model_key, "GPEI")
ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR)
self.assertEqual(ms.model_key, "BoTorch")
ms = ModelSpec(
model_enum=Models.BOTORCH_MODULAR, model_key_override="MBM with defaults"
)
self.assertEqual(ms.model_key, "MBM with defaults")

@patch(f"{ModelSpec.__module__}.compute_diagnostics")
@patch(f"{ModelSpec.__module__}.cross_validate", return_value=["fake-cv-result"])
Expand Down Expand Up @@ -194,3 +198,10 @@ def test_construct(self) -> None:
def test_model_key(self) -> None:
ms = FactoryFunctionModelSpec(factory_function=get_sobol)
self.assertEqual(ms.model_key, "get_sobol")
with self.assertRaisesRegex(TypeError, "cannot extract name"):
# pyre-ignore[6] - Invalid factory function for testing.
FactoryFunctionModelSpec(factory_function="test")
ms = FactoryFunctionModelSpec(
factory_function=get_sobol, model_key_override="fancy sobol"
)
self.assertEqual(ms.model_key, "fancy sobol")

0 comments on commit ca3ee3e

Please sign in to comment.