diff --git a/botorch/models/transforms/__init__.py b/botorch/models/transforms/__init__.py index b56bf19b75..0a913c5d8f 100644 --- a/botorch/models/transforms/__init__.py +++ b/botorch/models/transforms/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from botorch.models.transforms.factory import get_rounding_input_transform from botorch.models.transforms.input import ( ChainedInputTransform, Normalize, @@ -20,6 +21,7 @@ __all__ = [ + "get_rounding_input_transform", "Bilog", "ChainedInputTransform", "ChainedOutcomeTransform", diff --git a/botorch/models/transforms/factory.py b/botorch/models/transforms/factory.py new file mode 100644 index 0000000000..847fdf1b7c --- /dev/null +++ b/botorch/models/transforms/factory.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import OrderedDict +from typing import Dict, List, Optional + +from botorch.models.transforms.input import ( + ChainedInputTransform, + Normalize, + OneHotToNumeric, + Round, +) +from torch import Tensor + + +def get_rounding_input_transform( + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + initialization: bool = False, + return_numeric: bool = False, + approximate: bool = False, +) -> ChainedInputTransform: + """Get a rounding input transform. + + The rounding function will take inputs from the unit cube, + unnormalize the integers raw search space, round the inputs, + and normalize them back to the unit cube. + + Categoricals are assumed to be one-hot encoded. Integers are + currently assumed to be contiguous ranges (e.g. [1,2,3] and not + [1,5,7]). + + TODO: support non-contiguous sets of integers by modifying + the rounding function. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters. + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + initialization: A boolean indicating whether this exact rounding + function is for initialization. For initialization, the bounds + for are expanded such that the end point of a range is selected + with same probability that an interior point is selected, after + rounding. + return_numeric: A boolean indicating whether to return numeric or + one-hot encoded categoricals. Returning a nummeric + representation is helpful if the downstream code (e.g. kernel) + expects a numeric representation of the categoricals. + approximate: A boolean indicating whether to use an approximate + rounding function. + + Returns: + The rounding function ChainedInputTransform. + """ + has_integers = integer_indices is not None and len(integer_indices) > 0 + has_categoricals = ( + categorical_features is not None and len(categorical_features) > 0 + ) + if not (has_integers or has_categoricals): + raise ValueError( + "A rounding function is a no-op " + "if there are no integer or categorical parammeters." + ) + if initialization and has_integers: + # this gives the extreme integer values (end points) + # the same probability as the interior values of the range + init_one_hot_bounds = one_hot_bounds.clone() + init_one_hot_bounds[0, integer_indices] -= 0.4999 + init_one_hot_bounds[1, integer_indices] += 0.4999 + else: + init_one_hot_bounds = one_hot_bounds + + tfs = OrderedDict() + if has_integers: + # unnormalize to integer space + tfs["unnormalize_tf"] = Normalize( + d=init_one_hot_bounds.shape[1], + bounds=init_one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=True, + reverse=True, + ) + # round + tfs["round"] = Round( + approximate=approximate, + transform_on_train=False, + transform_on_fantasize=True, + integer_indices=integer_indices, + categorical_features=categorical_features, + ) + if has_integers: + # renormalize to unit cube + tfs["normalize_tf"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=True, + reverse=False, + ) + if return_numeric and has_categoricals: + tfs["one_hot_to_numeric"] = OneHotToNumeric( + # this is the dimension using one-hot encoded representation + dim=one_hot_bounds.shape[-1], + categorical_features=categorical_features, + transform_on_train=True, + transform_on_eval=True, + transform_on_fantasize=True, + ) + tf = ChainedInputTransform(**tfs) + tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device) + tf.eval() + return tf diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 611944a3d3..09310163b5 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -31,6 +31,7 @@ from torch import nn, Tensor from torch.distributions import Kumaraswamy from torch.nn import Module, ModuleDict +from torch.nn.functional import one_hot class InputTransform(ABC): @@ -364,7 +365,7 @@ def __init__( raise ValueError("Elements of `indices` have to be smaller than `d`!") if len(indices.unique()) != len(indices): raise ValueError("Elements of `indices` tensor must be unique!") - self.indices = indices + self.register_buffer("indices", indices) torch.broadcast_shapes(coefficient.shape, offset.shape) self._d = d @@ -1370,3 +1371,135 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor: else: p = p(X) if self.indices is None else p(X[..., self.indices]) return p.transpose(-3, -2) # p is batch_shape x n_p x n x d + + +class OneHotToNumeric(InputTransform, Module): + r"""Transform categorical parameters from a one-hot to a numeric representation. + + This assumes that the categoricals are the trailing dimensions. + """ + + def __init__( + self, + dim: int, + categorical_features: Optional[Dict[int, int]] = None, + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + ) -> None: + r"""Initialize. + + Args: + dim: The dimension of the one-hot-encoded input. + categorical_features: A dictionary mapping the starting index of each + categorical feature to its cardinality. This assumes that categoricals + are one-hot encoded. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: False. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: False. + + Returns: + A `batch_shape x n x d'`-dim tensor of where the one-hot encoded + categoricals are transformed to integer representation. + """ + super().__init__() + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + categorical_features = categorical_features or {} + # sort by starting index + self.categorical_features = OrderedDict( + sorted(categorical_features.items(), key=lambda x: x[0]) + ) + if len(self.categorical_features) > 0: + self.categorical_start_idx = min(self.categorical_features.keys()) + # check that the trailing dimensions are categoricals + end = self.categorical_start_idx + err_msg = ( + f"{self.__class__.__name__} requires that the categorical " + "parameters are the rightmost elements." + ) + for start, card in self.categorical_features.items(): + # the end of one one-hot representation should be followed + # by the start of the next + if end != start: + raise ValueError(err_msg) + # This assumes that the categoricals are the trailing + # dimensions + end = start + card + if end != dim: + # check end + raise ValueError(err_msg) + # the numeric representation dimension is the total number of parameters + # (continuous, integer, and categorical) + self.numeric_dim = self.categorical_start_idx + len(categorical_features) + + def transform(self, X: Tensor) -> Tensor: + r"""Transform the categorical inputs into integer representation. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d'`-dim tensor of where the one-hot encoded + categoricals are transformed to integer representation. + """ + if len(self.categorical_features) > 0: + X_numeric = X[..., : self.numeric_dim].clone() + idx = self.categorical_start_idx + for start, card in self.categorical_features.items(): + X_numeric[..., idx] = X[..., start : start + card].argmax(dim=-1) + idx += 1 + return X_numeric + return X + + def untransform(self, X: Tensor) -> Tensor: + r"""Transform the categoricals from integer representation to one-hot. + + Args: + X: A `batch_shape x n x d'`-dim tensor of transformed inputs, where + the categoricals are represented as integers. + + Returns: + A `batch_shape x n x d`-dim tensor of inputs, where the categoricals + have been transformed to one-hot representation. + """ + if len(self.categorical_features) > 0: + self.numeric_dim + one_hot_categoricals = [ + # note that self.categorical_features is sorted by the starting index + # in one-hot representation + one_hot( + X[..., idx - len(self.categorical_features)].long(), + num_classes=cardinality, + ) + for idx, cardinality in enumerate(self.categorical_features.values()) + ] + X = torch.cat( + [ + X[..., : self.categorical_start_idx], + *one_hot_categoricals, + ], + dim=-1, + ) + return X + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + type(self) == type(other) + and (self.transform_on_train == other.transform_on_train) + and (self.transform_on_eval == other.transform_on_eval) + and (self.transform_on_fantasize == other.transform_on_fantasize) + and self.categorical_features == other.categorical_features + ) diff --git a/botorch/models/transforms/utils.py b/botorch/models/transforms/utils.py index a17e11c494..6d1f9411d2 100644 --- a/botorch/models/transforms/utils.py +++ b/botorch/models/transforms/utils.py @@ -7,7 +7,6 @@ from __future__ import annotations from functools import wraps - from typing import Tuple import torch diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index fa6d74c5f6..0660a7cc1f 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -136,6 +136,11 @@ Input Transforms .. automodule:: botorch.models.transforms.input :members: +Transform Factory Methods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.transforms.factory + :members: + Transform Utilities ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.models.transforms.utils diff --git a/test/models/transforms/test_factory.py b/test/models/transforms/test_factory.py new file mode 100644 index 0000000000..aa9c11aaac --- /dev/null +++ b/test/models/transforms/test_factory.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from botorch.models.transforms.factory import get_rounding_input_transform +from botorch.models.transforms.input import ChainedInputTransform, Normalize, Round +from botorch.utils.rounding import OneHotArgmaxSTE +from botorch.utils.testing import BotorchTestCase +from botorch.utils.transforms import normalize, unnormalize + + +class TestGetRoundingInputTransform(BotorchTestCase): + def test_get_rounding_input_transform(self): + for dtype in (torch.float, torch.double): + one_hot_bounds = torch.tensor( + [ + [0, 5], + [0, 4], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + ], + dtype=dtype, + device=self.device, + ).t() + with self.assertRaises(ValueError): + # test no integer or categorical + get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + ) + integer_indices = [1] + categorical_features = {2: 2, 4: 3} + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + ) + self.assertIsInstance(tf, ChainedInputTransform) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 3) + # test unnormalize + tf_name_i, tf_i = tfs[0] + self.assertEqual(tf_name_i, "unnormalize_tf") + self.assertIsInstance(tf_i, Normalize) + self.assertTrue(tf_i.reverse) + bounds = one_hot_bounds[:, integer_indices] + offset = bounds[:1, :] + coefficient = bounds[1:2, :] - offset + self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) + self.assertTrue(torch.equal(tf_i.offset, offset)) + self.assertEqual(tf_i._d, one_hot_bounds.shape[1]) + self.assertEqual( + tf_i.indices, torch.tensor(integer_indices, device=self.device) + ) + # test round + tf_name_i, tf_i = tfs[1] + self.assertEqual(tf_name_i, "round") + self.assertIsInstance(tf_i, Round) + self.assertEqual(tf_i.integer_indices.tolist(), integer_indices) + self.assertEqual(tf_i.categorical_features, categorical_features) + # test normalize + tf_name_i, tf_i = tfs[2] + self.assertEqual(tf_name_i, "normalize_tf") + self.assertIsInstance(tf_i, Normalize) + self.assertFalse(tf_i.reverse) + self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) + self.assertTrue(torch.equal(tf_i.offset, offset)) + self.assertEqual(tf_i._d, one_hot_bounds.shape[1]) + + # test forward + X = torch.rand( + 2, 4, one_hot_bounds.shape[1], dtype=dtype, device=self.device + ) + X_tf = tf(X) + # assert the continuous param is unaffected + self.assertTrue(torch.equal(X_tf[..., 0], X[..., 0])) + # check that integer params are rounded + X_int = X[..., integer_indices] + unnormalized_X_int = unnormalize(X_int, bounds) + rounded_X_int = normalize(unnormalized_X_int.round(), bounds) + self.assertTrue(torch.equal(rounded_X_int, X_tf[..., integer_indices])) + # check that categoricals are discretized + for start, card in categorical_features.items(): + end = start + card + discretized_feat = OneHotArgmaxSTE.apply(X[..., start:end]) + self.assertTrue(torch.equal(discretized_feat, X_tf[..., start:end])) + # test transform on train/eval/fantasize + for tf_i in tf.values(): + self.assertFalse(tf_i.transform_on_train) + self.assertTrue(tf_i.transform_on_eval) + self.assertTrue(tf_i.transform_on_fantasize) + + # test no integer + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + ) + tfs = list(tf.items()) + # round should be the only transform + self.assertEqual(len(tfs), 1) + tf_name_i, tf_i = tfs[0] + self.assertEqual(tf_name_i, "round") + self.assertIsInstance(tf_i, Round) + self.assertEqual(tf_i.integer_indices.tolist(), []) + self.assertEqual(tf_i.categorical_features, categorical_features) + # test no categoricals + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + ) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 3) + _, tf_i = tfs[1] + self.assertEqual(tf_i.integer_indices.tolist(), integer_indices) + self.assertEqual(tf_i.categorical_features, {}) + # test initialization + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + initialization=True, + ) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 3) + # check that bounds are adjusted for integers on unnormalize + _, tf_i = tfs[0] + offset_init = bounds[:1, :] - 0.4999 + coefficient_init = bounds[1:2, :] + 0.4999 - offset_init + self.assertTrue(torch.equal(tf_i.coefficient, coefficient_init)) + self.assertTrue(torch.equal(tf_i.offset, offset_init)) + # check that bounds are adjusted for integers on normalize + _, tf_i = tfs[2] + self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) + self.assertTrue(torch.equal(tf_i.offset, offset)) + # test return numeric + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + return_numeric=True, + ) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 4) + tf_name_i, tf_i = tfs[3] + self.assertEqual(tf_name_i, "one_hot_to_numeric") + # transform to numeric on train + # (e.g. for kernels that expect a integer representation) + self.assertTrue(tf_i.transform_on_train) + self.assertTrue(tf_i.transform_on_eval) + self.assertTrue(tf_i.transform_on_fantasize) + self.assertEqual(tf_i.categorical_features, categorical_features) + self.assertEqual(tf_i.numeric_dim, 4) + # test return numeric and no categorical + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + return_numeric=True, + ) + tfs = list(tf.items()) + # there should be no one hot to numeric transform + self.assertEqual(len(tfs), 3) diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 456a8b47ef..1893f3d982 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -21,6 +21,7 @@ InputTransform, Log10, Normalize, + OneHotToNumeric, Round, Warp, ) @@ -196,7 +197,13 @@ def test_normalize(self): self.assertEqual(nlz.mins.shape, torch.Size([1, 1])) self.assertEqual(nlz.ranges.shape, torch.Size([1, 1])) self.assertEqual(len(nlz.indices), 1) - self.assertTrue((nlz.indices == torch.tensor([0], dtype=torch.long)).all()) + nlz.to(device=self.device) + self.assertTrue( + ( + nlz.indices + == torch.tensor([0], dtype=torch.long, device=self.device) + ).all() + ) # test .to other_dtype = torch.float if dtype == torch.double else torch.double @@ -231,7 +238,7 @@ def test_normalize(self): nlz.eval() X_unnlzd = nlz.untransform(X_nlzd) - self.assertAllClose(X, X_unnlzd, atol=1e-4, rtol=1e-4) + self.assertAllClose(X, X_unnlzd, atol=1e-3, rtol=1e-3) expected_bounds = torch.cat( [X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]], dim=-2, @@ -381,17 +388,25 @@ def test_standardize(self): self.assertEqual(stdz.means.shape, torch.Size([1, 1])) self.assertEqual(stdz.stds.shape, torch.Size([1, 1])) self.assertEqual(len(stdz.indices), 1) + stdz.to(device=self.device) self.assertTrue( - torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long)) + torch.equal( + stdz.indices, + torch.tensor([0], dtype=torch.long, device=self.device), + ) ) stdz = InputStandardize(d=2, indices=[0], batch_shape=torch.Size([3])) + stdz.to(device=self.device) self.assertTrue(stdz.training) self.assertEqual(stdz._d, 2) self.assertEqual(stdz.means.shape, torch.Size([3, 1, 1])) self.assertEqual(stdz.stds.shape, torch.Size([3, 1, 1])) self.assertEqual(len(stdz.indices), 1) self.assertTrue( - torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long)) + torch.equal( + stdz.indices, + torch.tensor([0], device=self.device, dtype=torch.long), + ) ) # test jitter @@ -915,6 +930,85 @@ def test_warp_transform(self): warp_tf._set_concentration(i=1, value=3.0) self.assertTrue((warp_tf.concentration1 == 3.0).all()) + def test_one_hot_to_numeric(self): + dim = 8 + # test exception when categoricals are not the trailing dimensions + categorical_features = {0: 2} + with self.assertRaises(ValueError): + OneHotToNumeric(dim=dim, categorical_features=categorical_features) + # categoricals at start and end of X but not in between + categorical_features = {0: 3, 6: 2} + with self.assertRaises(ValueError): + OneHotToNumeric(dim=dim, categorical_features=categorical_features) + for dtype in (torch.float, torch.double): + categorical_features = {6: 2, 3: 3} + tf = OneHotToNumeric(dim=dim, categorical_features=categorical_features) + tf.eval() + self.assertEqual(tf.categorical_features, {3: 3, 6: 2}) + cat1_numeric = torch.randint(0, 3, (3,), device=self.device) + cat1 = one_hot(cat1_numeric, num_classes=3) + cat2_numeric = torch.randint(0, 2, (3,), device=self.device) + cat2 = one_hot(cat2_numeric, num_classes=2) + cont = torch.rand(3, 3, dtype=dtype, device=self.device) + X = torch.cat([cont, cat1, cat2], dim=-1) + # test forward + X_numeric = tf(X) + expected = torch.cat( + [ + cont, + cat1_numeric.view(-1, 1).to(cont), + cat2_numeric.view(-1, 1).to(cont), + ], + dim=-1, + ) + self.assertTrue(torch.equal(X_numeric, expected)) + + # test untransform + X2 = tf.untransform(X_numeric) + self.assertTrue(torch.equal(X2, X)) + + # test no + tf = OneHotToNumeric(dim=dim, categorical_features={}) + tf.eval() + X_tf = tf(X) + self.assertTrue(torch.equal(X, X_tf)) + X2 = tf(X_tf) + self.assertTrue(torch.equal(X2, X_tf)) + + # test no transform on eval + tf2 = OneHotToNumeric( + dim=dim, categorical_features=categorical_features, transform_on_eval=False + ) + tf2.eval() + X_tf = tf2(X) + self.assertTrue(torch.equal(X, X_tf)) + + # test no transform on train + tf2 = OneHotToNumeric( + dim=dim, categorical_features=categorical_features, transform_on_train=False + ) + X_tf = tf2(X) + self.assertTrue(torch.equal(X, X_tf)) + tf2.eval() + X_tf = tf2(X) + self.assertFalse(torch.equal(X, X_tf)) + + # test equals + tf3 = OneHotToNumeric( + dim=dim, categorical_features=categorical_features, transform_on_train=False + ) + self.assertTrue(tf3.equals(tf2)) + # test different transform_on_train + tf3 = OneHotToNumeric( + dim=dim, categorical_features=categorical_features, transform_on_train=True + ) + self.assertFalse(tf3.equals(tf2)) + # test categorical features + tf3 = OneHotToNumeric( + dim=dim, categorical_features={}, transform_on_train=False + ) + self.assertFalse(tf3.equals(tf2)) + class TestAppendFeatures(BotorchTestCase): def test_append_features(self):