From 159df93393118bb76740e8899cb6f4444c7a187c Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 26 Nov 2023 16:54:22 +0100 Subject: [PATCH] always_close() now returns False for incompatible tensors --- phiml/math/_ops.py | 19 +++++++++++++++---- tests/commit/math/test__ops.py | 1 + 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 990bfa8e..a773c8df 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -2624,11 +2624,15 @@ def dtype(x) -> DType: def always_close(t1: Union[Number, Tensor, bool], t2: Union[Number, Tensor, bool], rel_tolerance=1e-5, abs_tolerance=0, equal_nan=False) -> bool: """ Checks whether two tensors are guaranteed to be `close` in all values. - Unlike `close()`, this function can be used with JIT compilation. + Unlike `close()`, this function can be used with JIT compilation and with tensors of incompatible shapes. + Incompatible tensors are never close. If one of the given tensors is being traced, the tensors are only equal if they reference the same native tensor. Otherwise, an element-wise equality check is performed. + See Also: + `close()`. + Args: t1: First tensor or number to compare. t2: Second tensor or number to compare. @@ -2644,7 +2648,10 @@ def always_close(t1: Union[Number, Tensor, bool], t2: Union[Number, Tensor, bool if t1.available != t2.available: return False if t1.available and t2.available: - return close(t1, t2, rel_tolerance=rel_tolerance, abs_tolerance=abs_tolerance, equal_nan=equal_nan) + try: + return close(t1, t2, rel_tolerance=rel_tolerance, abs_tolerance=abs_tolerance, equal_nan=equal_nan) + except IncompatibleShapes: + return False elif isinstance(t1, NativeTensor) and isinstance(t2, NativeTensor): return t1._native is t2._native else: @@ -2656,10 +2663,14 @@ def close(*tensors, rel_tolerance=1e-5, abs_tolerance=0, equal_nan=False) -> boo Checks whether all tensors have equal values within the specified tolerance. Does not check that the shapes exactly match. - Tensors with different shapes are reshaped before comparing. + Unlike with `always_close()`, all shapes must be compatible and tensors with different shapes are reshaped before comparing. + + See Also: + `always_close()`. Args: - *tensors: `Tensor` or tensor-like (constant) each + *tensors: At least two `Tensor` or tensor-like objects. + The shapes of all tensors must be compatible but not all tensors must have all dimensions. rel_tolerance: Relative tolerance abs_tolerance: Absolute tolerance equal_nan: If `True`, tensors are considered close if they are NaN in the same places. diff --git a/tests/commit/math/test__ops.py b/tests/commit/math/test__ops.py index 904a2906..d6efbe80 100644 --- a/tests/commit/math/test__ops.py +++ b/tests/commit/math/test__ops.py @@ -51,6 +51,7 @@ def jit(x, y): x = math.tensor(0) y = math.tensor(0) self.assertEqual(1, jit(x, y).native(), msg=b.name) + self.assertFalse(math.always_close(vec(x=0), vec(x=0, y=1))) def test_assert_close_non_uniform(self): t = math.stack([math.zeros(spatial(x=4)), math.zeros(spatial(x=3))], channel('stack'))