diff --git a/botorch/models/transforms/__init__.py b/botorch/models/transforms/__init__.py index b56bf19b75..0a913c5d8f 100644 --- a/botorch/models/transforms/__init__.py +++ b/botorch/models/transforms/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from botorch.models.transforms.factory import get_rounding_input_transform from botorch.models.transforms.input import ( ChainedInputTransform, Normalize, @@ -20,6 +21,7 @@ __all__ = [ + "get_rounding_input_transform", "Bilog", "ChainedInputTransform", "ChainedOutcomeTransform", diff --git a/botorch/models/transforms/factory.py b/botorch/models/transforms/factory.py new file mode 100644 index 0000000000..847fdf1b7c --- /dev/null +++ b/botorch/models/transforms/factory.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import OrderedDict +from typing import Dict, List, Optional + +from botorch.models.transforms.input import ( + ChainedInputTransform, + Normalize, + OneHotToNumeric, + Round, +) +from torch import Tensor + + +def get_rounding_input_transform( + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + initialization: bool = False, + return_numeric: bool = False, + approximate: bool = False, +) -> ChainedInputTransform: + """Get a rounding input transform. + + The rounding function will take inputs from the unit cube, + unnormalize the integers raw search space, round the inputs, + and normalize them back to the unit cube. + + Categoricals are assumed to be one-hot encoded. Integers are + currently assumed to be contiguous ranges (e.g. [1,2,3] and not + [1,5,7]). + + TODO: support non-contiguous sets of integers by modifying + the rounding function. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters. + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + initialization: A boolean indicating whether this exact rounding + function is for initialization. For initialization, the bounds + for are expanded such that the end point of a range is selected + with same probability that an interior point is selected, after + rounding. + return_numeric: A boolean indicating whether to return numeric or + one-hot encoded categoricals. Returning a nummeric + representation is helpful if the downstream code (e.g. kernel) + expects a numeric representation of the categoricals. + approximate: A boolean indicating whether to use an approximate + rounding function. + + Returns: + The rounding function ChainedInputTransform. + """ + has_integers = integer_indices is not None and len(integer_indices) > 0 + has_categoricals = ( + categorical_features is not None and len(categorical_features) > 0 + ) + if not (has_integers or has_categoricals): + raise ValueError( + "A rounding function is a no-op " + "if there are no integer or categorical parammeters." + ) + if initialization and has_integers: + # this gives the extreme integer values (end points) + # the same probability as the interior values of the range + init_one_hot_bounds = one_hot_bounds.clone() + init_one_hot_bounds[0, integer_indices] -= 0.4999 + init_one_hot_bounds[1, integer_indices] += 0.4999 + else: + init_one_hot_bounds = one_hot_bounds + + tfs = OrderedDict() + if has_integers: + # unnormalize to integer space + tfs["unnormalize_tf"] = Normalize( + d=init_one_hot_bounds.shape[1], + bounds=init_one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=True, + reverse=True, + ) + # round + tfs["round"] = Round( + approximate=approximate, + transform_on_train=False, + transform_on_fantasize=True, + integer_indices=integer_indices, + categorical_features=categorical_features, + ) + if has_integers: + # renormalize to unit cube + tfs["normalize_tf"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=True, + reverse=False, + ) + if return_numeric and has_categoricals: + tfs["one_hot_to_numeric"] = OneHotToNumeric( + # this is the dimension using one-hot encoded representation + dim=one_hot_bounds.shape[-1], + categorical_features=categorical_features, + transform_on_train=True, + transform_on_eval=True, + transform_on_fantasize=True, + ) + tf = ChainedInputTransform(**tfs) + tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device) + tf.eval() + return tf diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index cd100bea8f..44dbaa2c0a 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -1383,9 +1383,9 @@ def __init__( self, dim: int, categorical_features: Optional[Dict[int, int]] = None, - transform_on_train: bool = False, + transform_on_train: bool = True, transform_on_eval: bool = True, - transform_on_fantasize: bool = False, + transform_on_fantasize: bool = True, ) -> None: r"""Initialize. diff --git a/botorch/models/transforms/utils.py b/botorch/models/transforms/utils.py index a17e11c494..6d1f9411d2 100644 --- a/botorch/models/transforms/utils.py +++ b/botorch/models/transforms/utils.py @@ -7,7 +7,6 @@ from __future__ import annotations from functools import wraps - from typing import Tuple import torch diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index fa6d74c5f6..0660a7cc1f 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -136,6 +136,11 @@ Input Transforms .. automodule:: botorch.models.transforms.input :members: +Transform Factory Methods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.transforms.factory + :members: + Transform Utilities ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.models.transforms.utils diff --git a/test/models/transforms/test_factory.py b/test/models/transforms/test_factory.py new file mode 100644 index 0000000000..aa9c11aaac --- /dev/null +++ b/test/models/transforms/test_factory.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from botorch.models.transforms.factory import get_rounding_input_transform +from botorch.models.transforms.input import ChainedInputTransform, Normalize, Round +from botorch.utils.rounding import OneHotArgmaxSTE +from botorch.utils.testing import BotorchTestCase +from botorch.utils.transforms import normalize, unnormalize + + +class TestGetRoundingInputTransform(BotorchTestCase): + def test_get_rounding_input_transform(self): + for dtype in (torch.float, torch.double): + one_hot_bounds = torch.tensor( + [ + [0, 5], + [0, 4], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + ], + dtype=dtype, + device=self.device, + ).t() + with self.assertRaises(ValueError): + # test no integer or categorical + get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + ) + integer_indices = [1] + categorical_features = {2: 2, 4: 3} + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + ) + self.assertIsInstance(tf, ChainedInputTransform) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 3) + # test unnormalize + tf_name_i, tf_i = tfs[0] + self.assertEqual(tf_name_i, "unnormalize_tf") + self.assertIsInstance(tf_i, Normalize) + self.assertTrue(tf_i.reverse) + bounds = one_hot_bounds[:, integer_indices] + offset = bounds[:1, :] + coefficient = bounds[1:2, :] - offset + self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) + self.assertTrue(torch.equal(tf_i.offset, offset)) + self.assertEqual(tf_i._d, one_hot_bounds.shape[1]) + self.assertEqual( + tf_i.indices, torch.tensor(integer_indices, device=self.device) + ) + # test round + tf_name_i, tf_i = tfs[1] + self.assertEqual(tf_name_i, "round") + self.assertIsInstance(tf_i, Round) + self.assertEqual(tf_i.integer_indices.tolist(), integer_indices) + self.assertEqual(tf_i.categorical_features, categorical_features) + # test normalize + tf_name_i, tf_i = tfs[2] + self.assertEqual(tf_name_i, "normalize_tf") + self.assertIsInstance(tf_i, Normalize) + self.assertFalse(tf_i.reverse) + self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) + self.assertTrue(torch.equal(tf_i.offset, offset)) + self.assertEqual(tf_i._d, one_hot_bounds.shape[1]) + + # test forward + X = torch.rand( + 2, 4, one_hot_bounds.shape[1], dtype=dtype, device=self.device + ) + X_tf = tf(X) + # assert the continuous param is unaffected + self.assertTrue(torch.equal(X_tf[..., 0], X[..., 0])) + # check that integer params are rounded + X_int = X[..., integer_indices] + unnormalized_X_int = unnormalize(X_int, bounds) + rounded_X_int = normalize(unnormalized_X_int.round(), bounds) + self.assertTrue(torch.equal(rounded_X_int, X_tf[..., integer_indices])) + # check that categoricals are discretized + for start, card in categorical_features.items(): + end = start + card + discretized_feat = OneHotArgmaxSTE.apply(X[..., start:end]) + self.assertTrue(torch.equal(discretized_feat, X_tf[..., start:end])) + # test transform on train/eval/fantasize + for tf_i in tf.values(): + self.assertFalse(tf_i.transform_on_train) + self.assertTrue(tf_i.transform_on_eval) + self.assertTrue(tf_i.transform_on_fantasize) + + # test no integer + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + ) + tfs = list(tf.items()) + # round should be the only transform + self.assertEqual(len(tfs), 1) + tf_name_i, tf_i = tfs[0] + self.assertEqual(tf_name_i, "round") + self.assertIsInstance(tf_i, Round) + self.assertEqual(tf_i.integer_indices.tolist(), []) + self.assertEqual(tf_i.categorical_features, categorical_features) + # test no categoricals + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + ) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 3) + _, tf_i = tfs[1] + self.assertEqual(tf_i.integer_indices.tolist(), integer_indices) + self.assertEqual(tf_i.categorical_features, {}) + # test initialization + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + initialization=True, + ) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 3) + # check that bounds are adjusted for integers on unnormalize + _, tf_i = tfs[0] + offset_init = bounds[:1, :] - 0.4999 + coefficient_init = bounds[1:2, :] + 0.4999 - offset_init + self.assertTrue(torch.equal(tf_i.coefficient, coefficient_init)) + self.assertTrue(torch.equal(tf_i.offset, offset_init)) + # check that bounds are adjusted for integers on normalize + _, tf_i = tfs[2] + self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) + self.assertTrue(torch.equal(tf_i.offset, offset)) + # test return numeric + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + return_numeric=True, + ) + tfs = list(tf.items()) + self.assertEqual(len(tfs), 4) + tf_name_i, tf_i = tfs[3] + self.assertEqual(tf_name_i, "one_hot_to_numeric") + # transform to numeric on train + # (e.g. for kernels that expect a integer representation) + self.assertTrue(tf_i.transform_on_train) + self.assertTrue(tf_i.transform_on_eval) + self.assertTrue(tf_i.transform_on_fantasize) + self.assertEqual(tf_i.categorical_features, categorical_features) + self.assertEqual(tf_i.numeric_dim, 4) + # test return numeric and no categorical + tf = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + return_numeric=True, + ) + tfs = list(tf.items()) + # there should be no one hot to numeric transform + self.assertEqual(len(tfs), 3)