Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalGenerationNode #2266

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions ax/modelbridge/external_generation_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time
from abc import ABC, abstractmethod
from logging import Logger
from typing import Any, Dict, List, Optional, Sequence

from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.transition_criterion import TransitionCriterion
from ax.utils.common.logger import get_logger

logger: Logger = get_logger(__name__)


# TODO[drfreund]: Introduce a `GenerationNodeInterface` to
# make inheritance/overriding of `GenNode` methods cleaner.
class ExternalGenerationNode(GenerationNode, ABC):
"""A generation node intended to be used with non-Ax methods for
candidate generation.

To leverage external methods for candidate generation, the user must
create a subclass that implements ``update_generator_state`` and
``get_next_candidate`` methods. This can then be provided
as a node into a ``GenerationStrategy``, either as standalone or as
part of a larger generation strategy with other generation nodes,
e.g., with a Sobol node for initialization.

Example:
>>> class MyExternalGenerationNode(ExternalGenerationNode):
>>> ...
>>> generation_strategy = GenerationStrategy(
>>> nodes = [MyExternalGenerationNode(...)]
>>> )
>>> ax_client = AxClient(generation_strategy=generation_strategy)
>>> ax_client.create_experiment(...)
>>> ax_client.get_next_trial() # Generates trials using the new generation node.
"""

def __init__(
self,
node_name: str,
should_deduplicate: bool = True,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
gen_unlimited_trials: bool = True,
) -> None:
"""Initialize an external generation node.

NOTE: The runtime accounting in this method should be replicated by the
subclasses. This will ensure accurate comparison of runtimes between
methods, in case a non-significant compute is spent in the constructor.

Args:
node_name: Name of the generation node.
should_deduplicate: Whether to deduplicate the generated points against
the existing trials on the experiment. If True, the duplicate points
will be discarded and re-generated up to 5 times, after which a
`GenerationStrategyRepeatedPoints` exception will be raised.
NOTE: For this to work, the generator must be able to produce a
different parameterization when called again with the same state.
transition_criteria: Criteria for determining whether to move to the next
node in the generation strategy. This is an advanced option that is
only relevant if the generation strategy consists of multiple nodes.
gen_unlimited_trials: Whether to generate unlimited trials from this node.
This should only be False if the generation strategy will transition to
another node after generating a limited number of trials from this node.
"""
t_init_start = time.monotonic()
super().__init__(
node_name=node_name,
model_specs=[],
best_model_selector=None,
should_deduplicate=should_deduplicate,
transition_criteria=transition_criteria,
gen_unlimited_trials=gen_unlimited_trials,
)
self.fit_time_since_gen: float = time.monotonic() - t_init_start

@abstractmethod
def update_generator_state(self, experiment: Experiment, data: Data) -> None:
"""A method used to update the state of the generator. This includes any
models, predictors or any other custom state used by the generation node.
This method will be called with the up-to-date experiment and data before
``get_next_candidate`` is called to generate the next trial(s). Note
that ``get_next_candidate`` may be called multiple times (to generate
multiple candidates) after a call to ``update_generator_state``.

Args:
experiment: The ``Experiment`` object representing the current state of the
experiment. The key properties includes ``trials``, ``search_space``,
and ``optimization_config``. The data is provided as a separate arg.
data: The data / metrics collected on the experiment so far.
"""

@abstractmethod
def get_next_candidate(
self, pending_parameters: List[TParameterization]
) -> TParameterization:
"""Get the parameters for the next candidate configuration to evaluate.

Args:
pending_parameters: A list of parameters of the candidates pending
evaluation. This is often used to avoid generating duplicate candidates.

Returns:
A dictionary mapping parameter names to parameter values for the next
candidate suggested by the method.
"""

@property
def model_enum(self) -> Optional[str]:
return None

@property
def _fitted_model(self) -> None:
return None

@property
def model_spec_to_gen_from(self) -> ModelSpec:
raise NotImplementedError(
"`ExternalGenerationNode` does not utilize `ModelSpec`s "
"and does not define methods related to `ModelSpec`."
)

def fit(
self,
experiment: Experiment,
data: Data,
search_space: Optional[SearchSpace] = None,
optimization_config: Optional[OptimizationConfig] = None,
**kwargs: Any,
) -> None:
"""A method used to initialize or update the experiment state / data
on any surrogate models or predictors used during candidate generation.

This method records the time spent during the update and defers to
`update_generator_state` for the actual work.

Args:
experiment: The experiment to fit the surrogate model / predictor to.
data: The experiment data used to fit the model.
search_space: UNSUPPORTED. An optional override for the experiment
search space.
optimization_config: UNSUPPORTED. An optional override for the experiment
optimization config.
kwargs: UNSUPPORTED. Additional keyword arguments for model fitting.
"""
if search_space is not None or optimization_config is not None or kwargs:
raise UnsupportedError(
"Unexpected arguments encountered. `ExternalGenerationNode.fit` only "
"supports `experiment` and `data` arguments. "
"Each of the following arguments should be None / empty. "
f"{search_space=}, {optimization_config=}, {kwargs=}."
)
t_fit_start = time.monotonic()
self.update_generator_state(
experiment=experiment,
data=data,
)
self.fit_time_since_gen += time.monotonic() - t_fit_start

def _gen(
self,
n: Optional[int] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
"""Generate new candidates for evaluation.

This method calls `get_next_trial_parameterizations` to get the parameters
for the next trial(s), and packages it as needed for higher level Ax APIs.
If `should_deduplicate=True`, this also checks for duplicates and re-generates
the parameters as needed.

Args:
n: Optional integer representing how many arms should be in the generator
run produced by this method. Defaults to 1.
pending_observations: A map from metric name to pending
observations for that metric, used by some methods to avoid
re-suggesting candidates that are currently being evaluated.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.

Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
t_gen_start = time.monotonic()
n = 1 if n is None else n
pending_parameters: List[TParameterization] = [
[o.parameters for obs in (pending_observations or {}).values() for o in obs]
]
generated_params: List[TParameterization] = []
for _ in range(n):
params = self.get_next_candidate(pending_parameters=pending_parameters)
generated_params.append(params)
pending_parameters.append(params)
# Return the parameterizations as a generator run.
generator_run = GeneratorRun(
arms=[Arm(parameters=params) for params in generated_params],
fit_time=self.fit_time_since_gen,
gen_time=time.monotonic() - t_gen_start,
model_key=self.node_name,
)
# TODO: This shares the same bug as ModelBridge.gen. In both cases, after
# deduplication, the generator run will record fit_time as 0.
self.fit_time_since_gen = 0
return generator_run
64 changes: 46 additions & 18 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def fit(
**kwargs,
)

# TODO [drfreund]: Move this up to `GenerationNodeInterface` once implemented.
def gen(
self,
n: Optional[int] = None,
Expand All @@ -283,11 +284,8 @@ def gen(
arms_by_signature_for_deduplication: Optional[Dict[str, Arm]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
"""Picks a fitted model, from which to generate candidates (via
``self._pick_fitted_model_to_gen_from``) and generates candidates
from it. Uses the ``model_gen_kwargs`` set on the selected ``ModelSpec``
alongside any kwargs passed in to this function (with local kwargs)
taking precedent.
"""This method generates candidates using `self._gen` and handles deduplication
of generated candidates if `self.should_deduplicate=True`.

NOTE: Models must have been fit prior to calling ``gen``.
NOTE: Some underlying models may ignore the ``n`` argument and produce a
Expand All @@ -305,34 +303,25 @@ def gen(
new candidates without duplicates. If non-duplicate candidates are not
generated with these attempts, a ``GenerationStrategyRepeatedPoints``
exception will be raised.
arms_by_signature_for_deduplication: A dictionary mapping arm signatures to
the arms, to be used for deduplicating newly generated arms.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.

Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
model_spec = self.model_spec_to_gen_from
should_generate_run = True
generator_run = None
n_gen_draws = 0
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while should_generate_run:
generator_run = model_spec.gen(
# If `n` is not specified, ensure that the `None` value does not
# override the one set in `model_spec.model_gen_kwargs`.
n=(
model_spec.model_gen_kwargs.get("n")
if n is None and model_spec.model_gen_kwargs
else n
),
# For `pending_observations`, prefer the input to this function, as
# `pending_observations` are dynamic throughout the experiment and thus
# unlikely to be specified in `model_spec.model_gen_kwargs`.
generator_run = self._gen(
n=n,
pending_observations=pending_observations,
**model_gen_kwargs,
)

should_generate_run = (
self.should_deduplicate
and arms_by_signature_for_deduplication
Expand Down Expand Up @@ -360,6 +349,45 @@ def gen(
generator_run._generation_node_name = self.node_name
return generator_run

def _gen(
self,
n: Optional[int] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
"""Picks a fitted model, from which to generate candidates (via
``self._pick_fitted_model_to_gen_from``) and generates candidates
from it. Uses the ``model_gen_kwargs`` set on the selected ``ModelSpec``
alongside any kwargs passed in to this function (with local kwargs)
taking precedent.

Args:
n: Optional integer representing how many arms should be in the generator
run produced by this method. When this is ``None``, ``n`` will be
determined by the ``ModelSpec`` that we are generating from.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.

Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
model_spec = self.model_spec_to_gen_from
if n is None and model_spec.model_gen_kwargs:
# If `n` is not specified, ensure that the `None` value does not
# override the one set in `model_spec.model_gen_kwargs`.
n = model_spec.model_gen_kwargs.get("n", None)
return model_spec.gen(
n=n,
# For `pending_observations`, prefer the input to this function, as
# `pending_observations` are dynamic throughout the experiment and thus
# unlikely to be specified in `model_spec.model_gen_kwargs`.
pending_observations=pending_observations,
**model_gen_kwargs,
)

# ------------------------- Model selection logic helpers. -------------------------

def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
)

from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.model_spec import FactoryFunctionModelSpec
Expand Down
Loading
Loading