-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Implements a `ExternalGenerationNode` class, which defers to arbitrary non-Ax / BoTorch based methods for candidate generation. This class mostly retains the signature of the `GenerationNode` for compatibility with `GenerationStrategy`, and requires the user to implement only the necessary bits. Since it is compatible with `GenerationStrategy`, it can be easily combined with other `GenerationNode`, e.g., with Sobol to use a shared initialization with other generation strategies. This makes it particularly suitable for benchmarking against other methods from Ax. The user needs to implement `__init__`, `update_model_state` and `get_next_trial_parameters`. Differential Revision: D54500745
- Loading branch information
1 parent
c695be6
commit 959fc76
Showing
6 changed files
with
352 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import time | ||
from abc import ABC, abstractmethod | ||
from logging import Logger | ||
from typing import Any, Dict, List, Optional, Sequence | ||
|
||
from ax.core.arm import Arm | ||
from ax.core.data import Data | ||
from ax.core.experiment import Experiment | ||
from ax.core.generator_run import GeneratorRun | ||
from ax.core.observation import ObservationFeatures | ||
from ax.core.optimization_config import OptimizationConfig | ||
from ax.core.search_space import SearchSpace | ||
from ax.core.types import TParameterization | ||
from ax.exceptions.core import UnsupportedError | ||
from ax.modelbridge.generation_node import GenerationNode | ||
from ax.modelbridge.model_spec import ModelSpec | ||
from ax.modelbridge.transition_criterion import TransitionCriterion | ||
from ax.utils.common.logger import get_logger | ||
|
||
logger: Logger = get_logger(__name__) | ||
|
||
|
||
# TODO[drfreund]: Introduce a `GenerationNodeInterface` to | ||
# make inheritance/overriding of `GenNode` methods cleaner. | ||
class ExternalGenerationNode(GenerationNode, ABC): | ||
"""A generation node intended to be used with non-Ax methods for | ||
candidate generation. | ||
To leverage external methods for candidate generation, the user must | ||
create a subclass that implements ``update_model_state`` and | ||
``get_next_trial_parameters`` methods. This can then be provided | ||
as a node into a ``GenerationStrategy``, either as standalone or as | ||
part of a larger generation strategy with other generation nodes, | ||
e.g., with a Sobol node for initialization. | ||
Example: | ||
>>> class MyExternalGenerationNode(ExternalGenerationNode): | ||
>>> ... | ||
>>> generation_strategy = GenerationStrategy( | ||
>>> nodes = [MyExternalGenerationNode(...)] | ||
>>> ) | ||
>>> ax_client = AxClient(generation_strategy=generation_strategy) | ||
>>> ax_client.create_experiment(...) | ||
>>> ax_client.get_next_trial() # Generates trials using the new generation node. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
node_name: str, | ||
should_deduplicate: bool = True, | ||
transition_criteria: Optional[Sequence[TransitionCriterion]] = None, | ||
gen_unlimited_trials: bool = True, | ||
) -> None: | ||
"""Initialize an external generation node. | ||
NOTE: The runtime accounting in this method should be replicated by the | ||
subclasses. This will ensure accurate comparison of runtimes between | ||
methods, in case a non-significant compute is spent in the constructor. | ||
Args: | ||
node_name: Name of the generation node. | ||
should_deduplicate: Whether to deduplicate the generated points against | ||
the existing trials on the experiment. If True, the duplicate points | ||
will be discarded and re-generated up to 5 times, after which a | ||
`GenerationStrategyRepeatedPoints` exception will be raised. | ||
transition_criteria: Criteria for determining whether to move to the next | ||
node in the generation strategy. This is an advanced option that is | ||
only relevant if the generation strategy consists of multiple nodes. | ||
gen_unlimited_trials: Whether to generate unlimited trials from this node. | ||
This should only be False if the generation strategy will transition to | ||
another node after generating a limited number of trials from this node. | ||
""" | ||
t_init_start = time.monotonic() | ||
super().__init__( | ||
node_name=node_name, | ||
model_specs=[], | ||
best_model_selector=None, | ||
should_deduplicate=should_deduplicate, | ||
transition_criteria=transition_criteria, | ||
gen_unlimited_trials=gen_unlimited_trials, | ||
) | ||
self.fit_time_since_gen: float = time.monotonic() - t_init_start | ||
|
||
@abstractmethod | ||
def update_model_state(self, experiment: Experiment, data: Data) -> None: | ||
"""A method used to update the state of any models / predictors used by the | ||
generation node. | ||
Args: | ||
experiment: The ``Experiment`` object representing the current state of the | ||
experiment. The key properties includes ``trials``, ``search_space``, | ||
and ``optimization_config``. The data is provided as a separate arg. | ||
data: The data / metrics collected on the experiment so far. | ||
""" | ||
|
||
@abstractmethod | ||
def get_next_trial_parameters( | ||
self, pending_parameters: List[TParameterization] | ||
) -> TParameterization: | ||
"""Get the parameters for the next trial. | ||
Args: | ||
pending_parameters: A list of parameters of the trials pending evaluation. | ||
Returns: | ||
A dictionary mapping parameter names to parameter values for the next | ||
candidate suggested by the method. | ||
""" | ||
|
||
@property | ||
def model_enum(self) -> Optional[str]: | ||
return None | ||
|
||
@property | ||
def _fitted_model(self) -> None: | ||
return None | ||
|
||
@property | ||
def model_spec_to_gen_from(self) -> ModelSpec: | ||
raise NotImplementedError( | ||
"`ExternalGenerationNode` does not utilize `ModelSpec`s " | ||
"and does not define methods related to `ModelSpec`." | ||
) | ||
|
||
def fit( | ||
self, | ||
experiment: Experiment, | ||
data: Data, | ||
search_space: Optional[SearchSpace] = None, | ||
optimization_config: Optional[OptimizationConfig] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
"""A method used to initialize or update the experiment state / data | ||
on any surrogate models or predictors used during candidate generation. | ||
This method records the time spent during the update and defers to | ||
`update_model_state` for the actual work. | ||
Args: | ||
experiment: The experiment to fit the surrogate model / predictor to. | ||
data: The experiment data used to fit the model. | ||
search_space: UNSUPPORTED. An optional override for the experiment | ||
search space. | ||
optimization_config: UNSUPPORTED. An optional override for the experiment | ||
optimization config. | ||
kwargs: UNSUPPORTED. Additional keyword arguments for model fitting. | ||
""" | ||
if search_space is not None or optimization_config is not None or kwargs: | ||
raise UnsupportedError( | ||
"Unexpected arguments encountered. `ExternalGenerationNode.fit` only " | ||
"supports `experiment` and `data` arguments. " | ||
"Each of the following arguments should be None / empty. " | ||
f"{search_space=}, {optimization_config=}, {kwargs=}." | ||
) | ||
t_fit_start = time.monotonic() | ||
self.update_model_state( | ||
experiment=experiment, | ||
data=data, | ||
) | ||
self.fit_time_since_gen += time.monotonic() - t_fit_start | ||
|
||
def _gen( | ||
self, | ||
n: Optional[int] = None, | ||
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None, | ||
**model_gen_kwargs: Any, | ||
) -> GeneratorRun: | ||
"""Generate new candidates for evaluation. | ||
This method calls `get_next_trial_parameterizations` to get the parameters | ||
for the next trial(s), and packages it as needed for higher level Ax APIs. | ||
If `should_deduplicate=True`, this also checks for duplicates and re-generates | ||
the parameters as needed. | ||
Args: | ||
n: Optional integer representing how many arms should be in the generator | ||
run produced by this method. Defaults to 1. | ||
pending_observations: A map from metric name to pending | ||
observations for that metric, used by some models to avoid | ||
re-suggesting points that are currently being evaluated. | ||
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``; | ||
these override any pre-specified in ``ModelSpec.model_gen_kwargs``. | ||
Returns: | ||
A ``GeneratorRun`` containing the newly generated candidates. | ||
""" | ||
t_gen_start = time.monotonic() | ||
n = 1 if n is None else n | ||
pending_parameters: List[TParameterization] = [ | ||
[o.parameters for obs in (pending_observations or {}).values() for o in obs] | ||
] | ||
generated_params: List[TParameterization] = [] | ||
for _ in range(n): | ||
params = self.get_next_trial_parameters( | ||
pending_parameters=pending_parameters | ||
) | ||
generated_params.append(params) | ||
pending_parameters.append(params) | ||
# Return the parameterizations as a generator run. | ||
generator_run = GeneratorRun( | ||
arms=[Arm(parameters=params) for params in generated_params], | ||
fit_time=self.fit_time_since_gen, | ||
gen_time=time.monotonic() - t_gen_start, | ||
model_key=self.node_name, | ||
) | ||
# TODO: This shares the same bug as ModelBridge.gen. In both cases, after | ||
# deduplication, the generator run will record fit_time as 0. | ||
self.fit_time_since_gen = 0 | ||
return generator_run |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.