Skip to content

Commit

Permalink
feat: Add transpose function to tensorlib (#1696)
Browse files Browse the repository at this point in the history
* Add transpose function to tensorlib
* Add test for transpose
* Fix a backend name typo in the JAX backend docstrings
  • Loading branch information
matthewfeickert authored Nov 11, 2021
1 parent e7c0fce commit 2b2b281
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 1 deletion.
27 changes: 26 additions & 1 deletion src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def normal_dist(self, mu, sigma):

def to_numpy(self, tensor_in):
"""
Convert the TensorFlow tensor to a :class:`numpy.ndarray`.
Convert the JAX tensor to a :class:`numpy.ndarray`.
Example:
>>> import pyhf
Expand All @@ -591,3 +591,28 @@ def to_numpy(self, tensor_in):
"""
return np.asarray(tensor_in, dtype=tensor_in.dtype)

def transpose(self, tensor_in):
"""
Transpose the tensor.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
DeviceArray([[1., 2., 3.],
[4., 5., 6.]], dtype=float64)
>>> pyhf.tensorlib.transpose(tensor)
DeviceArray([[1., 4.],
[2., 5.],
[3., 6.]], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object.
Returns:
JAX ndarray: The transpose of the input tensor.
"""
return tensor_in.transpose()
25 changes: 25 additions & 0 deletions src/pyhf/tensor/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,28 @@ def to_numpy(self, tensor_in):
"""
return tensor_in

def transpose(self, tensor_in):
"""
Transpose the tensor.
Example:
>>> import pyhf
>>> pyhf.set_backend("numpy")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
array([[1., 2., 3.],
[4., 5., 6.]])
>>> pyhf.tensorlib.transpose(tensor)
array([[1., 4.],
[2., 5.],
[3., 6.]])
Args:
tensor_in (:obj:`tensor`): The input tensor object.
Returns:
:class:`numpy.ndarray`: The transpose of the input tensor.
"""
return tensor_in.transpose()
25 changes: 25 additions & 0 deletions src/pyhf/tensor/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,28 @@ def to_numpy(self, tensor_in):
"""
return tensor_in.numpy()

def transpose(self, tensor_in):
"""
Transpose the tensor.
Example:
>>> import pyhf
>>> pyhf.set_backend("pytorch")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
tensor([[1., 2., 3.],
[4., 5., 6.]])
>>> pyhf.tensorlib.transpose(tensor)
tensor([[1., 4.],
[2., 5.],
[3., 6.]])
Args:
tensor_in (:obj:`tensor`): The input tensor object.
Returns:
PyTorch FloatTensor: The transpose of the input tensor.
"""
return tensor_in.transpose(0, 1)
28 changes: 28 additions & 0 deletions src/pyhf/tensor/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,31 @@ def to_numpy(self, tensor_in):
"""
return tensor_in.numpy()

def transpose(self, tensor_in):
"""
Transpose the tensor.
Example:
>>> import pyhf
>>> pyhf.set_backend("tensorflow")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> print(tensor)
tf.Tensor(
[[1. 2. 3.]
[4. 5. 6.]], shape=(2, 3), dtype=float64)
>>> tensor_T = pyhf.tensorlib.transpose(tensor)
>>> print(tensor_T)
tf.Tensor(
[[1. 4.]
[2. 5.]
[3. 6.]], shape=(3, 2), dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object.
Returns:
TensorFlow Tensor: The transpose of the input tensor.
"""
return tf.transpose(tensor_in)
6 changes: 6 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def test_simple_tensor_ops(backend):
assert tb.tolist(tb.conditional((a < b), lambda: a + b, lambda: a - b)) == 9.0
assert tb.tolist(tb.conditional((a > b), lambda: a + b, lambda: a - b)) == -1.0

assert tb.tolist(tb.transpose(tb.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]))) == [
[1.0, 4.0],
[2.0, 5.0],
[3.0, 6.0],
]


def test_tensor_where_scalar(backend):
tb = pyhf.tensorlib
Expand Down

0 comments on commit 2b2b281

Please sign in to comment.