diff --git a/n3fit/src/n3fit/backends/keras_backend/operations.py b/n3fit/src/n3fit/backends/keras_backend/operations.py index efff803b92..7491136080 100644 --- a/n3fit/src/n3fit/backends/keras_backend/operations.py +++ b/n3fit/src/n3fit/backends/keras_backend/operations.py @@ -196,18 +196,6 @@ def tensor_ones_like(*args, **kwargs): return K.ones_like(*args, **kwargs) -@tf.function -def many_replication(grid, replications, axis=0, **kwargs): - """ - Generates a tensor with one extra dimension: - a repetition of "grid" n times along the given axis - from keras documentation: - If x has shape (s1, s2, s3) and axis is 1, the output will have shape (s1, s2 * rep, s3) - see full `docs `_ - """ - return K.repeat_elements(grid, rep=replications, axis=axis, **kwargs) - - # Property operations # modify properties of the tensor like the shape or elements it has @tf.function diff --git a/n3fit/src/n3fit/layers/x_operations.py b/n3fit/src/n3fit/layers/x_operations.py index a71a42d93d..1d22b78cca 100644 --- a/n3fit/src/n3fit/layers/x_operations.py +++ b/n3fit/src/n3fit/layers/x_operations.py @@ -56,24 +56,23 @@ def get_config(self): class xIntegrator(MetaLayer): """ - This layer performs a sum of the input layer/tensor on the first axis + This layer performs a sum of the input layer/tensor on the axis corresponding to the x-grid + weighted by the weights of the grid. - Receives as input a rank-n (n > 1) tensor `x` (batch_dims ..., xpoints, flavours) - and returns a summation on the `xpoints` index (i.e., index -2) - weighted by the weights of the grid + The output shape is the input shape with the x-axis removed. Parameters ---------- grid_weights: np.array weights of the grid + x_axis: int (default=1) + axis of the input tensor that corresponds to the x-grid """ - def __init__(self, grid_weights, output_dim=BASIS_SIZE, **kwargs): - grid_weights_tensor = op.numpy_to_tensor(grid_weights) - # Open up the grid weights - self.grid_weights = op.many_replication(grid_weights_tensor, output_dim, axis=1) + def __init__(self, grid_weights, x_axis=1, **kwargs): + self.x_axis = x_axis + self.grid_weights = op.flatten(op.numpy_to_tensor(grid_weights)) super().__init__(**kwargs) - def call(self, x): - xx = x * self.grid_weights - return op.sum(xx, axis=-2) + def call(self, pdf): + return op.tensor_product(pdf, self.grid_weights, axes=[self.x_axis, 0]) diff --git a/n3fit/src/n3fit/tests/test_xops.py b/n3fit/src/n3fit/tests/test_xops.py index 99ed67f64e..db030ae350 100644 --- a/n3fit/src/n3fit/tests/test_xops.py +++ b/n3fit/src/n3fit/tests/test_xops.py @@ -3,7 +3,8 @@ """ import numpy as np -from n3fit.layers import xDivide +from n3fit.backends import operations as op +from n3fit.layers import xDivide, xIntegrator def test_xdivide_default(): @@ -21,7 +22,7 @@ def test_xdivide_default(): def test_xdivide_indices(): - """Check that the default xDivide works as expected""" + """Check that xDivide with custom indices works as expected""" custom_indices = [0, 1, 7] x_div = xDivide(div_list=custom_indices) test_input = np.array([1, 2, 3], dtype=np.float32).reshape((1, 3, 1)) @@ -32,3 +33,15 @@ def test_xdivide_indices(): expected_output[:, :, i] = 1 / test_input[:, :, 0] np.testing.assert_allclose(test_output, expected_output, rtol=1e-05) + + +def test_xintegrator(): + np.random.seed(42) + weights = np.random.rand(5, 1) + pdf = op.numpy_to_tensor(np.random.rand(1, 5, 8)) + xint = xIntegrator(weights) + xint_out = xint(pdf) + xint_out_reference = np.array( + [[0.405455, 0.878931, 0.937715, 0.906214, 1.984154, 1.147975, 1.642387, 1.549858]] + ) + np.testing.assert_allclose(xint_out.numpy(), xint_out_reference, rtol=1e-05)