diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index f4abe3fd87..5156bba684 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -22,7 +22,11 @@ LinearHomotopySchedule, LogLinearHomotopySchedule, ) -from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg +from botorch.optim.initializers import ( + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topn, +) from botorch.optim.optimize import ( gen_batch_initial_conditions, optimize_acqf, @@ -43,6 +47,7 @@ "gen_batch_initial_conditions", "initialize_q_batch", "initialize_q_batch_nonneg", + "initialize_q_batch_topn", "OptimizationResult", "OptimizationStatus", "optimize_acqf", diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index af0f918f4a..753ca86124 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -271,13 +271,15 @@ def gen_batch_initial_conditions( fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. options: Options for initial condition generation. For valid options see - `initialize_q_batch` and `initialize_q_batch_nonneg`. If `options` - contains a `nonnegative=True` entry, then `acq_function` is - assumed to be non-negative (useful when using custom acquisition - functions). In addition, an "init_batch_limit" option can be passed - to specify the batch limit for the initialization. This is useful - for avoiding memory limits when computing the batch posterior over - raw samples. + `initialize_q_batch_topn`, `initialize_q_batch_nonneg`, and + `initialize_q_batch`. If `options` contains a `topn=True` then + `initialize_q_batch_topn` will be used. Else if `options` contains a + `nonnegative=True` entry, then `acq_function` is assumed to be + non-negative (useful when using custom acquisition functions). + `initialize_q_batch` will be used otherwise. In addition, an + "init_batch_limit" option can be passed to specify the batch limit + for the initialization. This is useful for avoiding memory limits + when computing the batch posterior over raw samples. inequality constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. @@ -328,14 +330,24 @@ def gen_batch_initial_conditions( init_kwargs = {} device = bounds.device bounds_cpu = bounds.cpu() - if "eta" in options: - init_kwargs["eta"] = options.get("eta") - if options.get("nonnegative") or is_nonnegative(acq_function): + + if options.get("topn"): + init_func = initialize_q_batch_topn + init_func_opts = ["sorted", "largest"] + elif options.get("nonnegative") or is_nonnegative(acq_function): init_func = initialize_q_batch_nonneg - if "alpha" in options: - init_kwargs["alpha"] = options.get("alpha") + init_func_opts = ["alpha", "eta"] else: init_func = initialize_q_batch + init_func_opts = ["eta"] + + for opt in init_func_opts: + # default value of "largest" to "acq_function.maximize" if it exists + if opt == "largest" and hasattr(acq_function, "maximize"): + init_kwargs[opt] = acq_function.maximize + + if opt in options: + init_kwargs[opt] = options.get(opt) q = 1 if q is None else q # the dimension the samples are drawn from @@ -363,7 +375,9 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd + X_rnd = unnormalize( + X_rnd_nlzd, bounds_cpu, update_constant_bounds=False + ) else: X_rnd = sample_q_batches_from_polytope( n=n, @@ -375,7 +389,8 @@ def gen_batch_initial_conditions( equality_constraints=equality_constraints, inequality_constraints=inequality_constraints, ) - # sample points around best + + # sample additional points around best if sample_around_best: X_best_rnd = sample_points_around_best( acq_function=acq_function, @@ -395,6 +410,8 @@ def gen_batch_initial_conditions( ) # Keep X on CPU for consistency & to limit GPU memory usage. X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu() + + # Append the fixed fantasies to the randomly generated points if fixed_X_fantasies is not None: if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]): raise BotorchTensorDimensionError( @@ -411,6 +428,9 @@ def gen_batch_initial_conditions( ], dim=-2, ) + + # Evaluate the acquisition function on `X_rnd` using `batch_limit` + # sized chunks. with torch.no_grad(): if batch_limit is None: batch_limit = X_rnd.shape[0] @@ -423,16 +443,22 @@ def gen_batch_initial_conditions( ], dim=0, ) + + # Downselect the initial conditions based on the acquisition function values batch_initial_conditions, _ = init_func( X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs ) batch_initial_conditions = batch_initial_conditions.to(device=device) + + # Return the initial conditions if no warnings were raised if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws): return batch_initial_conditions + if factor < max_factor: factor += 1 if seed is not None: seed += 1 # make sure to sample different X_rnd + warnings.warn( "Unable to find non-zero acquisition function values - initial conditions " "are being selected randomly.", @@ -1057,6 +1083,56 @@ def initialize_q_batch_nonneg( return X[idcs], acq_vals[idcs] +def initialize_q_batch_topn( + X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True +) -> tuple[Tensor, Tensor]: + r"""Take the top `n` initial conditions for candidate generation. + + Args: + X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim. + feature space. Typically, these are generated using qMC. + acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this + is the value of the batch acquisition function to be maximized. + n: The number of initial condition to be generated. Must be less than `b`. + + Returns: + - An `n x q x d` tensor of `n` `q`-batch initial conditions. + - An `n` tensor of the corresponding acquisition values. + + Example: + >>> # To get `n=10` starting points of q-batch size `q=3` + >>> # for model with `d=6`: + >>> qUCB = qUpperConfidenceBound(model, beta=0.1) + >>> X_rnd = torch.rand(500, 3, 6) + >>> X_init, acq_init = initialize_q_batch_topn( + ... X=X_rnd, acq_vals=qUCB(X_rnd), n=10 + ... ) + + """ + n_samples = X.shape[0] + if n > n_samples: + raise RuntimeError( + f"n ({n}) cannot be larger than the number of " + f"provided samples ({n_samples})" + ) + elif n == n_samples: + return X, acq_vals + + Ystd = acq_vals.std(dim=0) + if torch.any(Ystd == 0): + warnings.warn( + "All acquisition values for raw samples points are the same for " + "at least one batch. Choosing initial conditions at random.", + BadInitialCandidatesWarning, + stacklevel=3, + ) + idcs = torch.randperm(n=n_samples, device=X.device)[:n] + return X[idcs], acq_vals[idcs] + + topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted) + return X[topk_idcs], topk_out + + def sample_points_around_best( acq_function: AcquisitionFunction, n_discrete_points: int, diff --git a/botorch/utils/feasible_volume.py b/botorch/utils/feasible_volume.py index f3b8d2fb76..2608c03c2a 100644 --- a/botorch/utils/feasible_volume.py +++ b/botorch/utils/feasible_volume.py @@ -11,7 +11,7 @@ import botorch.models.model as model import torch from botorch.logging import _get_logger -from botorch.utils.sampling import manual_seed +from botorch.utils.sampling import manual_seed, unnormalize from torch import Tensor @@ -164,9 +164,10 @@ def estimate_feasible_volume( seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() with manual_seed(seed=seed): - box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand( + samples_nlzd = torch.rand( (nsample_feature, bounds.size(1)), dtype=dtype, device=device ) + box_samples = unnormalize(samples_nlzd, bounds, update_constant_bounds=False) features, p_feature = get_feasible_samples( samples=box_samples, inequality_constraints=inequality_constraints diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 52fe54fbb2..f914dea24d 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -98,14 +98,12 @@ def draw_sobol_samples( batch_shape = batch_shape or torch.Size() batch_size = int(torch.prod(torch.tensor(batch_shape))) d = bounds.shape[-1] - lower = bounds[0] - rng = bounds[1] - bounds[0] sobol_engine = SobolEngine(q * d, scramble=True, seed=seed) - samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype) - samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device) + samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype) + samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return lower + rng * samples_raw + return unnormalize(samples_raw, bounds, update_constant_bounds=False) def draw_sobol_normal_samples( diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 01f34c0da4..b354821cfb 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -66,17 +66,18 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor: return bounds -def normalize(X: Tensor, bounds: Tensor) -> Tensor: +def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor: r"""Min-max normalize X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] - bounds[0, i]`. This avoids division by zero issues. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] - bounds[0, i]`. Returns: A `... x d`-dim tensor of normalized data, given by @@ -89,21 +90,27 @@ def normalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X_normalized = normalize(X, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return (X - bounds[0]) / (bounds[1] - bounds[0]) -def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: +def unnormalize( + X: Tensor, bounds: Tensor, update_constant_bounds: bool = True +) -> Tensor: r"""Un-normalizes X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] + bounds[0, i]`, matching the behavior of `normalize`. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] + bounds[0, i]`. This is the inverse of + the behavior of `normalize` when `update_constant_bounds=True`. Returns: A `... x d`-dim tensor of unnormalized data, given by @@ -116,7 +123,9 @@ def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X = unnormalize(X_normalized, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return X * (bounds[1] - bounds[0]) + bounds[0] diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 09be6f2326..65c3e2b6bb 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -30,12 +30,14 @@ from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP -from botorch.optim import initialize_q_batch, initialize_q_batch_nonneg from botorch.optim.initializers import ( gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, gen_value_function_initial_conditions, + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topn, sample_perturbed_subset_dims, sample_points_around_best, sample_q_batches_from_polytope, @@ -45,7 +47,7 @@ transform_intra_point_constraint, ) from botorch.sampling.normal import IIDNormalSampler -from botorch.utils.sampling import draw_sobol_samples, manual_seed +from botorch.utils.sampling import draw_sobol_samples, manual_seed, unnormalize from botorch.utils.testing import ( _get_max_violation_of_bounds, _get_max_violation_of_constraints, @@ -108,10 +110,8 @@ def test_initialize_q_batch_nonneg(self): self.assertEqual(ics.dtype, X.dtype) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: + with self.assertWarns(BadInitialCandidatesWarning): ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=10) @@ -129,31 +129,64 @@ def test_initialize_q_batch_nonneg(self): self.assertEqual(ics.dtype, X.dtype) def test_initialize_q_batch(self): + for dtype, batch_shape in ( + (torch.float, torch.Size()), + (torch.double, [3, 2]), + (torch.float, (2,)), + (torch.double, torch.Size([2, 3, 4])), + (torch.float, []), + ): + # basic test + X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with self.assertWarns(BadInitialCandidatesWarning): + ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + + def test_initialize_q_batch_topn(self): for dtype in (torch.float, torch.double): - for batch_shape in (torch.Size(), [3, 2], (2,), torch.Size([2, 3, 4]), []): - # basic test - X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) - acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) - self.assertEqual(ics_X.device, X.device) - self.assertEqual(ics_X.dtype, X.dtype) - self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) - self.assertEqual(ics_acq_vals.device, acq_vals.device) - self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) - # ensure nothing happens if we want all samples - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) - self.assertTrue(torch.equal(X, ics_X)) - self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) - # ensure raises correct warning - acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: - ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) - self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) - with self.assertRaises(RuntimeError): - initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + # basic test + X = torch.rand(5, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # make sure things work with constant inputs + acq_vals = torch.ones(5, device=self.device, dtype=dtype) + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics.device, X.device) + self.assertEqual(ics.dtype, X.dtype) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with self.assertWarns(BadInitialCandidatesWarning): + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10) def test_initialize_q_batch_largeZ(self): for dtype in (torch.float, torch.double): @@ -187,64 +220,149 @@ def test_gen_batch_initial_conditions(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + for ( + dtype, + nonnegative, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in ( + (torch.float, True, None, None, None, True), + (torch.double, False, 1234, 1, {0: 0.5}, False), + (torch.double, True, 1234, None, {0: 0.5}, True), + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False] - ): - with mock.patch.object( - MockAcquisitionFunction, - "__call__", - wraps=mock_acqf.__call__, - ) as mock_acqf_call, warnings.catch_warnings(): - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=mock_acqf, - bounds=bounds, - q=1, - num_restarts=2, - raw_samples=10, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "init_batch_limit": init_batch_limit, - "sample_around_best": sample_around_best, - }, - ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + }, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + def test_gen_batch_initial_conditions_topn(self): + bounds = torch.stack([torch.zeros(2), torch.ones(2)]) + mock_acqf = MockAcquisitionFunction() + mock_acqf.objective = lambda y: y.squeeze(-1) + mock_acqf.maximize = True # Add maximize attribute + for ( + dtype, + topn, + largest, + is_sorted, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in ( + (torch.float, True, True, True, None, None, None, True), + (torch.double, False, False, False, 1234, 1, {0: 0.5}, False), + (torch.float, True, None, True, 1234, None, None, False), + (torch.double, False, True, False, None, 1, {0: 0.5}, True), + (torch.float, True, False, False, 1234, None, {0: 0.5}, True), + (torch.double, False, None, True, None, 1, None, False), + (torch.float, True, True, False, 1234, 1, {0: 0.5}, True), + (torch.double, False, False, True, None, None, None, False), + ): + bounds = bounds.to(device=self.device, dtype=dtype) + mock_acqf.X_baseline = bounds # for testing sample_around_best + mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + options = { + "topn": topn, + "sorted": is_sorted, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + } + if largest is not None: + options["largest"] = largest + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options=options, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) @@ -252,48 +370,46 @@ def test_gen_batch_initial_conditions_highdim(self): ffs_map = {i: random() for i in range(0, d, 2)} mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + for dtype, nonnegative, seed, ffs, sample_around_best in ( + (torch.float, True, None, None, True), + (torch.double, False, 1234, ffs_map, False), + (torch.double, True, 1234, ffs_map, True), + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - - for nonnegative, seed, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, ffs_map], [True, False] - ): - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=MockAcquisitionFunction(), - bounds=bounds, - q=10, - num_restarts=1, - raw_samples=2, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "sample_around_best": sample_around_best, - }, - ) + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=MockAcquisitionFunction(), + bounds=bounds, + q=10, + num_restarts=1, + raw_samples=2, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "sample_around_best": sample_around_best, + }, + ) + self.assertTrue( + any(issubclass(w.category, SamplingWarning) for w in ws) + ) + expected_shape = torch.Size([1, 10, d]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 + ) + if ffs is not None: + for idx, val in ffs.items(): self.assertTrue( - any(issubclass(w.category, SamplingWarning) for w in ws) + torch.all(batch_initial_conditions[..., idx] == val) ) - expected_shape = torch.Size([1, 10, d]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 - ) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) def test_gen_batch_initial_conditions_warning(self) -> None: for dtype in (torch.float, torch.double): @@ -581,51 +697,51 @@ def test_gen_batch_initial_conditions_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - equality_constraints, - equality=True, - ), - 1e-6, - ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + equality_constraints, + equality=True, + ), + 1e-6, + ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - self.assertTrue((raw_samps[..., 0] == 0.5).all()) - self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + self.assertTrue((raw_samps[..., 0] == 0.5).all()) + self.assertTrue((-4 * raw_samps[..., 1] >= -3).all()) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_interpoint_constraints(self): for dtype in (torch.float, torch.double): @@ -679,33 +795,33 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self): inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) - expected_shape = torch.Size([2, 3, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - - self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) - - self.assertAllClose( - batch_initial_conditions[0, 0, 0], - batch_initial_conditions[0, 1, 0], - batch_initial_conditions[0, 2, 0], - atol=1e-7, - ) + expected_shape = torch.Size([2, 3, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertAllClose( - batch_initial_conditions[1, 0, 0], - batch_initial_conditions[1, 1, 0], - batch_initial_conditions[1, 2, 0], - ) - self.assertLess( - _get_max_violation_of_constraints( - batch_initial_conditions, - inequality_constraints, - equality=False, - ), - 1e-6, - ) + self.assertTrue((batch_initial_conditions.sum(dim=-1) <= 1).all()) + + self.assertAllClose( + batch_initial_conditions[0, 0, 0], + batch_initial_conditions[0, 1, 0], + batch_initial_conditions[0, 2, 0], + atol=1e-7, + ) + + self.assertAllClose( + batch_initial_conditions[1, 0, 0], + batch_initial_conditions[1, 1, 0], + batch_initial_conditions[1, 2, 0], + ) + self.assertLess( + _get_max_violation_of_constraints( + batch_initial_conditions, + inequality_constraints, + equality=False, + ), + 1e-6, + ) def test_gen_batch_initial_conditions_generator(self): mock_acqf = MockAcquisitionFunction() @@ -727,7 +843,9 @@ def generator(n: int, q: int, seed: int | None): dtype=bounds.dtype, device=self.device, ) - X_rnd = bounds[0] + (bounds[1] - bounds[0]) * X_rnd_nlzd + X_rnd = unnormalize( + X_rnd_nlzd, bounds, update_constant_bounds=False + ) X_rnd[..., -1] = 0.42 return X_rnd @@ -756,20 +874,20 @@ def generator(n: int, q: int, seed: int | None): "init_batch_limit": init_batch_limit, }, ) - expected_shape = torch.Size([4, 2, 3]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + expected_shape = torch.Size([4, 2, 3]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all()) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_error_generator_with_sample_around_best(self): tkwargs = {"device": self.device, "dtype": torch.double} @@ -852,39 +970,40 @@ def test_gen_batch_initial_conditions_fixed_X_fantasies(self): }, fixed_X_fantasies=fixed_X_fantasies, ) - expected_shape = torch.Size([2, 4, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + expected_shape = torch.Size([2, 4, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([4, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., 0, idx] == val) - ) - self.assertTrue( - torch.equal( - batch_initial_conditions[:, 1:], - fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., 0, idx] == val) ) + self.assertTrue( + torch.equal( + batch_initial_conditions[:, 1:], + fixed_X_fantasies.unsqueeze(0).expand(2, 3, 2), ) + ) + # test wrong shape msg = ( "`fixed_X_fantasies` and `bounds` must both have the same trailing"