From be384869efd52fbd5b0f37423a48b10eb8c8e20b Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Tue, 22 Nov 2022 10:39:47 -0800 Subject: [PATCH] add utilities for straight-through gradient estimators for discretization functions Summary: see title Differential Revision: D41475380 fbshipit-source-id: c9f003ec76ccd75e6db42e44f09a12e6fb34feb5 --- botorch/test_functions/multi_objective.py | 3 +- botorch/utils/rounding.py | 87 +++++++++++++++++++++++ test/utils/test_rounding.py | 42 ++++++++++- 3 files changed, 130 insertions(+), 2 deletions(-) diff --git a/botorch/test_functions/multi_objective.py b/botorch/test_functions/multi_objective.py index f430496c8d..a46ea5373a 100644 --- a/botorch/test_functions/multi_objective.py +++ b/botorch/test_functions/multi_objective.py @@ -11,7 +11,8 @@ .. [Daulton2022] S. Daulton, S. Cakmak, M. Balandat, M. A. Osborne, E. Zhou, and E. Bakshy. - Robust Multi-Objective Bayesian Optimization Under Input Noise. 2022. + Robust Multi-Objective Bayesian Optimization Under Input Noise. + Proceedings of the 39th International Conference on Machine Learning, 2022. .. [Deb2005dtlz] K. Deb, L. Thiele, M. Laumanns, E. Zitzler, A. Abraham, L. Jain, and diff --git a/botorch/utils/rounding.py b/botorch/utils/rounding.py index 1da2ff8524..5efe0bc73b 100644 --- a/botorch/utils/rounding.py +++ b/botorch/utils/rounding.py @@ -4,10 +4,22 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +""" +Discretization (rounding) functions for acquisition optimization. + +.. [Daulton2022bopr] + S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy. + Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic + Reparameterization. Advances in Neural Information Processing Systems + 35, 2022. +""" + from __future__ import annotations import torch from torch import Tensor +from torch.autograd import Function +from torch.nn.functional import one_hot def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor: @@ -27,3 +39,78 @@ def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor: scaled_remainder = (X - offset - 0.5) / tau rounding_component = (torch.tanh(scaled_remainder) + 1) / 2 return offset + rounding_component + + +class RoundSTE(Function): + r"""Round the input tensor and use a straight-through gradient estimator. + + [Daulton2022bopr]_ proposes using this in acquisition optimization. + """ + + @staticmethod + def forward( + ctx, + X: Tensor, + ) -> Tensor: + r"""Round the input tensor element-wise. + + Args: + X: The tensor to be rounded. + + Returns: + A tensor where each element is rounded to the nearest integer. + """ + return X.round() + + @staticmethod + def backward(ctx, grad_output: Tensor) -> Tensor: + r"""Use a straight-through estimator the gradient of the rounding function. + + This uses the identity function. + + Args: + grad_output: A tensor of gradients. + + Returns: + The provided tensor. + """ + return grad_output + + +class OneHotArgmaxSTE(Function): + r"""Discretize a continuous relaxation of a one-hot encoded categorical. + + This returns a one-hot encoded categorical and use a straight-through + gradient estimator via an identity function. + + [Daulton2022bopr]_ proposes using this in acquisition optimization. + """ + + @staticmethod + def forward(ctx, X: Tensor) -> Tensor: + r"""Discretize the input tensor. + + This applies a argmax along the last dimensions of the input tensor + and one-hot encodes the result. + + Args: + X: The tensor to be rounded. + + Returns: + A tensor where each element is rounded to the nearest integer. + """ + return one_hot(X.argmax(dim=-1), num_classes=X.shape[-1]).to(X) + + @staticmethod + def backward(ctx, grad_output: Tensor) -> Tensor: + r"""Use a straight-through estimator the gradient of the discretization function. + + This uses the identity function. + + Args: + grad_output: A tensor of gradients. + + Returns: + The provided tensor. + """ + return grad_output diff --git a/test/utils/test_rounding.py b/test/utils/test_rounding.py index ebe17c1e26..7ca4fd2ca9 100644 --- a/test/utils/test_rounding.py +++ b/test/utils/test_rounding.py @@ -6,8 +6,9 @@ import torch -from botorch.utils.rounding import approximate_round +from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE from botorch.utils.testing import BotorchTestCase +from torch.nn.functional import one_hot class TestApproximateRound(BotorchTestCase): @@ -25,3 +26,42 @@ def test_approximate_round(self): X.requires_grad_(True) approximate_round(X).sum().backward() self.assertTrue((X.grad.abs() != 0).any()) + + +class TestRoundSTE(BotorchTestCase): + def test_round_ste(self): + for dtype in (torch.float, torch.double): + # sample uniformly from the interval [-2.5,2.5] + X = torch.rand(5, 2, device=self.device, dtype=dtype) * 5 - 2.5 + expected_rounded_X = X.round() + rounded_X = RoundSTE.apply(X) + # test forward + self.assertTrue(torch.equal(expected_rounded_X, rounded_X)) + # test backward + X = X.requires_grad_(True) + output = RoundSTE.apply(X) + # sample some weights to checked that gradients are passed + # as intended + w = torch.rand_like(X) + (w * output).sum().backward() + self.assertTrue(torch.equal(w, X.grad)) + + +class TestOneHotArgmaxSTE(BotorchTestCase): + def test_one_hot_argmax_ste(self): + for dtype in (torch.float, torch.double): + X = torch.rand(5, 4, device=self.device, dtype=dtype) + expected_discretized_X = one_hot( + X.argmax(dim=-1), num_classes=X.shape[-1] + ).to(X) + discretized_X = OneHotArgmaxSTE.apply(X) + # test forward + self.assertTrue(torch.equal(expected_discretized_X, discretized_X)) + # test backward + X = X.requires_grad_(True) + output = OneHotArgmaxSTE.apply(X) + # sample some weights to checked that gradients are passed + # as intended + w = torch.rand_like(X) + (w * output).sum().backward() + self.assertTrue(torch.equal(w, X.grad))