Skip to content

Commit

Permalink
SacessOptimizer: expose more hyperparameters
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Sep 15, 2024
1 parent 36eec04 commit 5c7aa6e
Showing 1 changed file with 109 additions and 36 deletions.
145 changes: 109 additions & 36 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Self-adaptive cooperative enhanced scatter search (SACESS)."""
from __future__ import annotations

import itertools
import logging
Expand All @@ -11,7 +12,7 @@
from multiprocessing import get_context
from multiprocessing.managers import SyncManager
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import Any, Callable
from uuid import uuid1
from warnings import warn

Expand Down Expand Up @@ -62,13 +63,14 @@ class SacessOptimizer:

def __init__(
self,
num_workers: Optional[int] = None,
ess_init_args: Optional[list[dict[str, Any]]] = None,
num_workers: int | None = None,
ess_init_args: list[dict[str, Any]] | None = None,
max_walltime_s: float = np.inf,
sacess_loglevel: int = logging.INFO,
ess_loglevel: int = logging.WARNING,
tmpdir: Union[Path, str] = None,
tmpdir: Path | str = None,
mp_start_method: str = "spawn",
options: SacessOptions = None,
):
"""Construct.
Expand Down Expand Up @@ -110,6 +112,8 @@ def __init__(
mp_start_method:
The start method for the multiprocessing context.
See :mod:`multiprocessing` for details.
options:
Further optimizer hyperparameters.
"""
if (num_workers is None and ess_init_args is None) or (
num_workers is not None and ess_init_args is not None
Expand Down Expand Up @@ -138,10 +142,11 @@ def __init__(
self._tmpdir = Path(f"SacessOptimizerTemp-{str(uuid1())[:8]}")
self._tmpdir = Path(self._tmpdir).absolute()
self._tmpdir.mkdir(parents=True, exist_ok=True)
self.histories: Optional[
list["pypesto.history.memory.MemoryHistory"]
] = None
self.histories: list[
pypesto.history.memory.MemoryHistory
] | None = None
self.mp_ctx = get_context(mp_start_method)
self.options = options or SacessOptions()

def minimize(
self,
Expand Down Expand Up @@ -212,6 +217,7 @@ def minimize(
shmem_manager=shmem_manager,
ess_options=ess_init_args,
dim=problem.dim,
options=self.options,
)
# create workers
workers = [
Expand All @@ -225,6 +231,7 @@ def minimize(
tmp_result_file=SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
),
options=self.options,
)
for worker_idx, ess_kwargs in enumerate(ess_init_args)
]
Expand Down Expand Up @@ -348,19 +355,19 @@ class SacessManager:
adaptation of ``_rejection_threshold``.
_rejection_threshold: Threshold for relative objective improvements that
incoming solutions have to pass to be accepted
_rejection_threshold_min: ``_rejection_threshold`` will be reduced (halved)
if too few solutions are accepted. This value is the lower limit for
``_rejection_threshold``.
_lock: Lock for accessing shared state.
_logger: A logger instance
_options: Further optimizer hyperparameters.
"""

def __init__(
self,
shmem_manager: SyncManager,
ess_options: list[dict[str, Any]],
dim: int,
options: SacessOptions = None,
):
self._options = options or SacessOptions()
self._num_workers = len(ess_options)
self._ess_options = [shmem_manager.dict(o) for o in ess_options]
self._best_known_fx = shmem_manager.Value("d", np.inf)
Expand All @@ -370,8 +377,9 @@ def __init__(
# [PenasGon2017]_ p.9 is 0.1.
# However, their implementation uses 0.1 *percent*. I assume this is a
# mistake in the paper.
self._rejection_threshold = shmem_manager.Value("d", 0.001)
self._rejection_threshold_min = 0.001
self._rejection_threshold = shmem_manager.Value(
"d", self._options.manager_initial_rejection_threshold
)

# scores of the workers, ordered by worker-index
# initial score is the worker index
Expand Down Expand Up @@ -480,7 +488,7 @@ def submit_solution(
if self._rejections.value >= self._num_workers:
self._rejection_threshold.value = min(
self._rejection_threshold.value / 2,
self._rejection_threshold_min,
self._options.manager_minimum_rejection_threshold,
)
self._logger.debug(
"Lowered acceptance threshold to "
Expand All @@ -507,9 +515,6 @@ class SacessWorker:
to the manager.
_ess_kwargs: ESSOptimizer options for this worker (may get updated during
the self-adaptive step).
_acceptance_threshold: Minimum relative improvement of the objective
compared to the best known value to be eligible for submission to the
Manager.
_n_sent_solutions: Number of solutions sent to the Manager.
_max_walltime_s: Walltime limit.
_logger: A Logger instance.
Expand All @@ -527,15 +532,14 @@ def __init__(
loglevel: int = logging.INFO,
ess_loglevel: int = logging.WARNING,
tmp_result_file: str = None,
options: SacessOptions = None,
):
self._manager = manager
self._worker_idx = worker_idx
self._best_known_fx = np.inf
self._n_received_solutions = 0
self._neval = 0
self._ess_kwargs = ess_kwargs
# Default value from original SaCeSS implementation
self._acceptance_threshold = 0.0001
self._n_sent_solutions = 0
self._max_walltime_s = max_walltime_s
self._start_time = None
Expand All @@ -544,6 +548,7 @@ def __init__(
self._logger = None
self._tmp_result_file = tmp_result_file
self._refset = None
self._options = options or SacessOptions()

def run(
self,
Expand Down Expand Up @@ -665,15 +670,16 @@ def _cooperate(self):
self.replace_solution(self._refset, x=recv_x, fx=recv_fx)

def _maybe_adapt(self, problem: Problem):
"""Perform adaptation step.
"""Perform the adaptation step if needed.
Update ESS settings if conditions are met.
"""
# Update ESS settings if we received way more solutions than we sent
# Magic numbers from [PenasGon2017]_ algorithm 5
if (
self._n_received_solutions > 10 * self._n_sent_solutions + 20
and self._neval > problem.dim * 5000
self._n_received_solutions
> self._options.adaptation_sent_coeff * self._n_sent_solutions
+ self._options.adaptation_sent_offset
and self._neval > problem.dim * self._options.adaptation_min_evals
):
self._ess_kwargs = self._manager.reconfigure_worker(
self._worker_idx
Expand All @@ -693,7 +699,7 @@ def maybe_update_best(self, x: np.array, fx: float):
f"Worker {self._worker_idx} maybe sending solution {fx}. "
f"best known: {self._best_known_fx}, "
f"rel change: {rel_change:.4g}, "
f"threshold: {self._acceptance_threshold}"
f"threshold: {self._options.worker_acceptance_threshold}"
)

# solution improves best value by at least a factor of ...
Expand All @@ -703,7 +709,7 @@ def maybe_update_best(self, x: np.array, fx: float):
or (
fx < self._best_known_fx
and abs((self._best_known_fx - fx) / fx)
> self._acceptance_threshold
> self._options.worker_acceptance_threshold
)
):
self._logger.debug(
Expand Down Expand Up @@ -767,9 +773,7 @@ def _keep_going(self):
return True

@staticmethod
def get_temp_result_filename(
worker_idx: int, tmpdir: Union[str, Path]
) -> str:
def get_temp_result_filename(worker_idx: int, tmpdir: str | Path) -> str:
return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute())


Expand Down Expand Up @@ -797,11 +801,9 @@ def _run_worker(
def get_default_ess_options(
num_workers: int,
dim: int,
local_optimizer: Union[
bool,
"pypesto.optimize.Optimizer",
Callable[..., "pypesto.optimize.Optimizer"],
] = True,
local_optimizer: bool
| pypesto.optimize.Optimizer
| Callable[..., pypesto.optimize.Optimizer] = True,
) -> list[dict]:
"""Get default ESS settings for (SA)CESS.
Expand Down Expand Up @@ -1017,8 +1019,8 @@ class SacessFidesFactory:

def __init__(
self,
fides_options: Optional[dict[str, Any]] = None,
fides_kwargs: Optional[dict[str, Any]] = None,
fides_options: dict[str, Any] | None = None,
fides_kwargs: dict[str, Any] | None = None,
):
if fides_options is None:
fides_options = {}
Expand All @@ -1038,7 +1040,7 @@ def __init__(

def __call__(
self, max_walltime_s: int, max_eval: int
) -> "pypesto.optimize.FidesOptimizer":
) -> pypesto.optimize.FidesOptimizer:
"""Create a :class:`FidesOptimizer` instance."""

from fides.constants import Options as FidesOptions
Expand Down Expand Up @@ -1085,5 +1087,76 @@ class SacessWorkerResult:
fx: float
n_eval: int
n_iter: int
history: "pypesto.history.memory.MemoryHistory"
history: pypesto.history.memory.MemoryHistory
exit_flag: ESSExitFlag


@dataclass
class SacessOptions:
"""Container for :class:`SacessOptimizer` hyperparameters.
Attributes
----------
manager_initial_rejection_threshold, manager_minimum_rejection_threshold:
Initial and minimum threshold for relative objective improvements that
incoming solutions have to pass to be accepted. If the number of
rejected solutions exceeds the number of workers, the threshold is
halved until it reaches ``manager_minimum_rejection_threshold``.
_acceptance_threshold: Minimum relative improvement of the objective
compared to the best known value to be eligible for submission to the
Manager.
worker_acceptance_threshold:
Minimum relative improvement of the objective compared to the best
known value to be eligible for submission to the Manager.
adaptation_min_evals, adaptation_sent_offset, adaptation_sent_coeff:
Hyperparameters that control when the workers will adapt their settings
based on the performance of the other workers.
The adaptation step is performed if all the following conditions are
met:
* The number of function evaluations since the last solution was sent
to the manager times the number of optimization parameters is greater
than ``adaptation_min_evals``.
* The number of solutions received by the worker since the last
solution it sent to the manager is greater than
``adaptation_sent_coeff * n_sent_solutions + adaptation_sent_offset``,
where ``n_sent_solutions`` is the number of solutions sent to the
manager by the given worker.
"""

manager_initial_rejection_threshold: float = 0.001
manager_minimum_rejection_threshold: float = 0.001

# Default value from original SaCeSS implementation
worker_acceptance_threshold: float = 0.0001

# Magic numbers for adaptation, taken from [PenasGon2017]_ algorithm 5
adaptation_min_evals: int = 5000
adaptation_sent_offset: int = 20
adaptation_sent_coeff: int = 10

def __post_init__(self):
if self.adaptation_min_evals < 0:
raise ValueError("adaptation_min_evals must be non-negative.")
if self.adaptation_sent_offset < 0:
raise ValueError("adaptation_sent_offset must be non-negative.")
if self.adaptation_sent_coeff < 0:
raise ValueError("adaptation_sent_coeff must be non-negative.")
if self.manager_initial_rejection_threshold < 0:
raise ValueError(
"manager_initial_rejection_threshold must be non-negative."
)
if self.manager_minimum_rejection_threshold < 0:
raise ValueError(
"manager_minimum_rejection_threshold must be non-negative."
)
if self.worker_acceptance_threshold < 0:
raise ValueError(
"worker_acceptance_threshold must be non-negative."
)

0 comments on commit 5c7aa6e

Please sign in to comment.