Skip to content

Commit

Permalink
type test option
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Sep 15, 2021
1 parent cd7a643 commit 93ebdc3
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,36 @@ def clone(data: NdarrayTensor) -> NdarrayTensor:
return copy.deepcopy(data)


def assert_allclose(actual: NdarrayOrTensor, desired: NdarrayOrTensor, device_test: bool = False, *args, **kwargs):
def assert_allclose(
actual: NdarrayOrTensor,
desired: NdarrayOrTensor,
type_test: bool = False,
device_test: bool = False,
*args,
**kwargs,
):
"""
Assert that types and all values of two data objects are close.
Args:
actual: Pytorch Tensor or numpy array for comparison.
desired: Pytorch Tensor or numpy array to compare against.
type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors.
device_test: whether to test the device property.
args: extra arguments to pass on to `np.testing.assert_allclose`.
kwargs: extra arguments to pass on to `np.testing.assert_allclose`.
"""
if type_test:
# check both actual and desired are of the same type
np.testing.assert_equal(isinstance(actual, np.ndarray), isinstance(desired, np.ndarray))
np.testing.assert_equal(isinstance(actual, torch.Tensor), isinstance(desired, torch.Tensor))

if isinstance(desired, torch.Tensor):
np.testing.assert_equal(isinstance(actual, torch.Tensor), True)
if device_test:
np.testing.assert_equal(str(actual.device), str(desired.device)) # type: ignore
actual = actual.cpu().numpy() # type: ignore
actual = actual.cpu().numpy() if isinstance(actual, torch.Tensor) else actual
desired = desired.cpu().numpy()
np.testing.assert_allclose(actual, desired, *args, **kwargs)

Expand Down

0 comments on commit 93ebdc3

Please sign in to comment.