Skip to content

Commit

Permalink
Rename moveaxis to swapaxes
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jul 27, 2023
1 parent becbf9d commit f830a74
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
6 changes: 4 additions & 2 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,11 @@ def op_subtract(inputs, **kwargs):
"""
return keras_subtract(inputs, **kwargs)

def moveaxis(tensor, source, destination):

def swapaxes(tensor, source, destination):
"""
Moves the axis of the tensor from source to destination
Moves the axis of the tensor from source to destination, as in numpy.swapaxes.
see full `docs <https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html>`_
"""
indices = list(range(tensor.shape.rank))
if source < 0:
Expand Down
38 changes: 21 additions & 17 deletions n3fit/src/n3fit/layers/rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def is_identity(self):
def call(self, x_raw):
rotated = op.tensor_product(x_raw, self.rotation_matrix, [self.rotation_axis, 0])
# this puts the rotated axis back in the original place
return op.moveaxis(rotated, -1, self.rotation_axis)
return op.swapaxes(rotated, -1, self.rotation_axis)


class FlavourToEvolution(Rotation):
"""
Rotates from the flavour basis to
the evolution basis.
"""

def __init__(
self,
flav_info,
Expand All @@ -64,29 +65,32 @@ class FkRotation(Rotation):
The input to this layer is a `pdf_raw` variable which is expected to have
a shape (1, None, 9), and it is then rotated to an output (1, None, 14)
"""

def __init__(self, output_dim=14, name="evolution", **kwargs):
self.output_dim = output_dim
rotation_matrix = self._create_rotation_matrix()
super().__init__(rotation_matrix, name=name, **kwargs)

def _create_rotation_matrix(self):
"""Create the rotation matrix"""
array = np.array([
[0, 0, 0, 0, 0, 0, 0, 0, 0], # photon
[1, 0, 0, 0, 0, 0, 0, 0, 0], # sigma
[0, 1, 0, 0, 0, 0, 0, 0, 0], # g
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v
[0, 0, 0, 1, 0, 0, 0, 0, 0], # v3
[0, 0, 0, 0, 1, 0, 0, 0, 0], # v8
[0, 0, 0, 0, 0, 0, 0, 0, 1], # v15
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v24
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v35
[0, 0, 0, 0, 0, 1, 0, 0, 0], # t3
[0, 0, 0, 0, 0, 0, 1, 0, 0], # t8
[1, 0, 0, 0, 0, 0, 0,-4, 0], # t15 (c-)
[1, 0, 0, 0, 0, 0, 0, 0, 0], # t24
[1, 0, 0, 0, 0, 0, 0, 0, 0], # t35
])
array = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0], # photon
[1, 0, 0, 0, 0, 0, 0, 0, 0], # sigma
[0, 1, 0, 0, 0, 0, 0, 0, 0], # g
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v
[0, 0, 0, 1, 0, 0, 0, 0, 0], # v3
[0, 0, 0, 0, 1, 0, 0, 0, 0], # v8
[0, 0, 0, 0, 0, 0, 0, 0, 1], # v15
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v24
[0, 0, 1, 0, 0, 0, 0, 0, 0], # v35
[0, 0, 0, 0, 0, 1, 0, 0, 0], # t3
[0, 0, 0, 0, 0, 0, 1, 0, 0], # t8
[1, 0, 0, 0, 0, 0, 0, -4, 0], # t15 (c-)
[1, 0, 0, 0, 0, 0, 0, 0, 0], # t24
[1, 0, 0, 0, 0, 0, 0, 0, 0], # t35
]
)
tensor = op.numpy_to_tensor(array.T)
return tensor

Expand Down

0 comments on commit f830a74

Please sign in to comment.