Skip to content

Commit

Permalink
ExternalGenerationNode (#2266)
Browse files Browse the repository at this point in the history
Summary:

Implements a `ExternalGenerationNode` class, which defers to arbitrary non-Ax / BoTorch based methods for candidate generation. This class mostly retains the signature of the `GenerationNode` for compatibility with `GenerationStrategy`, and requires the user to implement only the necessary bits. Since it is compatible with `GenerationStrategy`, it can be easily combined with other `GenerationNode`, e.g., with Sobol to use a shared initialization with other generation strategies. This makes it particularly suitable for benchmarking against other methods from Ax.

The user needs to implement `__init__`, `update_model_state` and `get_next_trial_parameters`.

Differential Revision: D54500745
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Mar 13, 2024
1 parent c695be6 commit 959fc76
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 20 deletions.
215 changes: 215 additions & 0 deletions ax/modelbridge/external_generation_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time
from abc import ABC, abstractmethod
from logging import Logger
from typing import Any, Dict, List, Optional, Sequence

from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.transition_criterion import TransitionCriterion
from ax.utils.common.logger import get_logger

logger: Logger = get_logger(__name__)


# TODO[drfreund]: Introduce a `GenerationNodeInterface` to
# make inheritance/overriding of `GenNode` methods cleaner.
class ExternalGenerationNode(GenerationNode, ABC):
"""A generation node intended to be used with non-Ax methods for
candidate generation.
To leverage external methods for candidate generation, the user must
create a subclass that implements ``update_model_state`` and
``get_next_trial_parameters`` methods. This can then be provided
as a node into a ``GenerationStrategy``, either as standalone or as
part of a larger generation strategy with other generation nodes,
e.g., with a Sobol node for initialization.
Example:
>>> class MyExternalGenerationNode(ExternalGenerationNode):
>>> ...
>>> generation_strategy = GenerationStrategy(
>>> nodes = [MyExternalGenerationNode(...)]
>>> )
>>> ax_client = AxClient(generation_strategy=generation_strategy)
>>> ax_client.create_experiment(...)
>>> ax_client.get_next_trial() # Generates trials using the new generation node.
"""

def __init__(
self,
node_name: str,
should_deduplicate: bool = True,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
gen_unlimited_trials: bool = True,
) -> None:
"""Initialize an external generation node.
NOTE: The runtime accounting in this method should be replicated by the
subclasses. This will ensure accurate comparison of runtimes between
methods, in case a non-significant compute is spent in the constructor.
Args:
node_name: Name of the generation node.
should_deduplicate: Whether to deduplicate the generated points against
the existing trials on the experiment. If True, the duplicate points
will be discarded and re-generated up to 5 times, after which a
`GenerationStrategyRepeatedPoints` exception will be raised.
transition_criteria: Criteria for determining whether to move to the next
node in the generation strategy. This is an advanced option that is
only relevant if the generation strategy consists of multiple nodes.
gen_unlimited_trials: Whether to generate unlimited trials from this node.
This should only be False if the generation strategy will transition to
another node after generating a limited number of trials from this node.
"""
t_init_start = time.monotonic()
super().__init__(
node_name=node_name,
model_specs=[],
best_model_selector=None,
should_deduplicate=should_deduplicate,
transition_criteria=transition_criteria,
gen_unlimited_trials=gen_unlimited_trials,
)
self.fit_time_since_gen: float = time.monotonic() - t_init_start

@abstractmethod
def update_model_state(self, experiment: Experiment, data: Data) -> None:
"""A method used to update the state of any models / predictors used by the
generation node.
Args:
experiment: The ``Experiment`` object representing the current state of the
experiment. The key properties includes ``trials``, ``search_space``,
and ``optimization_config``. The data is provided as a separate arg.
data: The data / metrics collected on the experiment so far.
"""

@abstractmethod
def get_next_trial_parameters(
self, pending_parameters: List[TParameterization]
) -> TParameterization:
"""Get the parameters for the next trial.
Args:
pending_parameters: A list of parameters of the trials pending evaluation.
Returns:
A dictionary mapping parameter names to parameter values for the next
candidate suggested by the method.
"""

@property
def model_enum(self) -> Optional[str]:
return None

@property
def _fitted_model(self) -> None:
return None

@property
def model_spec_to_gen_from(self) -> ModelSpec:
raise NotImplementedError(
"`ExternalGenerationNode` does not utilize `ModelSpec`s "
"and does not define methods related to `ModelSpec`."
)

def fit(
self,
experiment: Experiment,
data: Data,
search_space: Optional[SearchSpace] = None,
optimization_config: Optional[OptimizationConfig] = None,
**kwargs: Any,
) -> None:
"""A method used to initialize or update the experiment state / data
on any surrogate models or predictors used during candidate generation.
This method records the time spent during the update and defers to
`update_model_state` for the actual work.
Args:
experiment: The experiment to fit the surrogate model / predictor to.
data: The experiment data used to fit the model.
search_space: UNSUPPORTED. An optional override for the experiment
search space.
optimization_config: UNSUPPORTED. An optional override for the experiment
optimization config.
kwargs: UNSUPPORTED. Additional keyword arguments for model fitting.
"""
if search_space is not None or optimization_config is not None or kwargs:
raise UnsupportedError(
"Unexpected arguments encountered. `ExternalGenerationNode.fit` only "
"supports `experiment` and `data` arguments. "
"Each of the following arguments should be None / empty. "
f"{search_space=}, {optimization_config=}, {kwargs=}."
)
t_fit_start = time.monotonic()
self.update_model_state(
experiment=experiment,
data=data,
)
self.fit_time_since_gen += time.monotonic() - t_fit_start

def _gen(
self,
n: Optional[int] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
"""Generate new candidates for evaluation.
This method calls `get_next_trial_parameterizations` to get the parameters
for the next trial(s), and packages it as needed for higher level Ax APIs.
If `should_deduplicate=True`, this also checks for duplicates and re-generates
the parameters as needed.
Args:
n: Optional integer representing how many arms should be in the generator
run produced by this method. Defaults to 1.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
re-suggesting points that are currently being evaluated.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.
Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
t_gen_start = time.monotonic()
n = 1 if n is None else n
pending_parameters: List[TParameterization] = [
[o.parameters for obs in (pending_observations or {}).values() for o in obs]
]
generated_params: List[TParameterization] = []
for _ in range(n):
params = self.get_next_trial_parameters(
pending_parameters=pending_parameters
)
generated_params.append(params)
pending_parameters.append(params)
# Return the parameterizations as a generator run.
generator_run = GeneratorRun(
arms=[Arm(parameters=params) for params in generated_params],
fit_time=self.fit_time_since_gen,
gen_time=time.monotonic() - t_gen_start,
model_key=self.node_name,
)
# TODO: This shares the same bug as ModelBridge.gen. In both cases, after
# deduplication, the generator run will record fit_time as 0.
self.fit_time_since_gen = 0
return generator_run
66 changes: 48 additions & 18 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def fit(
**kwargs,
)

# TODO [drfreund]: Move this up to `GenerationNodeInterface` once implemented.
def gen(
self,
n: Optional[int] = None,
Expand All @@ -283,11 +284,8 @@ def gen(
arms_by_signature_for_deduplication: Optional[Dict[str, Arm]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
"""Picks a fitted model, from which to generate candidates (via
``self._pick_fitted_model_to_gen_from``) and generates candidates
from it. Uses the ``model_gen_kwargs`` set on the selected ``ModelSpec``
alongside any kwargs passed in to this function (with local kwargs)
taking precedent.
"""This method generates candidates using `self._gen` and handles deduplication
of generated candidates if `self.should_deduplicate=True`.
NOTE: Models must have been fit prior to calling ``gen``.
NOTE: Some underlying models may ignore the ``n`` argument and produce a
Expand All @@ -305,34 +303,25 @@ def gen(
new candidates without duplicates. If non-duplicate candidates are not
generated with these attempts, a ``GenerationStrategyRepeatedPoints``
exception will be raised.
arms_by_signature_for_deduplication: A dictionary mapping arm signatures to
the arms, to be used for deduplicating newly generated arms.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.
Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
model_spec = self.model_spec_to_gen_from
should_generate_run = True
generator_run = None
n_gen_draws = 0
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while should_generate_run:
generator_run = model_spec.gen(
# If `n` is not specified, ensure that the `None` value does not
# override the one set in `model_spec.model_gen_kwargs`.
n=(
model_spec.model_gen_kwargs.get("n")
if n is None and model_spec.model_gen_kwargs
else n
),
# For `pending_observations`, prefer the input to this function, as
# `pending_observations` are dynamic throughout the experiment and thus
# unlikely to be specified in `model_spec.model_gen_kwargs`.
generator_run = self._gen(
n=n,
pending_observations=pending_observations,
**model_gen_kwargs,
)

should_generate_run = (
self.should_deduplicate
and arms_by_signature_for_deduplication
Expand Down Expand Up @@ -360,6 +349,47 @@ def gen(
generator_run._generation_node_name = self.node_name
return generator_run

def _gen(
self,
n: Optional[int] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
"""Picks a fitted model, from which to generate candidates (via
``self._pick_fitted_model_to_gen_from``) and generates candidates
from it. Uses the ``model_gen_kwargs`` set on the selected ``ModelSpec``
alongside any kwargs passed in to this function (with local kwargs)
taking precedent.
Args:
n: Optional integer representing how many arms should be in the generator
run produced by this method. When this is ``None``, ``n`` will be
determined by the ``ModelSpec`` that we are generating from.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.
Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
model_spec = self.model_spec_to_gen_from
return model_spec.gen(
# If `n` is not specified, ensure that the `None` value does not
# override the one set in `model_spec.model_gen_kwargs`.
n=(
model_spec.model_gen_kwargs.get("n")
if n is None and model_spec.model_gen_kwargs
else n
),
# For `pending_observations`, prefer the input to this function, as
# `pending_observations` are dynamic throughout the experiment and thus
# unlikely to be specified in `model_spec.model_gen_kwargs`.
pending_observations=pending_observations,
**model_gen_kwargs,
)

# ------------------------- Model selection logic helpers. -------------------------

def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
)

from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.model_spec import FactoryFunctionModelSpec
Expand Down
Loading

0 comments on commit 959fc76

Please sign in to comment.