diff --git a/tutorials/baxus.ipynb b/tutorials/baxus.ipynb new file mode 100644 index 0000000000..7724e941a4 --- /dev/null +++ b/tutorials/baxus.ipynb @@ -0,0 +1,1033 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## BO with BAxUS and TS/EI\n", + "\n", + "In this tutorial, we show how to implement **B**ayesian optimization with **a**daptively e**x**panding s**u**bspace**s** (BAxUS) [1] in a closed loop in BoTorch.\n", + "The tutorial is purposefully similar to the [TuRBO tutorial](https://botorch.org/tutorials/turbo_1) to highlight the differences in the implementations.\n", + "\n", + "This implementation supports either Expected Improvement (EI) or Thompson sampling (TS). We optimize the Branin2 function [2] with 498 dummy dimensions$ and show that BAxUS outperforms EI as well as Sobol.\n", + "\n", + "Since BoTorch assumes a maximization problem, we will attempt to maximize $-f(x)$ to achieve $\\max_{x\\in \\mathcal{X}} -f(x)=0$.\n", + "\n", + "- [1]: [Papenmeier, Leonard, et al. Increasing the Scope as You Learn: Adaptive Bayesian Optimization in Nested Subspaces. Advances in Neural Information Processing Systems. 2022](https://openreview.net/pdf?id=e4Wf6112DI)\n", + "- [2]: [Branin Test Function](https://www.sfu.ca/~ssurjano/branin.html)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on cuda\n" + ] + } + ], + "source": [ + "import math\n", + "import os\n", + "from dataclasses import dataclass\n", + "\n", + "import botorch\n", + "import gpytorch\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from gpytorch.constraints import Interval\n", + "from gpytorch.kernels import MaternKernel, ScaleKernel\n", + "from gpytorch.likelihoods import GaussianLikelihood\n", + "from gpytorch.mlls import ExactMarginalLogLikelihood\n", + "from torch.quasirandom import SobolEngine\n", + "\n", + "from botorch.acquisition.analytic import ExpectedImprovement\n", + "from botorch.exceptions import ModelFittingError\n", + "from botorch.fit import fit_gpytorch_mll\n", + "from botorch.generation import MaxPosteriorSampling\n", + "from botorch.models import SingleTaskGP\n", + "from botorch.optim import optimize_acqf\n", + "from botorch.test_functions import Branin\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on {device}\")\n", + "dtype = torch.double\n", + "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize the augmented Branin function\n", + "\n", + "The goal is to minimize the embedded Branin function\n", + "\n", + "$f(x_1, x_2, \\ldots, x_{20}) = \\left (x_2-\\frac{5.1}{4\\pi^2}x_1^2+\\frac{5}{\\pi}x_1-6\\right )^2+10\\cdot \\left (1-\\frac{1}{8\\pi}\\right )\\cos(x_1)+10$\n", + "\n", + "with bounds [-5, 10] for $x_1$ and [0, 15] for $x_2$ (all other dimensions are ignored). The function has three minima with an optimal value of $0.397887$.\n", + "\n", + "As mentioned above, since botorch assumes a maximization problem, we instead maximize $-f(x)$." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a function with dummy variables\n", + "\n", + "We first define a new function where we only pass the first two input dimensions to the actual Branin function." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "branin = Branin(negate=True).to(device=device, dtype=dtype)\n", + "\n", + "\n", + "def branin_emb(x):\n", + " \"\"\"x is assumed to be in [-1, 1]^D\"\"\"\n", + " lb, ub = branin.bounds\n", + " return branin(lb + (ub - lb) * (x[..., :2] + 1) / 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fun = branin_emb\n", + "dim = 500\n", + "\n", + "n_init = 10\n", + "max_cholesky_size = float(\"inf\") # Always use Cholesky" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Maintain the BAxUS state\n", + "BAxUS needs to maintain a state, which includes the length of the trust region, success and failure counters, success and failure tolerance, etc. \n", + "In contrast to TuRBO, the failure tolerance depends on the target dimensionality.\n", + "\n", + "In this tutorial we store the state in a dataclass and update the state of TuRBO after each batch evaluation. \n", + "\n", + "**Note**: These settings assume that the domain has been scaled to $[-1, 1]^d$" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class BaxusState:\n", + " dim: int\n", + " eval_budget: int\n", + " new_bins_on_split: int = 3\n", + " d_init: int = float(\"nan\") # Note: post-initialized\n", + " target_dim: int = float(\"nan\") # Note: post-initialized\n", + " n_splits: int = float(\"nan\") # Note: post-initialized\n", + " length: float = 0.8\n", + " length_init: float = 0.8\n", + " length_min: float = 0.5**7\n", + " length_max: float = 1.6\n", + " failure_counter: int = 0\n", + " success_counter: int = 0\n", + " success_tolerance: int = 3\n", + " best_value: float = -float(\"inf\")\n", + " restart_triggered: bool = False\n", + "\n", + " def __post_init__(self):\n", + " n_splits = round(math.log(self.dim, self.new_bins_on_split + 1))\n", + " self.d_init = 1 + np.argmin(\n", + " np.abs(\n", + " (1 + np.arange(self.new_bins_on_split))\n", + " * (1 + self.new_bins_on_split) ** n_splits\n", + " - self.dim\n", + " )\n", + " )\n", + " self.target_dim = self.d_init\n", + " self.n_splits = n_splits\n", + "\n", + " @property\n", + " def split_budget(self) -> int:\n", + " return round(\n", + " -1\n", + " * (self.new_bins_on_split * self.eval_budget * self.target_dim)\n", + " / (self.d_init * (1 - (self.new_bins_on_split + 1) ** (self.n_splits + 1)))\n", + " )\n", + "\n", + " @property\n", + " def failure_tolerance(self) -> int:\n", + " if self.target_dim == self.dim:\n", + " return self.target_dim\n", + " k = math.floor(math.log(self.length_min / self.length_init, 0.5))\n", + " split_budget = self.split_budget\n", + " return min(self.target_dim, max(1, math.floor(split_budget / k)))\n", + "\n", + "\n", + "def update_state(state, Y_next):\n", + " if max(Y_next) > state.best_value + 1e-3 * math.fabs(state.best_value):\n", + " state.success_counter += 1\n", + " state.failure_counter = 0\n", + " else:\n", + " state.success_counter = 0\n", + " state.failure_counter += 1\n", + "\n", + " if state.success_counter == state.success_tolerance: # Expand trust region\n", + " state.length = min(2.0 * state.length, state.length_max)\n", + " state.success_counter = 0\n", + " elif state.failure_counter == state.failure_tolerance: # Shrink trust region\n", + " state.length /= 2.0\n", + " state.failure_counter = 0\n", + "\n", + " state.best_value = max(state.best_value, max(Y_next).item())\n", + " if state.length < state.length_min:\n", + " state.restart_triggered = True\n", + " return state" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a BAxUS embedding\n", + "\n", + "We now show how to create the BAxUS embedding. The essential idea is to assign input dimensions to target dimensions and to assign a sign $\\in \\pm 1$ to each input dimension, similar to the HeSBO embedding. \n", + "We create the embedding matrix that is used to project points from the target to the input space. The matrix is sparse, each column has precisely one non-zero entry that is either 1 or -1." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 1., 0., -1., 0., 0., 1., 0., 0., -1., 0.],\n", + " [ 0., -1., 0., 0., 0., 0., -1., 0., 0., -1.],\n", + " [ 0., 0., 0., -1., 1., 0., 0., 1., 0., 0.]], device='cuda:0',\n", + " dtype=torch.float64)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def embedding_matrix(input_dim: int, target_dim: int) -> torch.Tensor:\n", + " if (\n", + " target_dim >= input_dim\n", + " ): # return identity matrix if target size greater than input size\n", + " return torch.eye(input_dim, device=device, dtype=dtype)\n", + "\n", + " input_dims_perm = (\n", + " torch.randperm(input_dim, device=device) + 1\n", + " ) # add 1 to indices for padding column in matrix\n", + "\n", + " bins = torch.tensor_split(\n", + " input_dims_perm, target_dim\n", + " ) # split dims into almost equally-sized bins\n", + " bins = torch.nn.utils.rnn.pad_sequence(\n", + " bins, batch_first=True\n", + " ) # zero pad bins, the index 0 will be cut off later\n", + "\n", + " mtrx = torch.zeros(\n", + " (target_dim, input_dim + 1), dtype=dtype, device=device\n", + " ) # add one extra column for padding\n", + " mtrx = mtrx.scatter_(\n", + " 1,\n", + " bins,\n", + " 2 * torch.randint(2, (target_dim, input_dim), dtype=dtype, device=device) - 1,\n", + " ) # fill mask with random +/- 1 at indices\n", + "\n", + " return mtrx[:, 1:] # cut off index zero as this corresponds to zero padding\n", + "\n", + "\n", + "embedding_matrix(10, 3) # example for an embedding matrix" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Function to increase the embedding\n", + "\n", + "Next, we write a helper function to increase the embedding and to bring observations to the increased target space." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def increase_embedding_and_observations(\n", + " S: torch.Tensor, X: torch.Tensor, n_new_bins: int\n", + ") -> torch.Tensor:\n", + " assert X.size(1) == S.size(0), \"Observations don't lie in row space of S\"\n", + "\n", + " S_update = S.clone()\n", + " X_update = X.clone()\n", + "\n", + " for row_idx in range(len(S)):\n", + " row = S[row_idx]\n", + " idxs_non_zero = torch.nonzero(row)\n", + " idxs_non_zero = idxs_non_zero[torch.randperm(len(idxs_non_zero))].squeeze()\n", + "\n", + " non_zero_elements = row[idxs_non_zero].squeeze()\n", + "\n", + " n_row_bins = min(\n", + " n_new_bins, len(idxs_non_zero)\n", + " ) # number of new bins is always less or equal than the contributing input dims in the row minus one\n", + "\n", + " new_bins = torch.tensor_split(idxs_non_zero, n_row_bins)[\n", + " 1:\n", + " ] # the dims in the first bin won't be moved\n", + " elements_to_move = torch.tensor_split(non_zero_elements, n_row_bins)[1:]\n", + "\n", + " new_bins_padded = torch.nn.utils.rnn.pad_sequence(\n", + " new_bins, batch_first=True\n", + " ) # pad the tuples of bins with zeros to apply _scatter\n", + " els_to_move_padded = torch.nn.utils.rnn.pad_sequence(\n", + " elements_to_move, batch_first=True\n", + " )\n", + "\n", + " S_stack = torch.zeros(\n", + " (n_row_bins - 1, len(row) + 1), device=device, dtype=dtype\n", + " ) # submatrix to stack on S_update\n", + "\n", + " S_stack = S_stack.scatter_(\n", + " 1, new_bins_padded + 1, els_to_move_padded\n", + " ) # fill with old values (add 1 to indices for padding column)\n", + "\n", + " S_update[\n", + " row_idx, torch.hstack(new_bins)\n", + " ] = 0 # set values that were move to zero in current row\n", + "\n", + " X_update = torch.hstack(\n", + " (X_update, X[:, row_idx].reshape(-1, 1).repeat(1, len(new_bins)))\n", + " ) # repeat observations for row at the end of X (column-wise)\n", + " S_update = torch.vstack(\n", + " (S_update, S_stack[:, 1:])\n", + " ) # stack onto S_update except for padding column\n", + "\n", + " return S_update, X_update" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "S before increase\n", + "tensor([[-1., -1., 0., 0., 0., -1., -1., 0., -1., 0.],\n", + " [ 0., 0., -1., -1., 1., 0., 0., 1., 0., -1.]], device='cuda:0',\n", + " dtype=torch.float64)\n", + "X before increase\n", + "tensor([[79, 84],\n", + " [85, 65],\n", + " [46, 11],\n", + " [95, 34],\n", + " [14, 36],\n", + " [10, 55],\n", + " [48, 47]])\n", + "S after increase\n", + "tensor([[-1., -1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., -1.],\n", + " [ 0., 0., 0., 0., 0., -1., 0., 0., -1., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., -1., 0., 0., 0.],\n", + " [ 0., 0., -1., 0., 0., 0., 0., 1., 0., 0.],\n", + " [ 0., 0., 0., -1., 0., 0., 0., 0., 0., 0.]], device='cuda:0',\n", + " dtype=torch.float64)\n", + "X after increase\n", + "tensor([[79, 84, 79, 79, 84, 84],\n", + " [85, 65, 85, 85, 65, 65],\n", + " [46, 11, 46, 46, 11, 11],\n", + " [95, 34, 95, 95, 34, 34],\n", + " [14, 36, 14, 14, 36, 36],\n", + " [10, 55, 10, 10, 55, 55],\n", + " [48, 47, 48, 48, 47, 47]])\n" + ] + } + ], + "source": [ + "S = embedding_matrix(10, 2)\n", + "X = torch.randint(100, (7, 2))\n", + "print(f\"S before increase\\n{S}\")\n", + "print(f\"X before increase\\n{X}\")\n", + "\n", + "S, X = increase_embedding_and_observations(S, X, 3)\n", + "print(f\"S after increase\\n{S}\")\n", + "print(f\"X after increase\\n{X}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Take a look at the state" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BaxusState(dim=500, eval_budget=500, new_bins_on_split=3, d_init=2, target_dim=2, n_splits=4, length=0.8, length_init=0.8, length_min=0.0078125, length_max=1.6, failure_counter=0, success_counter=0, success_tolerance=3, best_value=-inf, restart_triggered=False)\n" + ] + } + ], + "source": [ + "state = BaxusState(dim=dim, eval_budget=500)\n", + "print(state)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate initial points\n", + "This generates an initial set of Sobol points that we use to start of the BO loop." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def get_initial_points(dim, n_pts, seed=0):\n", + " sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)\n", + " X_init = (\n", + " 2 * sobol.draw(n=n_pts).to(dtype=dtype, device=device) - 1\n", + " ) # points have to be in [-1, 1]^d\n", + " return X_init" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate new batch\n", + "Given the current `state` and a probabilistic (GP) `model` built from observations `X` and `Y`, we generate a new batch of points. \n", + "\n", + "This method works on the domain $[-1, +1]^d$, so make sure to not pass in observations from the true domain. `unnormalize` is called before the true function is evaluated which will first map the points back to the original domain.\n", + "\n", + "We support either TS and qEI which can be specified via the `acqf` argument." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def create_candidate(\n", + " state,\n", + " model, # GP model\n", + " X, # Evaluated points on the domain [-1, 1]^d\n", + " Y, # Function values\n", + " n_candidates=None, # Number of candidates for Thompson sampling\n", + " num_restarts=10,\n", + " raw_samples=512,\n", + " acqf=\"ts\", # \"ei\" or \"ts\"\n", + "):\n", + " assert acqf in (\"ts\", \"ei\")\n", + " assert X.min() >= -1.0 and X.max() <= 1.0 and torch.all(torch.isfinite(Y))\n", + " if n_candidates is None:\n", + " n_candidates = min(5000, max(2000, 200 * X.shape[-1]))\n", + "\n", + " # Scale the TR to be proportional to the lengthscales\n", + " x_center = X[Y.argmax(), :].clone()\n", + " weights = model.covar_module.base_kernel.lengthscale.detach().view(-1)\n", + " weights = weights / weights.mean()\n", + " weights = weights / torch.prod(weights.pow(1.0 / len(weights)))\n", + " tr_lb = torch.clamp(x_center - weights * state.length, -1.0, 1.0)\n", + " tr_ub = torch.clamp(x_center + weights * state.length, -1.0, 1.0)\n", + "\n", + " if acqf == \"ts\":\n", + " dim = X.shape[-1]\n", + " sobol = SobolEngine(dim, scramble=True)\n", + " pert = sobol.draw(n_candidates).to(dtype=dtype, device=device)\n", + " pert = tr_lb + (tr_ub - tr_lb) * pert\n", + "\n", + " # Create a perturbation mask\n", + " prob_perturb = min(20.0 / dim, 1.0)\n", + " mask = torch.rand(n_candidates, dim, dtype=dtype, device=device) <= prob_perturb\n", + " ind = torch.where(mask.sum(dim=1) == 0)[0]\n", + " mask[ind, torch.randint(0, dim, size=(len(ind),), device=device)] = 1\n", + "\n", + " # Create candidate points from the perturbations and the mask\n", + " X_cand = x_center.expand(n_candidates, dim).clone()\n", + " X_cand[mask] = pert[mask]\n", + "\n", + " # Sample on the candidate points\n", + " thompson_sampling = MaxPosteriorSampling(model=model, replacement=False)\n", + " with torch.no_grad(): # We don't need gradients when using TS\n", + " X_next = thompson_sampling(X_cand, num_samples=1)\n", + "\n", + " elif acqf == \"ei\":\n", + " ei = ExpectedImprovement(model, train_Y.max(), maximize=True)\n", + " X_next, acq_value = optimize_acqf(\n", + " ei,\n", + " bounds=torch.stack([tr_lb, tr_ub]),\n", + " q=1,\n", + " num_restarts=num_restarts,\n", + " raw_samples=raw_samples,\n", + " )\n", + "\n", + " return X_next" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimization loop\n", + "This simple loop runs one instance of BAxUS with Thompson sampling until convergence.\n", + "\n", + "BAxUS works on a fixed evaluation budget and shrinks the trust region until the minimal trust region size is reached (`state[\"restart_triggered\"]` is set to `True`).\n", + "Then, BAxUS increases the target space and carries over the observations to the updated space. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "iteration 11, d=2) Best value: -2.74, TR length: 0.4\n", + "iteration 12, d=2) Best value: -2.64, TR length: 0.4\n", + "iteration 13, d=2) Best value: -2.55, TR length: 0.4\n", + "iteration 14, d=2) Best value: -2.55, TR length: 0.2\n", + "iteration 15, d=2) Best value: -2.55, TR length: 0.1\n", + "iteration 16, d=2) Best value: -2.55, TR length: 0.05\n", + "iteration 17, d=2) Best value: -2.55, TR length: 0.025\n", + "iteration 18, d=2) Best value: -2.55, TR length: 0.0125\n", + "iteration 19, d=2) Best value: -2.55, TR length: 0.00625\n", + "increasing target space\n", + "new dimensionality: 6\n", + "iteration 20, d=6) Best value: -2.54, TR length: 0.8\n", + "iteration 21, d=6) Best value: -2.54, TR length: 0.4\n", + "iteration 22, d=6) Best value: -1.79, TR length: 0.4\n", + "iteration 23, d=6) Best value: -0.628, TR length: 0.4\n", + "iteration 24, d=6) Best value: -0.456, TR length: 0.8\n", + "iteration 25, d=6) Best value: -0.456, TR length: 0.4\n", + "iteration 26, d=6) Best value: -0.42, TR length: 0.4\n", + "iteration 27, d=6) Best value: -0.42, TR length: 0.2\n", + "iteration 28, d=6) Best value: -0.42, TR length: 0.1\n", + "iteration 29, d=6) Best value: -0.42, TR length: 0.05\n", + "iteration 30, d=6) Best value: -0.42, TR length: 0.025\n", + "iteration 31, d=6) Best value: -0.41, TR length: 0.025\n", + "iteration 32, d=6) Best value: -0.403, TR length: 0.025\n", + "iteration 33, d=6) Best value: -0.4, TR length: 0.05\n", + "iteration 34, d=6) Best value: -0.399, TR length: 0.05\n", + "iteration 35, d=6) Best value: -0.399, TR length: 0.025\n", + "iteration 36, d=6) Best value: -0.398, TR length: 0.025\n", + "iteration 37, d=6) Best value: -0.398, TR length: 0.0125\n", + "iteration 38, d=6) Best value: -0.398, TR length: 0.00625\n", + "increasing target space\n", + "new dimensionality: 18\n", + "iteration 39, d=18) Best value: -0.398, TR length: 0.4\n", + "iteration 40, d=18) Best value: -0.398, TR length: 0.2\n", + "iteration 41, d=18) Best value: -0.398, TR length: 0.1\n", + "iteration 42, d=18) Best value: -0.398, TR length: 0.05\n", + "iteration 43, d=18) Best value: -0.398, TR length: 0.025\n", + "iteration 44, d=18) Best value: -0.398, TR length: 0.0125\n", + "iteration 45, d=18) Best value: -0.398, TR length: 0.00625\n", + "increasing target space\n", + "new dimensionality: 54\n", + "iteration 46, d=54) Best value: -0.398, TR length: 0.4\n", + "iteration 47, d=54) Best value: -0.398, TR length: 0.2\n", + "iteration 48, d=54) Best value: -0.398, TR length: 0.1\n", + "iteration 49, d=54) Best value: -0.398, TR length: 0.05\n", + "iteration 50, d=54) Best value: -0.398, TR length: 0.025\n", + "iteration 51, d=54) Best value: -0.398, TR length: 0.0125\n", + "iteration 52, d=54) Best value: -0.398, TR length: 0.00625\n", + "increasing target space\n", + "new dimensionality: 162\n", + "iteration 53, d=162) Best value: -0.398, TR length: 0.8\n", + "iteration 54, d=162) Best value: -0.398, TR length: 0.8\n", + "iteration 55, d=162) Best value: -0.398, TR length: 0.4\n", + "iteration 56, d=162) Best value: -0.398, TR length: 0.4\n", + "iteration 57, d=162) Best value: -0.398, TR length: 0.4\n", + "iteration 58, d=162) Best value: -0.398, TR length: 0.2\n", + "iteration 59, d=162) Best value: -0.398, TR length: 0.2\n", + "iteration 60, d=162) Best value: -0.398, TR length: 0.2\n", + "iteration 61, d=162) Best value: -0.398, TR length: 0.1\n", + "iteration 62, d=162) Best value: -0.398, TR length: 0.1\n", + "iteration 63, d=162) Best value: -0.398, TR length: 0.1\n", + "iteration 64, d=162) Best value: -0.398, TR length: 0.05\n", + "iteration 65, d=162) Best value: -0.398, TR length: 0.05\n", + "iteration 66, d=162) Best value: -0.398, TR length: 0.05\n", + "iteration 67, d=162) Best value: -0.398, TR length: 0.025\n", + "iteration 68, d=162) Best value: -0.398, TR length: 0.025\n", + "iteration 69, d=162) Best value: -0.398, TR length: 0.025\n", + "iteration 70, d=162) Best value: -0.398, TR length: 0.0125\n", + "iteration 71, d=162) Best value: -0.398, TR length: 0.0125\n", + "iteration 72, d=162) Best value: -0.398, TR length: 0.0125\n", + "iteration 73, d=162) Best value: -0.398, TR length: 0.00625\n", + "increasing target space\n", + "new dimensionality: 486\n", + "iteration 74, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 75, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 76, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 77, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 78, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 79, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 80, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 81, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 82, d=486) Best value: -0.398, TR length: 0.8\n", + "iteration 83, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 84, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 85, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 86, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 87, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 88, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 89, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 90, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 91, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 92, d=486) Best value: -0.398, TR length: 0.4\n", + "iteration 93, d=486) Best value: -0.398, TR length: 0.2\n", + "iteration 94, d=486) Best value: -0.398, TR length: 0.2\n", + "iteration 95, d=486) Best value: -0.398, TR length: 0.2\n", + "iteration 96, d=486) Best value: -0.398, TR length: 0.2\n", + "iteration 97, d=486) Best value: -0.398, TR length: 0.2\n", + "iteration 98, d=486) Best value: -0.398, TR length: 0.2\n", + "iteration 99, d=486) Best value: -0.398, TR length: 0.2\n", + "iteration 100, d=486) Best value: -0.398, TR length: 0.2\n" + ] + } + ], + "source": [ + "evaluation_budget = 100\n", + "\n", + "state = BaxusState(dim=dim, eval_budget=evaluation_budget - n_init)\n", + "S = embedding_matrix(input_dim=state.dim, target_dim=state.d_init)\n", + "\n", + "X_baxus_target = get_initial_points(state.d_init, n_init)\n", + "X_baxus_input = X_baxus_target @ S\n", + "Y_baxus = torch.tensor(\n", + " [branin_emb(x) for x in X_baxus_input], dtype=dtype, device=device\n", + ").unsqueeze(-1)\n", + "\n", + "\n", + "NUM_RESTARTS = 10 if not SMOKE_TEST else 2\n", + "RAW_SAMPLES = 512 if not SMOKE_TEST else 4\n", + "N_CANDIDATES = min(5000, max(2000, 200 * dim)) if not SMOKE_TEST else 4\n", + "\n", + "# Disable input scaling checks as we normalize to [-1, 1]\n", + "with botorch.settings.validate_input_scaling(False):\n", + "\n", + " for _ in range(evaluation_budget - n_init): # Run until evaluation budget depleted\n", + " # Fit a GP model\n", + " train_Y = (Y_baxus - Y_baxus.mean()) / Y_baxus.std()\n", + " likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))\n", + " covar_module = (\n", + " ScaleKernel( # Use the same lengthscale prior as in the TuRBO paper\n", + " MaternKernel(\n", + " nu=2.5,\n", + " ard_num_dims=state.target_dim,\n", + " lengthscale_constraint=Interval(0.005, 10),\n", + " ),\n", + " outputscale_constraint=Interval(0.05, 10),\n", + " )\n", + " )\n", + " model = SingleTaskGP(\n", + " X_baxus_target, train_Y, covar_module=covar_module, likelihood=likelihood\n", + " )\n", + " mll = ExactMarginalLogLikelihood(model.likelihood, model)\n", + "\n", + " # Do the fitting and acquisition function optimization inside the Cholesky context\n", + " with gpytorch.settings.max_cholesky_size(max_cholesky_size):\n", + " # Fit the model\n", + " try:\n", + " fit_gpytorch_mll(mll)\n", + " except ModelFittingError:\n", + " # Right after increasing the target dimensionality, the covariance matrix becomes indefinite\n", + " # In this case, the Cholesky decomposition might fail due to numerical instabilities\n", + " # In this case, we revert to Adam-based optimization\n", + " optimizer = torch.optim.Adam([{\"params\": model.parameters()}], lr=0.1)\n", + "\n", + " for _ in range(100):\n", + " optimizer.zero_grad()\n", + " output = model(X_baxus_target)\n", + " loss = -mll(output, train_Y.flatten())\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Create a batch\n", + " X_next_target = create_candidate(\n", + " state=state,\n", + " model=model,\n", + " X=X_baxus_target,\n", + " Y=train_Y,\n", + " n_candidates=N_CANDIDATES,\n", + " num_restarts=NUM_RESTARTS,\n", + " raw_samples=RAW_SAMPLES,\n", + " acqf=\"ts\",\n", + " )\n", + "\n", + " X_next_input = X_next_target @ S\n", + "\n", + " Y_next = torch.tensor(\n", + " [branin_emb(x) for x in X_next_input], dtype=dtype, device=device\n", + " ).unsqueeze(-1)\n", + "\n", + " # Update state\n", + " state = update_state(state=state, Y_next=Y_next)\n", + "\n", + " # Append data\n", + " X_baxus_input = torch.cat((X_baxus_input, X_next_input), dim=0)\n", + " X_baxus_target = torch.cat((X_baxus_target, X_next_target), dim=0)\n", + " Y_baxus = torch.cat((Y_baxus, Y_next), dim=0)\n", + "\n", + " # Print current status\n", + " print(\n", + " f\"iteration {len(X_baxus_input)}, d={len(X_baxus_target.T)}) Best value: {state.best_value:.3}, TR length: {state.length:.3}\"\n", + " )\n", + "\n", + " if state.restart_triggered:\n", + " state.restart_triggered = False\n", + " print(\"increasing target space\")\n", + " S, X_baxus_target = increase_embedding_and_observations(\n", + " S, X_baxus_target, state.new_bins_on_split\n", + " )\n", + " print(f\"new dimensionality: {len(S)}\")\n", + " state.target_dim = len(S)\n", + " state.length = state.length_init\n", + " state.failure_counter = 0\n", + " state.success_counter = 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GP-EI\n", + "As a baseline, we compare BAxUS to Expected Improvement (EI)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11) Best value: -7.04e-01\n", + "12) Best value: -7.04e-01\n", + "13) Best value: -7.04e-01\n", + "14) Best value: -7.04e-01\n", + "15) Best value: -7.04e-01\n", + "16) Best value: -7.04e-01\n", + "17) Best value: -7.04e-01\n", + "18) Best value: -7.04e-01\n", + "19) Best value: -7.04e-01\n", + "20) Best value: -7.04e-01\n", + "21) Best value: -7.04e-01\n", + "22) Best value: -7.04e-01\n", + "23) Best value: -7.04e-01\n", + "24) Best value: -7.04e-01\n", + "25) Best value: -7.04e-01\n", + "26) Best value: -7.04e-01\n", + "27) Best value: -7.04e-01\n", + "28) Best value: -7.04e-01\n", + "29) Best value: -7.04e-01\n", + "30) Best value: -7.04e-01\n", + "31) Best value: -7.04e-01\n", + "32) Best value: -7.04e-01\n", + "33) Best value: -7.04e-01\n", + "34) Best value: -7.04e-01\n", + "35) Best value: -7.04e-01\n", + "36) Best value: -7.04e-01\n", + "37) Best value: -7.04e-01\n", + "38) Best value: -7.04e-01\n", + "39) Best value: -7.04e-01\n", + "40) Best value: -7.04e-01\n", + "41) Best value: -7.04e-01\n", + "42) Best value: -7.04e-01\n", + "43) Best value: -7.04e-01\n", + "44) Best value: -7.04e-01\n", + "45) Best value: -7.04e-01\n", + "46) Best value: -7.04e-01\n", + "47) Best value: -7.04e-01\n", + "48) Best value: -7.04e-01\n", + "49) Best value: -7.04e-01\n", + "50) Best value: -7.04e-01\n", + "51) Best value: -7.04e-01\n", + "52) Best value: -7.04e-01\n", + "53) Best value: -7.04e-01\n", + "54) Best value: -7.04e-01\n", + "55) Best value: -7.04e-01\n", + "56) Best value: -7.04e-01\n", + "57) Best value: -7.04e-01\n", + "58) Best value: -7.04e-01\n", + "59) Best value: -7.04e-01\n", + "60) Best value: -7.04e-01\n", + "61) Best value: -7.04e-01\n", + "62) Best value: -7.04e-01\n", + "63) Best value: -7.04e-01\n", + "64) Best value: -7.04e-01\n", + "65) Best value: -7.04e-01\n", + "66) Best value: -7.04e-01\n", + "67) Best value: -7.04e-01\n", + "68) Best value: -7.04e-01\n", + "69) Best value: -7.04e-01\n", + "70) Best value: -7.04e-01\n", + "71) Best value: -7.04e-01\n", + "72) Best value: -7.04e-01\n", + "73) Best value: -7.04e-01\n", + "74) Best value: -7.04e-01\n", + "75) Best value: -7.04e-01\n", + "76) Best value: -7.04e-01\n", + "77) Best value: -7.04e-01\n", + "78) Best value: -7.04e-01\n", + "79) Best value: -7.04e-01\n", + "80) Best value: -7.04e-01\n", + "81) Best value: -7.04e-01\n", + "82) Best value: -7.04e-01\n", + "83) Best value: -7.04e-01\n", + "84) Best value: -7.04e-01\n", + "85) Best value: -7.04e-01\n", + "86) Best value: -7.04e-01\n", + "87) Best value: -7.04e-01\n", + "88) Best value: -7.04e-01\n", + "89) Best value: -7.04e-01\n", + "90) Best value: -7.04e-01\n", + "91) Best value: -7.04e-01\n", + "92) Best value: -7.04e-01\n", + "93) Best value: -7.04e-01\n", + "94) Best value: -7.04e-01\n", + "95) Best value: -7.04e-01\n", + "96) Best value: -7.04e-01\n", + "97) Best value: -7.04e-01\n", + "98) Best value: -7.04e-01\n", + "99) Best value: -7.04e-01\n", + "100) Best value: -7.04e-01\n" + ] + } + ], + "source": [ + "X_ei = get_initial_points(dim, n_init)\n", + "Y_ei = torch.tensor(\n", + " [branin_emb(x) for x in X_ei], dtype=dtype, device=device\n", + ").unsqueeze(-1)\n", + "\n", + "# Disable input scaling checks as we normalize to [-1, 1]\n", + "with botorch.settings.validate_input_scaling(False):\n", + " while len(Y_ei) < len(Y_baxus):\n", + " train_Y = (Y_ei - Y_ei.mean()) / Y_ei.std()\n", + " likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))\n", + " model = SingleTaskGP(X_ei, train_Y, likelihood=likelihood)\n", + " mll = ExactMarginalLogLikelihood(model.likelihood, model)\n", + " optimizer = torch.optim.Adam([{\"params\": model.parameters()}], lr=0.1)\n", + " model.train()\n", + " model.likelihood.train()\n", + " for _ in range(50):\n", + " optimizer.zero_grad()\n", + " output = model(X_ei)\n", + " loss = -mll(output, train_Y.squeeze())\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Create a batch\n", + " ei = ExpectedImprovement(model, train_Y.max(), maximize=True)\n", + " candidate, acq_value = optimize_acqf(\n", + " ei,\n", + " bounds=torch.stack(\n", + " [\n", + " -torch.ones(dim, dtype=dtype, device=device),\n", + " torch.ones(dim, dtype=dtype, device=device),\n", + " ]\n", + " ),\n", + " q=1,\n", + " num_restarts=NUM_RESTARTS,\n", + " raw_samples=RAW_SAMPLES,\n", + " )\n", + " Y_next = torch.tensor(\n", + " [branin_emb(x) for x in candidate], dtype=dtype, device=device\n", + " ).unsqueeze(-1)\n", + "\n", + " # Append data\n", + " X_ei = torch.cat((X_ei, candidate), axis=0)\n", + " Y_ei = torch.cat((Y_ei, Y_next), axis=0)\n", + "\n", + " # Print current status\n", + " print(f\"{len(X_ei)}) Best value: {Y_ei.max().item():.2e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sobol" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "X_Sobol = (\n", + " SobolEngine(dim, scramble=True, seed=0)\n", + " .draw(len(X_baxus_input))\n", + " .to(dtype=dtype, device=device)\n", + " * 2\n", + " - 1\n", + ")\n", + "Y_Sobol = torch.tensor(\n", + " [branin_emb(x) for x in X_Sobol], dtype=dtype, device=device\n", + ").unsqueeze(-1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare the methods\n", + "\n", + "We show the regret of the different methods." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%matplotlib inline\n", + "\n", + "names = [\"BAxUS\", \"EI\", \"Sobol\"]\n", + "runs = [Y_baxus, Y_ei, Y_Sobol]\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "for name, run in zip(names, runs):\n", + " fx = np.maximum.accumulate(run.cpu())\n", + " plt.plot(-fx + branin.optimal_value, marker=\"\", lw=3)\n", + "\n", + "plt.ylabel(\"Regret\", fontsize=18)\n", + "plt.xlabel(\"Number of evaluations\", fontsize=18)\n", + "plt.title(f\"{dim}D Embedded Branin\", fontsize=24)\n", + "plt.xlim([0, len(Y_baxus)])\n", + "plt.yscale(\"log\")\n", + "\n", + "plt.grid(True)\n", + "plt.tight_layout()\n", + "plt.legend(\n", + " names + [\"Global optimal value\"],\n", + " loc=\"lower center\",\n", + " bbox_to_anchor=(0, -0.08, 1, 1),\n", + " bbox_transform=plt.gcf().transFigure,\n", + " ncol=4,\n", + " fontsize=16,\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/website/tutorials.json b/website/tutorials.json index 1534de39c1..7f61e0b3a1 100644 --- a/website/tutorials.json +++ b/website/tutorials.json @@ -26,6 +26,10 @@ "id": "turbo_1", "title": "Trust Region Bayesian Optimization (TuRBO)" }, + { + "id": "baxus", + "title": "Bayesian optimization with adaptively expanding subspaces (BAxUS)" + }, { "id": "scalable_constrained_bo", "title": "Scalable Constrained Bayesian Optimization (SCBO)"