Skip to content

Commit

Permalink
Add BotorchTestCase.assertAllClose (#1618)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1618

`BotorchTestCase.assertAllClose` will print more informative error messages on failure than `TestCase.assertTrue(torch.allclose(...))`. It uses `torch.testing.assert_close`.

Old test output:
```AssertionError: False is not true```

New test output:
```
1) AssertionError: Scalars are not close!

Absolute difference: 1.0000034868717194 (up to 0.0001 allowed)
Relative difference: 0.8348668001940709 (up to 1e-05 allowed)
```

This currently replicates the behavior of `torch.allclose` so that tests remain exactly as strict as they used to be, but in the future we might want to use the behavior of `assert_close` instead since it uses higher tolerances for single-precision inputs by default and is more configurable.

Differential Revision: D42402142

fbshipit-source-id: 62dc6df5e786a72a758a4d9a08ca92f88ab2c149
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 7, 2023
1 parent bb5fc4c commit 19ec6e1
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union
from unittest import TestCase

import torch
Expand Down Expand Up @@ -51,6 +51,36 @@ def setUp(self):
category=UserWarning,
)

def assertAllClose(
self,
input: torch.Tensor,
other: Union[torch.Tensor, float],
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> None:
r"""
Calls torch.testing.assert_close, using the signature and default behavior
of torch.allclose.
Example output:
AssertionError: Scalars are not close!
Absolute difference: 1.0000034868717194 (up to 0.0001 allowed)
Relative difference: 0.8348668001940709 (up to 1e-05 allowed)
"""
# Why not just use the signature and behavior of `torch.testing.assert_close`?
# Because we used `torch.allclose` for testing in the past, and the two don't
# behave exactly the same. In particular, `assert_close` requires both `atol`
# and `rtol` to be set if either one is.
torch.testing.assert_close(
input,
other,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)


class BaseTestProblemBaseTestCase:

Expand Down

0 comments on commit 19ec6e1

Please sign in to comment.