diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 20562e7a1e1..62e7d94ad3a 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -666,6 +666,8 @@ def __post_init__(self) -> None: ) model_spec = FactoryFunctionModelSpec( factory_function=self.model, + # Only pass down the model name if it is not empty. + model_key_override=self.model_name if self.model_name else None, model_kwargs=self.model_kwargs, model_gen_kwargs=self.model_gen_kwargs, ) diff --git a/ax/modelbridge/model_spec.py b/ax/modelbridge/model_spec.py index b11195edb29..48b61806f08 100644 --- a/ax/modelbridge/model_spec.py +++ b/ax/modelbridge/model_spec.py @@ -20,7 +20,7 @@ from ax.core.observation import ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace -from ax.exceptions.core import UserInputError +from ax.exceptions.core import AxWarning, UserInputError from ax.modelbridge.base import ModelBridge from ax.modelbridge.cross_validation import ( compute_diagnostics, @@ -61,6 +61,9 @@ class ModelSpec(SortableBase, SerializationMixin): model_gen_kwargs: dict[str, Any] = field(default_factory=dict) # Kwargs to pass to `cross_validate`. model_cv_kwargs: dict[str, Any] = field(default_factory=dict) + # An optional override for the model key. Each `ModelSpec` in a + # `GenerationNode` must have a unique key to ensure identifiability. + model_key_override: Optional[str] = None # Fitted model, constructed using specified `model_kwargs` and `Data` # on `ModelSpec.fit` @@ -106,12 +109,10 @@ def fixed_features(self, value: Optional[ObservationFeatures]) -> None: @property def model_key(self) -> str: """Key string to identify the model used by this ``ModelSpec``.""" - # NOTE: In the future, might need to add more to model key to make - # model specs with the same model (but different kwargs) easier to - # distinguish from their key. Could also add separate property, just - # `key` (for `ModelSpec.key`, which will be unique even between model - # specs with same model type). - return self.model_enum.value + if self.model_key_override is not None: + return self.model_key_override + else: + return self.model_enum.value def fit( self, @@ -342,23 +343,24 @@ def __post_init__(self) -> None: "as the required `factory_function` argument to " "`FactoryFunctionModelSpec`." ) + if self.model_key_override is None: + try: + # `model` is defined via a factory function. + # pyre-ignore[16]: Anonymous callable has no attribute `__name__`. + self.model_key_override = not_none(self.factory_function).__name__ + except Exception: + raise TypeError( + f"{self.factory_function} is not a valid function, cannot extract " + "name. Please provide the model name using `model_key_override`." + ) + warnings.warn( "Using a factory function to describe the model, so optimization state " "cannot be stored and optimization is not resumable if interrupted.", + AxWarning, stacklevel=3, ) - @property - def model_key(self) -> str: - """Key string to identify the model used by this ``ModelSpec``.""" - try: - # `model` is defined via a factory function. - return not_none(self.factory_function).__name__ # pyre-ignore[16] - except Exception: - raise TypeError( - f"{self.factory_function} is not a valid function, cannot extract name." - ) - def fit( self, experiment: Experiment, diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index f122e6d26ca..4e4179d42ba 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -252,6 +252,17 @@ def test_init_factory_function(self) -> None: generation_step.model_specs, [FactoryFunctionModelSpec(factory_function=get_sobol)], ) + generation_step = GenerationStep( + model=get_sobol, num_trials=-1, model_name="test" + ) + self.assertEqual( + generation_step.model_specs, + [ + FactoryFunctionModelSpec( + factory_function=get_sobol, model_key_override="test" + ) + ], + ) def test_properties(self) -> None: self.assertEqual(self.sobol_generation_step.model_spec, self.model_spec) diff --git a/ax/modelbridge/tests/test_model_spec.py b/ax/modelbridge/tests/test_model_spec.py index 76d1d4d6041..ee24532bcf6 100644 --- a/ax/modelbridge/tests/test_model_spec.py +++ b/ax/modelbridge/tests/test_model_spec.py @@ -67,8 +67,12 @@ def test_fit(self, wrapped_extract_ssd: Mock) -> None: wrapped_extract_ssd.assert_called_once() def test_model_key(self) -> None: - ms = ModelSpec(model_enum=Models.GPEI) - self.assertEqual(ms.model_key, "GPEI") + ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + self.assertEqual(ms.model_key, "BoTorch") + ms = ModelSpec( + model_enum=Models.BOTORCH_MODULAR, model_key_override="MBM with defaults" + ) + self.assertEqual(ms.model_key, "MBM with defaults") @patch(f"{ModelSpec.__module__}.compute_diagnostics") @patch(f"{ModelSpec.__module__}.cross_validate", return_value=["fake-cv-result"]) @@ -194,3 +198,10 @@ def test_construct(self) -> None: def test_model_key(self) -> None: ms = FactoryFunctionModelSpec(factory_function=get_sobol) self.assertEqual(ms.model_key, "get_sobol") + with self.assertRaisesRegex(TypeError, "cannot extract name"): + # pyre-ignore[6] - Invalid factory function for testing. + FactoryFunctionModelSpec(factory_function="test") + ms = FactoryFunctionModelSpec( + factory_function=get_sobol, model_key_override="fancy sobol" + ) + self.assertEqual(ms.model_key, "fancy sobol")