Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TopK downselection for initial batch generation. #2636

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8047aa1
wip: topk ic generation
CompRhys Nov 19, 2024
f5a8d64
tests: add tests
CompRhys Nov 20, 2024
96f9bef
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Nov 24, 2024
a022462
fix: micro-optimization suggestion from review
CompRhys Nov 24, 2024
e75239d
fix: don't use unnormalize due to unexpected behaviour with constant …
CompRhys Nov 25, 2024
8e27422
doc: initialize_q_batch_topk -> initialize_q_batch_topn
CompRhys Nov 25, 2024
662caf1
tests: achive full coverage
CompRhys Nov 26, 2024
75eea37
clean: remote debug snippet
CompRhys Nov 26, 2024
5e0fe59
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Nov 27, 2024
88a2e5d
fea: use unnormalize in more places but add flag to turn off the cons…
CompRhys Dec 2, 2024
e0202e2
doc: add docstring for the new update_constant_bounds argument
CompRhys Dec 2, 2024
21bbc27
fix: assert warns rather than catch and check
CompRhys Dec 2, 2024
6e93eba
fix: nit limit scope of context managers
CompRhys Dec 3, 2024
5e706ea
Merge branch 'topk-icgen' of https://github.com/Radical-AI/botorch in…
CompRhys Dec 3, 2024
f364fe1
doc: update the gen_batch_initial_conditions docstring
CompRhys Dec 3, 2024
7d9f9eb
Merge remote-tracking branch 'upstream/main' into topk-icgen
CompRhys Dec 3, 2024
e054fe2
Merge branch 'main' into topk-icgen
CompRhys Dec 3, 2024
1e0828c
test: reduce the number of tests
CompRhys Dec 3, 2024
ec1c167
Merge branch 'topk-icgen' of https://github.com/Radical-AI/botorch in…
CompRhys Dec 3, 2024
1ddc929
revert: redo the changes to reduce context manager scope
CompRhys Dec 3, 2024
61f6ffb
nit: change to assertWarns
CompRhys Dec 3, 2024
975bb29
Update botorch/optim/initializers.py
CompRhys Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion botorch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
LinearHomotopySchedule,
LogLinearHomotopySchedule,
)
from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg
from botorch.optim.initializers import (
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topn,
)
from botorch.optim.optimize import (
gen_batch_initial_conditions,
optimize_acqf,
Expand All @@ -43,6 +47,7 @@
"gen_batch_initial_conditions",
"initialize_q_batch",
"initialize_q_batch_nonneg",
"initialize_q_batch_topn",
"OptimizationResult",
"OptimizationStatus",
"optimize_acqf",
Expand Down
88 changes: 81 additions & 7 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,24 @@ def gen_batch_initial_conditions(
init_kwargs = {}
device = bounds.device
bounds_cpu = bounds.cpu()
if "eta" in options:
init_kwargs["eta"] = options.get("eta")
if options.get("nonnegative") or is_nonnegative(acq_function):

if options.get("topn"):
init_func = initialize_q_batch_topn
init_func_opts = ["sorted", "largest"]
elif options.get("nonnegative") or is_nonnegative(acq_function):
init_func = initialize_q_batch_nonneg
if "alpha" in options:
init_kwargs["alpha"] = options.get("alpha")
init_func_opts = ["alpha", "eta"]
else:
init_func = initialize_q_batch
init_func_opts = ["eta"]

for opt in init_func_opts:
# default value of "largest" to "acq_function.maximize" if it exists
if opt == "largest" and hasattr(acq_function, "maximize"):
init_kwargs[opt] = acq_function.maximize

if opt in options:
init_kwargs[opt] = options.get(opt)

q = 1 if q is None else q
# the dimension the samples are drawn from
Expand Down Expand Up @@ -363,7 +373,9 @@ def gen_batch_initial_conditions(
X_rnd_nlzd = torch.rand(
n, q, bounds_cpu.shape[-1], dtype=bounds.dtype
)
X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd
X_rnd = unnormalize(
X_rnd_nlzd, bounds, update_constant_bounds=False
)
else:
X_rnd = sample_q_batches_from_polytope(
n=n,
Expand All @@ -375,7 +387,8 @@ def gen_batch_initial_conditions(
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
# sample points around best

# sample additional points around best
if sample_around_best:
X_best_rnd = sample_points_around_best(
acq_function=acq_function,
Expand All @@ -395,6 +408,8 @@ def gen_batch_initial_conditions(
)
# Keep X on CPU for consistency & to limit GPU memory usage.
X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu()

# Append the fixed fantasies to the randomly generated points
if fixed_X_fantasies is not None:
if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]):
raise BotorchTensorDimensionError(
Expand All @@ -411,6 +426,9 @@ def gen_batch_initial_conditions(
],
dim=-2,
)

# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
with torch.no_grad():
if batch_limit is None:
batch_limit = X_rnd.shape[0]
Expand All @@ -423,16 +441,22 @@ def gen_batch_initial_conditions(
],
dim=0,
)

# Downselect the initial conditions based on the acquisition function values
batch_initial_conditions, _ = init_func(
X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs
)
batch_initial_conditions = batch_initial_conditions.to(device=device)

# Return the initial conditions if no warnings were raised
if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws):
return batch_initial_conditions

if factor < max_factor:
factor += 1
if seed is not None:
seed += 1 # make sure to sample different X_rnd

warnings.warn(
"Unable to find non-zero acquisition function values - initial conditions "
"are being selected randomly.",
Expand Down Expand Up @@ -1057,6 +1081,56 @@ def initialize_q_batch_nonneg(
return X[idcs], acq_vals[idcs]


def initialize_q_batch_topn(
X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True
) -> tuple[Tensor, Tensor]:
r"""Take the top `n` initial conditions for candidate generation.

Args:
X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
feature space. Typically, these are generated using qMC.
acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this
is the value of the batch acquisition function to be maximized.
n: The number of initial condition to be generated. Must be less than `b`.

Returns:
- An `n x q x d` tensor of `n` `q`-batch initial conditions.
- An `n` tensor of the corresponding acquisition values.

Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
>>> # for model with `d=6`:
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch_topn(
... X=X_rnd, acq_vals=qUCB(X_rnd), n=10
... )

"""
n_samples = X.shape[0]
if n > n_samples:
raise RuntimeError(
f"n ({n}) cannot be larger than the number of "
f"provided samples ({n_samples})"
)
elif n == n_samples:
return X, acq_vals

Ystd = acq_vals.std(dim=0)
if torch.any(Ystd == 0):
warnings.warn(
"All acquisition values for raw samples points are the same for "
"at least one batch. Choosing initial conditions at random.",
BadInitialCandidatesWarning,
stacklevel=3,
)
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
return X[idcs], acq_vals[idcs]

topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted)
return X[topk_idcs], topk_out


def sample_points_around_best(
acq_function: AcquisitionFunction,
n_discrete_points: int,
Expand Down
5 changes: 3 additions & 2 deletions botorch/utils/feasible_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import botorch.models.model as model
import torch
from botorch.logging import _get_logger
from botorch.utils.sampling import manual_seed
from botorch.utils.sampling import manual_seed, unnormalize
from torch import Tensor


Expand Down Expand Up @@ -164,9 +164,10 @@ def estimate_feasible_volume(
seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()

with manual_seed(seed=seed):
box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(
samples_nlzd = torch.rand(
(nsample_feature, bounds.size(1)), dtype=dtype, device=device
)
box_samples = unnormalize(samples_nlzd, bounds, update_constant_bounds=False)

features, p_feature = get_feasible_samples(
samples=box_samples, inequality_constraints=inequality_constraints
Expand Down
8 changes: 3 additions & 5 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,12 @@ def draw_sobol_samples(
batch_shape = batch_shape or torch.Size()
batch_size = int(torch.prod(torch.tensor(batch_shape)))
d = bounds.shape[-1]
lower = bounds[0]
rng = bounds[1] - bounds[0]
sobol_engine = SobolEngine(q * d, scramble=True, seed=seed)
samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype)
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device)
samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype)
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device)
if batch_shape != torch.Size():
samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1)
return lower + rng * samples_raw
return unnormalize(samples_raw, bounds, update_constant_bounds=False)


def draw_sobol_normal_samples(
Expand Down
33 changes: 21 additions & 12 deletions botorch/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,18 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor:
return bounds


def normalize(X: Tensor, bounds: Tensor) -> Tensor:
def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor:
r"""Min-max normalize X w.r.t. the provided bounds.

NOTE: If the upper and lower bounds are identical for a dimension, that dimension
will not be scaled. Such dimensions will only be shifted as
`new_X[..., i] = X[..., i] - bounds[0, i]`. This avoids division by zero issues.

Args:
X: `... x d` tensor of data
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
columns.
update_constant_bounds: If `True`, update the constant bounds in order to
avoid division by zero issues. When the upper and lower bounds are
identical for a dimension, that dimension will not be scaled. Such
dimensions will only be shifted as
`new_X[..., i] = X[..., i] - bounds[0, i]`.

Returns:
A `... x d`-dim tensor of normalized data, given by
Expand All @@ -89,21 +90,27 @@ def normalize(X: Tensor, bounds: Tensor) -> Tensor:
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
>>> X_normalized = normalize(X, bounds)
"""
bounds = _update_constant_bounds(bounds=bounds)
bounds = (
_update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds
)
return (X - bounds[0]) / (bounds[1] - bounds[0])


def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
def unnormalize(
X: Tensor, bounds: Tensor, update_constant_bounds: bool = True
) -> Tensor:
r"""Un-normalizes X w.r.t. the provided bounds.

NOTE: If the upper and lower bounds are identical for a dimension, that dimension
will not be scaled. Such dimensions will only be shifted as
`new_X[..., i] = X[..., i] + bounds[0, i]`, matching the behavior of `normalize`.

Args:
X: `... x d` tensor of data
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
columns.
update_constant_bounds: If `True`, update the constant bounds in order to
avoid division by zero issues. When the upper and lower bounds are
identical for a dimension, that dimension will not be scaled. Such
dimensions will only be shifted as
`new_X[..., i] = X[..., i] + bounds[0, i]`. This is the inverse of
the behavior of `normalize` when `update_constant_bounds=True`.

Returns:
A `... x d`-dim tensor of unnormalized data, given by
Expand All @@ -116,7 +123,9 @@ def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
>>> X = unnormalize(X_normalized, bounds)
"""
bounds = _update_constant_bounds(bounds=bounds)
bounds = (
_update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds
)
return X * (bounds[1] - bounds[0]) + bounds[0]


Expand Down
Loading
Loading