Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add utility for constructing rounding input transforms #1531

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions botorch/models/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from botorch.models.transforms.factory import get_rounding_input_transform
from botorch.models.transforms.input import (
ChainedInputTransform,
Normalize,
Expand All @@ -20,6 +21,7 @@


__all__ = [
"get_rounding_input_transform",
"Bilog",
"ChainedInputTransform",
"ChainedOutcomeTransform",
Expand Down
125 changes: 125 additions & 0 deletions botorch/models/transforms/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections import OrderedDict
from typing import Dict, List, Optional

from botorch.models.transforms.input import (
ChainedInputTransform,
Normalize,
OneHotToNumeric,
Round,
)
from torch import Tensor


def get_rounding_input_transform(
one_hot_bounds: Tensor,
integer_indices: Optional[List[int]] = None,
categorical_features: Optional[Dict[int, int]] = None,
initialization: bool = False,
return_numeric: bool = False,
approximate: bool = False,
) -> ChainedInputTransform:
"""Get a rounding input transform.
The rounding function will take inputs from the unit cube,
unnormalize the integers raw search space, round the inputs,
and normalize them back to the unit cube.
Categoricals are assumed to be one-hot encoded. Integers are
currently assumed to be contiguous ranges (e.g. [1,2,3] and not
[1,5,7]).
TODO: support non-contiguous sets of integers by modifying
the rounding function.
Args:
one_hot_bounds: The raw search space bounds where categoricals are
encoded in one-hot representation and the integer parameters
are not normalized.
integer_indices: The indices of the integer parameters.
categorical_features: A dictionary mapping indices to cardinalities
for the categorical features.
initialization: A boolean indicating whether this exact rounding
function is for initialization. For initialization, the bounds
for are expanded such that the end point of a range is selected
with same probability that an interior point is selected, after
rounding.
return_numeric: A boolean indicating whether to return numeric or
one-hot encoded categoricals. Returning a nummeric
representation is helpful if the downstream code (e.g. kernel)
expects a numeric representation of the categoricals.
approximate: A boolean indicating whether to use an approximate
rounding function.
Returns:
The rounding function ChainedInputTransform.
"""
has_integers = integer_indices is not None and len(integer_indices) > 0
has_categoricals = (
categorical_features is not None and len(categorical_features) > 0
)
if not (has_integers or has_categoricals):
raise ValueError(
"A rounding function is a no-op "
"if there are no integer or categorical parammeters."
)
if initialization and has_integers:
# this gives the extreme integer values (end points)
# the same probability as the interior values of the range
init_one_hot_bounds = one_hot_bounds.clone()
init_one_hot_bounds[0, integer_indices] -= 0.4999
init_one_hot_bounds[1, integer_indices] += 0.4999
else:
init_one_hot_bounds = one_hot_bounds

tfs = OrderedDict()
if has_integers:
# unnormalize to integer space
tfs["unnormalize_tf"] = Normalize(
d=init_one_hot_bounds.shape[1],
bounds=init_one_hot_bounds,
indices=integer_indices,
transform_on_train=False,
transform_on_eval=True,
transform_on_fantasize=True,
reverse=True,
)
# round
tfs["round"] = Round(
approximate=approximate,
transform_on_train=False,
transform_on_fantasize=True,
integer_indices=integer_indices,
categorical_features=categorical_features,
)
if has_integers:
# renormalize to unit cube
tfs["normalize_tf"] = Normalize(
d=one_hot_bounds.shape[1],
bounds=one_hot_bounds,
indices=integer_indices,
transform_on_train=False,
transform_on_eval=True,
transform_on_fantasize=True,
reverse=False,
)
if return_numeric and has_categoricals:
tfs["one_hot_to_numeric"] = OneHotToNumeric(
# this is the dimension using one-hot encoded representation
dim=one_hot_bounds.shape[-1],
categorical_features=categorical_features,
transform_on_train=True,
transform_on_eval=True,
transform_on_fantasize=True,
)
tf = ChainedInputTransform(**tfs)
tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device)
tf.eval()
return tf
135 changes: 134 additions & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch import nn, Tensor
from torch.distributions import Kumaraswamy
from torch.nn import Module, ModuleDict
from torch.nn.functional import one_hot


class InputTransform(ABC):
Expand Down Expand Up @@ -364,7 +365,7 @@ def __init__(
raise ValueError("Elements of `indices` have to be smaller than `d`!")
if len(indices.unique()) != len(indices):
raise ValueError("Elements of `indices` tensor must be unique!")
self.indices = indices
self.register_buffer("indices", indices)
torch.broadcast_shapes(coefficient.shape, offset.shape)

self._d = d
Expand Down Expand Up @@ -1370,3 +1371,135 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor:
else:
p = p(X) if self.indices is None else p(X[..., self.indices])
return p.transpose(-3, -2) # p is batch_shape x n_p x n x d


class OneHotToNumeric(InputTransform, Module):
r"""Transform categorical parameters from a one-hot to a numeric representation.

This assumes that the categoricals are the trailing dimensions.
"""

def __init__(
self,
dim: int,
categorical_features: Optional[Dict[int, int]] = None,
transform_on_train: bool = True,
transform_on_eval: bool = True,
transform_on_fantasize: bool = True,
) -> None:
r"""Initialize.

Args:
dim: The dimension of the one-hot-encoded input.
categorical_features: A dictionary mapping the starting index of each
categorical feature to its cardinality. This assumes that categoricals
are one-hot encoded.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: False.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: False.

Returns:
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
categoricals are transformed to integer representation.
"""
super().__init__()
self.transform_on_train = transform_on_train
self.transform_on_eval = transform_on_eval
self.transform_on_fantasize = transform_on_fantasize
categorical_features = categorical_features or {}
# sort by starting index
self.categorical_features = OrderedDict(
sorted(categorical_features.items(), key=lambda x: x[0])
)
if len(self.categorical_features) > 0:
self.categorical_start_idx = min(self.categorical_features.keys())
# check that the trailing dimensions are categoricals
end = self.categorical_start_idx
err_msg = (
f"{self.__class__.__name__} requires that the categorical "
"parameters are the rightmost elements."
)
for start, card in self.categorical_features.items():
# the end of one one-hot representation should be followed
# by the start of the next
if end != start:
raise ValueError(err_msg)
# This assumes that the categoricals are the trailing
# dimensions
end = start + card
if end != dim:
# check end
raise ValueError(err_msg)
# the numeric representation dimension is the total number of parameters
# (continuous, integer, and categorical)
self.numeric_dim = self.categorical_start_idx + len(categorical_features)

def transform(self, X: Tensor) -> Tensor:
r"""Transform the categorical inputs into integer representation.

Args:
X: A `batch_shape x n x d`-dim tensor of inputs.

Returns:
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
categoricals are transformed to integer representation.
"""
if len(self.categorical_features) > 0:
X_numeric = X[..., : self.numeric_dim].clone()
idx = self.categorical_start_idx
for start, card in self.categorical_features.items():
X_numeric[..., idx] = X[..., start : start + card].argmax(dim=-1)
idx += 1
return X_numeric
return X

def untransform(self, X: Tensor) -> Tensor:
r"""Transform the categoricals from integer representation to one-hot.

Args:
X: A `batch_shape x n x d'`-dim tensor of transformed inputs, where
the categoricals are represented as integers.

Returns:
A `batch_shape x n x d`-dim tensor of inputs, where the categoricals
have been transformed to one-hot representation.
"""
if len(self.categorical_features) > 0:
self.numeric_dim
one_hot_categoricals = [
# note that self.categorical_features is sorted by the starting index
# in one-hot representation
one_hot(
X[..., idx - len(self.categorical_features)].long(),
num_classes=cardinality,
)
for idx, cardinality in enumerate(self.categorical_features.values())
]
X = torch.cat(
[
X[..., : self.categorical_start_idx],
*one_hot_categoricals,
],
dim=-1,
)
return X

def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.

Args:
other: Another input transform.

Returns:
A boolean indicating if the other transform is equivalent.
"""
return (
type(self) == type(other)
and (self.transform_on_train == other.transform_on_train)
and (self.transform_on_eval == other.transform_on_eval)
and (self.transform_on_fantasize == other.transform_on_fantasize)
and self.categorical_features == other.categorical_features
)
1 change: 0 additions & 1 deletion botorch/models/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations

from functools import wraps

from typing import Tuple

import torch
Expand Down
5 changes: 5 additions & 0 deletions sphinx/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ Input Transforms
.. automodule:: botorch.models.transforms.input
:members:

Transform Factory Methods
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.transforms.factory
:members:

Transform Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.transforms.utils
Expand Down
Loading