Skip to content

Commit

Permalink
add one hot to numeric input transform (pytorch#1517)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1517

see title

Differential Revision: https://internalfb.com/D41482322

fbshipit-source-id: 824c1955d71054101c8317b2d6799aa96611799a
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 4, 2023
1 parent d026e80 commit a5cb0d2
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 0 deletions.
133 changes: 133 additions & 0 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch import nn, Tensor
from torch.distributions import Kumaraswamy
from torch.nn import Module, ModuleDict
from torch.nn.functional import one_hot


class InputTransform(ABC):
Expand Down Expand Up @@ -1370,3 +1371,135 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor:
else:
p = p(X) if self.indices is None else p(X[..., self.indices])
return p.transpose(-3, -2) # p is batch_shape x n_p x n x d


class OneHotToNumeric(InputTransform, Module):
r"""Transform categorical parameters from a one-hot to a numeric representation.
This assumes that the categoricals are the trailing dimensions.
"""

def __init__(
self,
dim: int,
categorical_features: Optional[Dict[int, int]] = None,
transform_on_train: bool = False,
transform_on_eval: bool = True,
transform_on_fantasize: bool = False,
) -> None:
r"""Initialize.
Args:
dim: The dimension of the one-hot-encoded input.
categorical_features: A dictionary mapping the starting index of each
categorical feature to its cardinality. This assumes that categoricals
are one-hot encoded.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: False.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: False.
Returns:
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
categoricals are transformed to integer representation.
"""
super().__init__()
self.transform_on_train = transform_on_train
self.transform_on_eval = transform_on_eval
self.transform_on_fantasize = transform_on_fantasize
categorical_features = categorical_features or {}
# sort by starting index
self.categorical_features = OrderedDict(
sorted(categorical_features.items(), key=lambda x: x[0])
)
if len(self.categorical_features) > 0:
self.categorical_start_idx = min(self.categorical_features.keys())
# check that the trailing dimensions are categoricals
end = self.categorical_start_idx
err_msg = (
f"{self.__class__.__name__} requires that the categorical "
"parameters are the rightmost elements."
)
for start, card in self.categorical_features.items():
# the end of one one-hot representation should be followed
# by the start of the next
if end != start:
raise ValueError(err_msg)
# This assumes that the categoricals are the trailing
# dimensions
end = start + card
if end != dim:
# check end
raise ValueError(err_msg)
# the numeric representation dimension is the total number of parameters
# (continuous, integer, and categorical)
self.numeric_dim = self.categorical_start_idx + len(categorical_features)

def transform(self, X: Tensor) -> Tensor:
r"""Transform the categorical inputs into integer representation.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
categoricals are transformed to integer representation.
"""
if len(self.categorical_features) > 0:
X_numeric = X[..., : self.numeric_dim].clone()
idx = self.categorical_start_idx
for start, card in self.categorical_features.items():
X_numeric[..., idx] = X[..., start : start + card].argmax(dim=-1)
idx += 1
return X_numeric
return X

def untransform(self, X: Tensor) -> Tensor:
r"""Transform the categoricals from integer representation to one-hot.
Args:
X: A `batch_shape x n x d'`-dim tensor of transformed inputs, where
the categoricals are represented as integers.
Returns:
A `batch_shape x n x d`-dim tensor of inputs, where the categoricals
have been transformed to one-hot representation.
"""
if len(self.categorical_features) > 0:
self.numeric_dim
one_hot_categoricals = [
# note that self.categorical_features is sorted by the starting index
# in one-hot representation
one_hot(
X[..., idx - len(self.categorical_features)].long(),
num_classes=cardinality,
)
for idx, cardinality in enumerate(self.categorical_features.values())
]
X = torch.cat(
[
X[..., : self.categorical_start_idx],
*one_hot_categoricals,
],
dim=-1,
)
return X

def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.
Args:
other: Another input transform.
Returns:
A boolean indicating if the other transform is equivalent.
"""
return (
type(self) == type(other)
and (self.transform_on_train == other.transform_on_train)
and (self.transform_on_eval == other.transform_on_eval)
and (self.transform_on_fantasize == other.transform_on_fantasize)
and self.categorical_features == other.categorical_features
)
80 changes: 80 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
InputTransform,
Log10,
Normalize,
OneHotToNumeric,
Round,
Warp,
)
Expand Down Expand Up @@ -915,6 +916,85 @@ def test_warp_transform(self):
warp_tf._set_concentration(i=1, value=3.0)
self.assertTrue((warp_tf.concentration1 == 3.0).all())

def test_one_hot_to_numeric(self):
dim = 8
# test exception when categoricals are not the trailing dimensions
categorical_features = {0: 2}
with self.assertRaises(ValueError):
OneHotToNumeric(dim=dim, categorical_features=categorical_features)
# categoricals at start and end of X but not in between
categorical_features = {0: 3, 6: 2}
with self.assertRaises(ValueError):
OneHotToNumeric(dim=dim, categorical_features=categorical_features)
for dtype in (torch.float, torch.double):
categorical_features = {6: 2, 3: 3}
tf = OneHotToNumeric(dim=dim, categorical_features=categorical_features)
tf.eval()
self.assertEqual(tf.categorical_features, {3: 3, 6: 2})
cat1_numeric = torch.randint(0, 3, (3,), device=self.device)
cat1 = one_hot(cat1_numeric, num_classes=3)
cat2_numeric = torch.randint(0, 2, (3,), device=self.device)
cat2 = one_hot(cat2_numeric, num_classes=2)
cont = torch.rand(3, 3, dtype=dtype, device=self.device)
X = torch.cat([cont, cat1, cat2], dim=-1)
# test forward
X_numeric = tf(X)
expected = torch.cat(
[
cont,
cat1_numeric.view(-1, 1).to(cont),
cat2_numeric.view(-1, 1).to(cont),
],
dim=-1,
)
self.assertTrue(torch.equal(X_numeric, expected))

# test untransform
X2 = tf.untransform(X_numeric)
self.assertTrue(torch.equal(X2, X))

# test no
tf = OneHotToNumeric(dim=dim, categorical_features={})
tf.eval()
X_tf = tf(X)
self.assertTrue(torch.equal(X, X_tf))
X2 = tf(X_tf)
self.assertTrue(torch.equal(X2, X_tf))

# test no transform on eval
tf2 = OneHotToNumeric(
dim=dim, categorical_features=categorical_features, transform_on_eval=False
)
tf2.eval()
X_tf = tf2(X)
self.assertTrue(torch.equal(X, X_tf))

# test no transform on train
tf2 = OneHotToNumeric(
dim=dim, categorical_features=categorical_features, transform_on_train=False
)
X_tf = tf2(X)
self.assertTrue(torch.equal(X, X_tf))
tf2.eval()
X_tf = tf2(X)
self.assertFalse(torch.equal(X, X_tf))

# test equals
tf3 = OneHotToNumeric(
dim=dim, categorical_features=categorical_features, transform_on_train=False
)
self.assertTrue(tf3.equals(tf2))
# test different transform_on_train
tf3 = OneHotToNumeric(
dim=dim, categorical_features=categorical_features, transform_on_train=True
)
self.assertFalse(tf3.equals(tf2))
# test categorical features
tf3 = OneHotToNumeric(
dim=dim, categorical_features={}, transform_on_train=False
)
self.assertFalse(tf3.equals(tf2))


class TestAppendFeatures(BotorchTestCase):
def test_append_features(self):
Expand Down

0 comments on commit a5cb0d2

Please sign in to comment.