From 2b2b28181a038b25bd4634808cdfec869b30d834 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Thu, 11 Nov 2021 17:42:15 -0600 Subject: [PATCH] feat: Add transpose function to tensorlib (#1696) * Add transpose function to tensorlib * Add test for transpose * Fix a backend name typo in the JAX backend docstrings --- src/pyhf/tensor/jax_backend.py | 27 +++++++++++++++++++++++++- src/pyhf/tensor/numpy_backend.py | 25 ++++++++++++++++++++++++ src/pyhf/tensor/pytorch_backend.py | 25 ++++++++++++++++++++++++ src/pyhf/tensor/tensorflow_backend.py | 28 +++++++++++++++++++++++++++ tests/test_tensor.py | 6 ++++++ 5 files changed, 110 insertions(+), 1 deletion(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 2da67574d7..f6466357b2 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -567,7 +567,7 @@ def normal_dist(self, mu, sigma): def to_numpy(self, tensor_in): """ - Convert the TensorFlow tensor to a :class:`numpy.ndarray`. + Convert the JAX tensor to a :class:`numpy.ndarray`. Example: >>> import pyhf @@ -591,3 +591,28 @@ def to_numpy(self, tensor_in): """ return np.asarray(tensor_in, dtype=tensor_in.dtype) + + def transpose(self, tensor_in): + """ + Transpose the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("jax") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> tensor + DeviceArray([[1., 2., 3.], + [4., 5., 6.]], dtype=float64) + >>> pyhf.tensorlib.transpose(tensor) + DeviceArray([[1., 4.], + [2., 5.], + [3., 6.]], dtype=float64) + + Args: + tensor_in (:obj:`tensor`): The input tensor object. + + Returns: + JAX ndarray: The transpose of the input tensor. + + """ + return tensor_in.transpose() diff --git a/src/pyhf/tensor/numpy_backend.py b/src/pyhf/tensor/numpy_backend.py index c0b2b6dbf6..c27e78cdca 100644 --- a/src/pyhf/tensor/numpy_backend.py +++ b/src/pyhf/tensor/numpy_backend.py @@ -570,3 +570,28 @@ def to_numpy(self, tensor_in): """ return tensor_in + + def transpose(self, tensor_in): + """ + Transpose the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("numpy") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> tensor + array([[1., 2., 3.], + [4., 5., 6.]]) + >>> pyhf.tensorlib.transpose(tensor) + array([[1., 4.], + [2., 5.], + [3., 6.]]) + + Args: + tensor_in (:obj:`tensor`): The input tensor object. + + Returns: + :class:`numpy.ndarray`: The transpose of the input tensor. + + """ + return tensor_in.transpose() diff --git a/src/pyhf/tensor/pytorch_backend.py b/src/pyhf/tensor/pytorch_backend.py index dd6207f634..d82c297edf 100644 --- a/src/pyhf/tensor/pytorch_backend.py +++ b/src/pyhf/tensor/pytorch_backend.py @@ -592,3 +592,28 @@ def to_numpy(self, tensor_in): """ return tensor_in.numpy() + + def transpose(self, tensor_in): + """ + Transpose the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("pytorch") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> tensor + tensor([[1., 2., 3.], + [4., 5., 6.]]) + >>> pyhf.tensorlib.transpose(tensor) + tensor([[1., 4.], + [2., 5.], + [3., 6.]]) + + Args: + tensor_in (:obj:`tensor`): The input tensor object. + + Returns: + PyTorch FloatTensor: The transpose of the input tensor. + + """ + return tensor_in.transpose(0, 1) diff --git a/src/pyhf/tensor/tensorflow_backend.py b/src/pyhf/tensor/tensorflow_backend.py index a0a8337d04..3168ac6a35 100644 --- a/src/pyhf/tensor/tensorflow_backend.py +++ b/src/pyhf/tensor/tensorflow_backend.py @@ -686,3 +686,31 @@ def to_numpy(self, tensor_in): """ return tensor_in.numpy() + + def transpose(self, tensor_in): + """ + Transpose the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("tensorflow") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> print(tensor) + tf.Tensor( + [[1. 2. 3.] + [4. 5. 6.]], shape=(2, 3), dtype=float64) + >>> tensor_T = pyhf.tensorlib.transpose(tensor) + >>> print(tensor_T) + tf.Tensor( + [[1. 4.] + [2. 5.] + [3. 6.]], shape=(3, 2), dtype=float64) + + Args: + tensor_in (:obj:`tensor`): The input tensor object. + + Returns: + TensorFlow Tensor: The transpose of the input tensor. + + """ + return tf.transpose(tensor_in) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 59b7496c78..194344a0a6 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -79,6 +79,12 @@ def test_simple_tensor_ops(backend): assert tb.tolist(tb.conditional((a < b), lambda: a + b, lambda: a - b)) == 9.0 assert tb.tolist(tb.conditional((a > b), lambda: a + b, lambda: a - b)) == -1.0 + assert tb.tolist(tb.transpose(tb.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]))) == [ + [1.0, 4.0], + [2.0, 5.0], + [3.0, 6.0], + ] + def test_tensor_where_scalar(backend): tb = pyhf.tensorlib