Skip to content

Commit

Permalink
add utilities for straight-through gradient estimators for discretiza…
Browse files Browse the repository at this point in the history
…tion functions

Summary: see title

Differential Revision: D41475380

fbshipit-source-id: c9f003ec76ccd75e6db42e44f09a12e6fb34feb5
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 22, 2022
1 parent 92f0d1d commit be38486
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 2 deletions.
3 changes: 2 additions & 1 deletion botorch/test_functions/multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
.. [Daulton2022]
S. Daulton, S. Cakmak, M. Balandat, M. A. Osborne, E. Zhou, and E. Bakshy.
Robust Multi-Objective Bayesian Optimization Under Input Noise. 2022.
Robust Multi-Objective Bayesian Optimization Under Input Noise.
Proceedings of the 39th International Conference on Machine Learning, 2022.
.. [Deb2005dtlz]
K. Deb, L. Thiele, M. Laumanns, E. Zitzler, A. Abraham, L. Jain, and
Expand Down
87 changes: 87 additions & 0 deletions botorch/utils/rounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,22 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Discretization (rounding) functions for acquisition optimization.
.. [Daulton2022bopr]
S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy.
Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic
Reparameterization. Advances in Neural Information Processing Systems
35, 2022.
"""

from __future__ import annotations

import torch
from torch import Tensor
from torch.autograd import Function
from torch.nn.functional import one_hot


def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
Expand All @@ -27,3 +39,78 @@ def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
scaled_remainder = (X - offset - 0.5) / tau
rounding_component = (torch.tanh(scaled_remainder) + 1) / 2
return offset + rounding_component


class RoundSTE(Function):
r"""Round the input tensor and use a straight-through gradient estimator.
[Daulton2022bopr]_ proposes using this in acquisition optimization.
"""

@staticmethod
def forward(
ctx,
X: Tensor,
) -> Tensor:
r"""Round the input tensor element-wise.
Args:
X: The tensor to be rounded.
Returns:
A tensor where each element is rounded to the nearest integer.
"""
return X.round()

@staticmethod
def backward(ctx, grad_output: Tensor) -> Tensor:
r"""Use a straight-through estimator the gradient of the rounding function.
This uses the identity function.
Args:
grad_output: A tensor of gradients.
Returns:
The provided tensor.
"""
return grad_output


class OneHotArgmaxSTE(Function):
r"""Discretize a continuous relaxation of a one-hot encoded categorical.
This returns a one-hot encoded categorical and use a straight-through
gradient estimator via an identity function.
[Daulton2022bopr]_ proposes using this in acquisition optimization.
"""

@staticmethod
def forward(ctx, X: Tensor) -> Tensor:
r"""Discretize the input tensor.
This applies a argmax along the last dimensions of the input tensor
and one-hot encodes the result.
Args:
X: The tensor to be rounded.
Returns:
A tensor where each element is rounded to the nearest integer.
"""
return one_hot(X.argmax(dim=-1), num_classes=X.shape[-1]).to(X)

@staticmethod
def backward(ctx, grad_output: Tensor) -> Tensor:
r"""Use a straight-through estimator the gradient of the discretization function.
This uses the identity function.
Args:
grad_output: A tensor of gradients.
Returns:
The provided tensor.
"""
return grad_output
42 changes: 41 additions & 1 deletion test/utils/test_rounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@


import torch
from botorch.utils.rounding import approximate_round
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
from botorch.utils.testing import BotorchTestCase
from torch.nn.functional import one_hot


class TestApproximateRound(BotorchTestCase):
Expand All @@ -25,3 +26,42 @@ def test_approximate_round(self):
X.requires_grad_(True)
approximate_round(X).sum().backward()
self.assertTrue((X.grad.abs() != 0).any())


class TestRoundSTE(BotorchTestCase):
def test_round_ste(self):
for dtype in (torch.float, torch.double):
# sample uniformly from the interval [-2.5,2.5]
X = torch.rand(5, 2, device=self.device, dtype=dtype) * 5 - 2.5
expected_rounded_X = X.round()
rounded_X = RoundSTE.apply(X)
# test forward
self.assertTrue(torch.equal(expected_rounded_X, rounded_X))
# test backward
X = X.requires_grad_(True)
output = RoundSTE.apply(X)
# sample some weights to checked that gradients are passed
# as intended
w = torch.rand_like(X)
(w * output).sum().backward()
self.assertTrue(torch.equal(w, X.grad))


class TestOneHotArgmaxSTE(BotorchTestCase):
def test_one_hot_argmax_ste(self):
for dtype in (torch.float, torch.double):
X = torch.rand(5, 4, device=self.device, dtype=dtype)
expected_discretized_X = one_hot(
X.argmax(dim=-1), num_classes=X.shape[-1]
).to(X)
discretized_X = OneHotArgmaxSTE.apply(X)
# test forward
self.assertTrue(torch.equal(expected_discretized_X, discretized_X))
# test backward
X = X.requires_grad_(True)
output = OneHotArgmaxSTE.apply(X)
# sample some weights to checked that gradients are passed
# as intended
w = torch.rand_like(X)
(w * output).sum().backward()
self.assertTrue(torch.equal(w, X.grad))

0 comments on commit be38486

Please sign in to comment.