forked from pytorch/botorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add utility for constructing rounding input transforms (pytorch#1531)
Summary: Pull Request resolved: pytorch#1531 see title Differential Revision: https://internalfb.com/D41497584 fbshipit-source-id: f9a4b481473393a7664766a4d16f7e2051dcd190
- Loading branch information
1 parent
ef44773
commit 3401848
Showing
6 changed files
with
296 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
from collections import OrderedDict | ||
from typing import Dict, List, Optional | ||
|
||
from botorch.models.transforms.input import ( | ||
ChainedInputTransform, | ||
Normalize, | ||
OneHotToNumeric, | ||
Round, | ||
) | ||
from torch import Tensor | ||
|
||
|
||
def get_rounding_input_transform( | ||
one_hot_bounds: Tensor, | ||
integer_indices: Optional[List[int]] = None, | ||
categorical_features: Optional[Dict[int, int]] = None, | ||
initialization: bool = False, | ||
return_numeric: bool = False, | ||
approximate: bool = False, | ||
) -> ChainedInputTransform: | ||
"""Get a rounding input transform. | ||
The rounding function will take inputs from the unit cube, | ||
unnormalize the integers raw search space, round the inputs, | ||
and normalize them back to the unit cube. | ||
Categoricals are assumed to be one-hot encoded. Integers are | ||
currently assumed to be contiguous ranges (e.g. [1,2,3] and not | ||
[1,5,7]). | ||
TODO: support non-contiguous sets of integers by modifying | ||
the rounding function. | ||
Args: | ||
one_hot_bounds: The raw search space bounds where categoricals are | ||
encoded in one-hot representation and the integer parameters | ||
are not normalized. | ||
integer_indices: The indices of the integer parameters | ||
categorical_features: A dictionary mapping indices to cardinalities | ||
for the categorical features. | ||
initialization: A boolean indication whether this exact rounding | ||
function is for initialization. | ||
return_numeric: A boolean indicating whether to return numeric or | ||
one-hot encoded categoricals. Returning a nummeric | ||
representation is helpful if the downstream code (e.g. kernel) | ||
expects a numeric representation of the categoricals | ||
Returns: | ||
The rounding function ChainedInputTransform. | ||
""" | ||
has_integers = integer_indices is not None and len(integer_indices) > 0 | ||
has_categoricals = ( | ||
categorical_features is not None and len(categorical_features) > 0 | ||
) | ||
if not (has_integers or has_categoricals): | ||
raise ValueError( | ||
"A rounding function is a no-op " | ||
"if there are no integer or categorical parammeters." | ||
) | ||
if initialization and has_integers: | ||
# this gives the extreme integer values (end points) | ||
# the same probability as the interior values of the range | ||
init_one_hot_bounds = one_hot_bounds.clone() | ||
init_one_hot_bounds[0, integer_indices] -= 0.4999 | ||
init_one_hot_bounds[1, integer_indices] += 0.4999 | ||
else: | ||
init_one_hot_bounds = one_hot_bounds | ||
|
||
tfs = OrderedDict() | ||
if has_integers: | ||
# unnormalize to integer space | ||
tfs["unnormalize_tf"] = Normalize( | ||
d=init_one_hot_bounds.shape[1], | ||
bounds=init_one_hot_bounds, | ||
indices=integer_indices, | ||
transform_on_train=False, | ||
transform_on_eval=True, | ||
transform_on_fantasize=True, | ||
reverse=True, | ||
) | ||
# round | ||
tfs["round"] = Round( | ||
approximate=approximate, | ||
transform_on_train=False, | ||
transform_on_fantasize=True, | ||
integer_indices=integer_indices, | ||
categorical_features=categorical_features, | ||
) | ||
if has_integers: | ||
# renormalize to unit cube | ||
tfs["normalize_tf"] = Normalize( | ||
d=one_hot_bounds.shape[1], | ||
bounds=one_hot_bounds, | ||
indices=integer_indices, | ||
transform_on_train=False, | ||
transform_on_eval=True, | ||
transform_on_fantasize=True, | ||
reverse=False, | ||
) | ||
if return_numeric and has_categoricals: | ||
tfs["one_hot_to_numeric"] = OneHotToNumeric( | ||
# this is the dimension using one-hot encoded representation | ||
dim=one_hot_bounds.shape[-1], | ||
categorical_features=categorical_features, | ||
transform_on_train=True, | ||
transform_on_eval=True, | ||
transform_on_fantasize=True, | ||
) | ||
tf = ChainedInputTransform(**tfs) | ||
tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device) | ||
tf.eval() | ||
return tf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,6 @@ | |
from __future__ import annotations | ||
|
||
from functools import wraps | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import torch | ||
from botorch.models.transforms.factory import get_rounding_input_transform | ||
from botorch.models.transforms.input import ChainedInputTransform, Normalize, Round | ||
from botorch.utils.rounding import OneHotArgmaxSTE | ||
from botorch.utils.testing import BotorchTestCase | ||
from botorch.utils.transforms import normalize, unnormalize | ||
|
||
|
||
class TestGetRoundingInputTransform(BotorchTestCase): | ||
def test_get_rounding_input_transform(self): | ||
for dtype in (torch.float, torch.double): | ||
one_hot_bounds = torch.tensor( | ||
[ | ||
[0, 5], | ||
[0, 4], | ||
[0, 1], | ||
[0, 1], | ||
[0, 1], | ||
[0, 1], | ||
[0, 1], | ||
], | ||
dtype=dtype, | ||
device=self.device, | ||
).t() | ||
with self.assertRaises(ValueError): | ||
# test no integer or categorical | ||
get_rounding_input_transform( | ||
one_hot_bounds=one_hot_bounds, | ||
) | ||
integer_indices = [1] | ||
categorical_features = {2: 2, 4: 3} | ||
tf = get_rounding_input_transform( | ||
one_hot_bounds=one_hot_bounds, | ||
integer_indices=integer_indices, | ||
categorical_features=categorical_features, | ||
) | ||
self.assertIsInstance(tf, ChainedInputTransform) | ||
tfs = list(tf.items()) | ||
self.assertEqual(len(tfs), 3) | ||
# test unnormalize | ||
tf_name_i, tf_i = tfs[0] | ||
self.assertEqual(tf_name_i, "unnormalize_tf") | ||
self.assertIsInstance(tf_i, Normalize) | ||
self.assertTrue(tf_i.reverse) | ||
bounds = one_hot_bounds[:, integer_indices] | ||
offset = bounds[:1, :] | ||
coefficient = bounds[1:2, :] - offset | ||
self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) | ||
self.assertTrue(torch.equal(tf_i.offset, offset)) | ||
self.assertEqual(tf_i._d, one_hot_bounds.shape[1]) | ||
self.assertEqual( | ||
tf_i.indices, torch.tensor(integer_indices, device=self.device) | ||
) | ||
# test round | ||
tf_name_i, tf_i = tfs[1] | ||
self.assertEqual(tf_name_i, "round") | ||
self.assertIsInstance(tf_i, Round) | ||
self.assertEqual(tf_i.integer_indices.tolist(), integer_indices) | ||
self.assertEqual(tf_i.categorical_features, categorical_features) | ||
# test normalize | ||
tf_name_i, tf_i = tfs[2] | ||
self.assertEqual(tf_name_i, "normalize_tf") | ||
self.assertIsInstance(tf_i, Normalize) | ||
self.assertFalse(tf_i.reverse) | ||
self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) | ||
self.assertTrue(torch.equal(tf_i.offset, offset)) | ||
self.assertEqual(tf_i._d, one_hot_bounds.shape[1]) | ||
|
||
# test forward | ||
X = torch.rand( | ||
2, 4, one_hot_bounds.shape[1], dtype=dtype, device=self.device | ||
) | ||
X_tf = tf(X) | ||
# assert the continuous param is unaffected | ||
self.assertTrue(torch.equal(X_tf[..., 0], X[..., 0])) | ||
# check that integer params are rounded | ||
X_int = X[..., integer_indices] | ||
unnormalized_X_int = unnormalize(X_int, bounds) | ||
rounded_X_int = normalize(unnormalized_X_int.round(), bounds) | ||
self.assertTrue(torch.equal(rounded_X_int, X_tf[..., integer_indices])) | ||
# check that categoricals are discretized | ||
for start, card in categorical_features.items(): | ||
end = start + card | ||
discretized_feat = OneHotArgmaxSTE.apply(X[..., start:end]) | ||
self.assertTrue(torch.equal(discretized_feat, X_tf[..., start:end])) | ||
# test transform on train/eval/fantasize | ||
for tf_i in tf.values(): | ||
self.assertFalse(tf_i.transform_on_train) | ||
self.assertTrue(tf_i.transform_on_eval) | ||
self.assertTrue(tf_i.transform_on_fantasize) | ||
|
||
# test no integer | ||
tf = get_rounding_input_transform( | ||
one_hot_bounds=one_hot_bounds, | ||
categorical_features=categorical_features, | ||
) | ||
tfs = list(tf.items()) | ||
# round should be the only transform | ||
self.assertEqual(len(tfs), 1) | ||
tf_name_i, tf_i = tfs[0] | ||
self.assertEqual(tf_name_i, "round") | ||
self.assertIsInstance(tf_i, Round) | ||
self.assertEqual(tf_i.integer_indices.tolist(), []) | ||
self.assertEqual(tf_i.categorical_features, categorical_features) | ||
# test no categoricals | ||
tf = get_rounding_input_transform( | ||
one_hot_bounds=one_hot_bounds, | ||
integer_indices=integer_indices, | ||
) | ||
tfs = list(tf.items()) | ||
self.assertEqual(len(tfs), 3) | ||
_, tf_i = tfs[1] | ||
self.assertEqual(tf_i.integer_indices.tolist(), integer_indices) | ||
self.assertEqual(tf_i.categorical_features, {}) | ||
# test initialization | ||
tf = get_rounding_input_transform( | ||
one_hot_bounds=one_hot_bounds, | ||
integer_indices=integer_indices, | ||
categorical_features=categorical_features, | ||
initialization=True, | ||
) | ||
tfs = list(tf.items()) | ||
self.assertEqual(len(tfs), 3) | ||
# check that bounds are adjusted for integers on unnormalize | ||
_, tf_i = tfs[0] | ||
offset_init = bounds[:1, :] - 0.4999 | ||
coefficient_init = bounds[1:2, :] + 0.4999 - offset_init | ||
self.assertTrue(torch.equal(tf_i.coefficient, coefficient_init)) | ||
self.assertTrue(torch.equal(tf_i.offset, offset_init)) | ||
# check that bounds are adjusted for integers on normalize | ||
_, tf_i = tfs[2] | ||
self.assertTrue(torch.equal(tf_i.coefficient, coefficient)) | ||
self.assertTrue(torch.equal(tf_i.offset, offset)) | ||
# test return numeric | ||
tf = get_rounding_input_transform( | ||
one_hot_bounds=one_hot_bounds, | ||
integer_indices=integer_indices, | ||
categorical_features=categorical_features, | ||
return_numeric=True, | ||
) | ||
tfs = list(tf.items()) | ||
self.assertEqual(len(tfs), 4) | ||
tf_name_i, tf_i = tfs[3] | ||
self.assertEqual(tf_name_i, "one_hot_to_numeric") | ||
# transform to numeric on train | ||
# (e.g. for kernels that expect a integer representation) | ||
self.assertTrue(tf_i.transform_on_train) | ||
self.assertTrue(tf_i.transform_on_eval) | ||
self.assertTrue(tf_i.transform_on_fantasize) | ||
self.assertEqual(tf_i.categorical_features, categorical_features) | ||
self.assertEqual(tf_i.numeric_dim, 4) | ||
# test return numeric and no categorical | ||
tf = get_rounding_input_transform( | ||
one_hot_bounds=one_hot_bounds, | ||
integer_indices=integer_indices, | ||
return_numeric=True, | ||
) | ||
tfs = list(tf.items()) | ||
# there should be no one hot to numeric transform | ||
self.assertEqual(len(tfs), 3) |