From 59b3651bae6301f1a7db4044d97a080ea27b8ee6 Mon Sep 17 00:00:00 2001 From: bjp Date: Thu, 3 Oct 2024 05:08:38 -0700 Subject: [PATCH] Update tensor_shape.py for JAX/NumPy backend. PiperOrigin-RevId: 681829342 --- .../python/internal/backend/numpy/gen/tensor_shape.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py index e6747ca9bd..d37bc69e63 100755 --- a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py @@ -1435,11 +1435,12 @@ def is_compatible_with(self, other): """ other = as_shape(other) - if self.dims is not None and other.dims is not None: + if self._dims is not None and other._dims is not None: # pylint: disable=protected-access if self.rank != other.rank: return False - for x_dim, y_dim in zip(self.dims, other.dims): - if not x_dim.is_compatible_with(y_dim): + for x_dim, y_dim in zip(self._dims, other._dims): # pylint: disable=protected-access + # Inline TensorShape.dims logic for performance in tight loops. + if x_dim is not None and y_dim is not None and x_dim != y_dim: return False return True