Skip to content

Commit

Permalink
Merge pull request #1877 from NNPDF/replica-axis-first
Browse files Browse the repository at this point in the history
Move replica axis to front everywhere
  • Loading branch information
scarlehoff authored Dec 8, 2023
2 parents 5eebfba + d5e14f9 commit 8cbe0cf
Show file tree
Hide file tree
Showing 19 changed files with 160 additions and 201 deletions.
10 changes: 6 additions & 4 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,7 @@ def split_replicas(self):
if self.single_replica_generator is None:
raise ValueError("Trying to generate single replica models with no generator set.")
replicas = []
num_replicas = self.output.shape[-1]
for i_replica in range(num_replicas):
for i_replica in range(self.num_replicas):
replica = self.single_replica_generator()
replica.set_replica_weights(self.get_replica_weights(i_replica))

Expand All @@ -410,14 +409,17 @@ def split_replicas(self):

return replicas

@property
def num_replicas(self):
return self.output.shape[1]

def load_identical_replicas(self, model_file):
"""
From a single replica model, load the same weights into all replicas.
"""
weights = self._format_weights_from_file(model_file)

num_replicas = self.output.shape[-1]
for i_replica in range(num_replicas):
for i_replica in range(self.num_replicas):
self.set_replica_weights(weights, i_replica)

def _format_weights_from_file(self, model_file):
Expand Down
14 changes: 7 additions & 7 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,26 +261,26 @@ def pdf_masked_convolution(raw_pdf, basis_mask):
Parameters
----------
pdf: tf.tensor
rank 4 (batchsize, xgrid, flavours, replicas)
rank 4 (batchsize, replicas, xgrid, flavours)
basis_mask: tf.tensor
rank 2 tensor (flavours, flavours)
mask to apply to the pdf convolution
Return
------
pdf_x_pdf: tf.tensor
rank3 (len(mask_true), xgrid, xgrid, replicas)
rank3 (replicas, len(mask_true), xgrid, xgrid)
"""
if raw_pdf.shape[-1] == 1: # only one replica!
pdf = tf.squeeze(raw_pdf, axis=(0, -1))
if raw_pdf.shape[1] == 1: # only one replica!
pdf = tf.squeeze(raw_pdf, axis=(0, 1))
luminosity = tensor_product(pdf, pdf, axes=0)
lumi_tmp = K.permute_dimensions(luminosity, (3, 1, 2, 0))
pdf_x_pdf = batchit(boolean_mask(lumi_tmp, basis_mask), -1)
pdf_x_pdf = batchit(boolean_mask(lumi_tmp, basis_mask), 0)
else:
pdf = tf.squeeze(raw_pdf, axis=0) # remove the batchsize
luminosity = tf.einsum('air,bjr->jibar', pdf, pdf)
luminosity = tf.einsum('rai,rbj->rjiba', pdf, pdf)
# (xgrid, flavour, xgrid, flavour)
pdf_x_pdf = boolean_mask(luminosity, basis_mask)
pdf_x_pdf = boolean_mask(luminosity, basis_mask, axis=1)
return pdf_x_pdf


Expand Down
13 changes: 7 additions & 6 deletions n3fit/src/n3fit/hyper_optimization/penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,18 @@ def saturation(pdf_model=None, n=100, min_x=1e-6, max_x=1e-4, flavors=None, **_k
if flavors is None:
flavors = [1, 2]
x = np.logspace(np.log10(min_x), np.log10(max_x), n)
x = np.expand_dims(x, axis=[0, -1])
extra_loss = 0.0

y = pdf_model.predict({"pdf_input": x})
xpdf = y[0, :, flavors]
x_input = np.expand_dims(x, axis=[0, -1])
y = pdf_model.predict({"pdf_input": x_input})
xpdf = y[0, :, :, flavors] # this is now of shape (flavors, replicas, xgrid)

delta_logx = np.diff(np.log10(x), axis=1)
delta_xpdf = np.diff(xpdf, axis=1)
x = np.expand_dims(x, axis=[0, 1])
delta_logx = np.diff(np.log10(x), axis=2)
delta_xpdf = np.diff(xpdf, axis=2)
slope = delta_xpdf / delta_logx

pen = abs(np.mean(slope, axis=1)) + np.std(slope, axis=1)
pen = abs(np.mean(slope, axis=2)) + np.std(slope, axis=2)

# sum over flavors
# Add a small offset to avoid ZeroDivisionError
Expand Down
58 changes: 30 additions & 28 deletions n3fit/src/n3fit/layers/DIS.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,31 @@
"""

import numpy as np
from .observable import Observable

from n3fit.backends import operations as op

from .observable import Observable


class DIS(Observable):
"""
The DIS class receives a list of active flavours and a fktable
and prepares a layer that performs the convolution of said fktable with
the incoming pdf.
The DIS class receives a list of active flavours and a fktable
and prepares a layer that performs the convolution of said fktable with
the incoming pdf.
The fktable is expected to be rank 3 (ndata, xgrid, flavours)
while the input pdf is rank 4 where the first dimension is the batch dimension
and the last dimension the number of replicas being fitted (1, xgrid, flavours, replicas)
The fktable is expected to be rank 3 (ndata, xgrid, flavours)
while the input pdf is rank 4 where the first dimension is the batch dimension
and the last dimension the number of replicas being fitted (1, replicas, xgrid, flavours)
"""

def gen_mask(self, basis):
"""
Receives a list of active flavours and generates a boolean mask tensor
Receives a list of active flavours and generates a boolean mask tensor
Parameters
----------
basis: list(int)
list of active flavours
Parameters
----------
basis: list(int)
list of active flavours
"""
if basis is None:
self.basis = np.ones(self.nfl, dtype=bool)
Expand All @@ -41,21 +43,21 @@ def gen_mask(self, basis):

def call(self, pdf):
"""
This function perform the fktable \otimes pdf convolution.
This function perform the fktable \otimes pdf convolution.
First pass the input PDF through a mask to remove the unactive flavors,
then a tensor_product between the PDF and each fktable is performed
finally the defined operation is applied to all the results
First pass the input PDF through a mask to remove the unactive flavors,
then a tensor_product between the PDF and each fktable is performed
finally the defined operation is applied to all the results
Parameters
----------
pdf: backend tensor
rank 4 tensor (batch_size, xgrid, flavours, replicas)
Parameters
----------
pdf: backend tensor
rank 4 tensor (batch_size, replicas, xgrid, flavours)
Returns
-------
result: backend tensor
rank 3 tensor (batchsize, replicas, ndata)
Returns
-------
result: backend tensor
rank 3 tensor (batchsize, replicas, ndata)
"""
# DIS never needs splitting
if self.splitting is not None:
Expand All @@ -65,13 +67,13 @@ def call(self, pdf):
# Separate the two possible paths this layer can take
if self.many_masks:
for mask, fktable in zip(self.all_masks, self.fktables):
pdf_masked = op.boolean_mask(pdf, mask, axis=2)
res = op.tensor_product(pdf_masked, fktable, axes=[(1, 2), (2, 1)])
pdf_masked = op.boolean_mask(pdf, mask, axis=3)
res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)])
results.append(res)
else:
pdf_masked = op.boolean_mask(pdf, self.all_masks[0], axis=2)
pdf_masked = op.boolean_mask(pdf, self.all_masks[0], axis=3)
for fktable in self.fktables:
res = op.tensor_product(pdf_masked, fktable, axes=[(1, 2), (2, 1)])
res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)])
results.append(res)

return self.operation(results)
14 changes: 8 additions & 6 deletions n3fit/src/n3fit/layers/DY.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from .observable import Observable

from n3fit.backends import operations as op

from .observable import Observable


class DY(Observable):
"""
Expand Down Expand Up @@ -30,7 +32,7 @@ def call(self, pdf_raw):
Parameters
----------
pdf_in: tensor
rank 4 tensor (batchsize, xgrid, flavours, replicas)
rank 4 tensor (batchsize, replicas, xgrid, flavours)
Returns
-------
Expand All @@ -43,20 +45,20 @@ def call(self, pdf_raw):
results = []
if self.many_masks:
if self.splitting:
splitted_pdf = op.split(pdf_raw, self.splitting, axis=1)
splitted_pdf = op.split(pdf_raw, self.splitting, axis=2)
for mask, pdf, fk in zip(self.all_masks, splitted_pdf, self.fktables):
pdf_x_pdf = op.pdf_masked_convolution(pdf, mask)
res = op.tensor_product(fk, pdf_x_pdf, axes=3)
res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)])
results.append(res)
else:
for mask, fk in zip(self.all_masks, self.fktables):
pdf_x_pdf = op.pdf_masked_convolution(pdf_raw, mask)
res = op.tensor_product(fk, pdf_x_pdf, axes=3)
res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)])
results.append(res)
else:
pdf_x_pdf = op.pdf_masked_convolution(pdf_raw, self.all_masks[0])
for fk in self.fktables:
res = op.tensor_product(fk, pdf_x_pdf, axes=3)
res = op.tensor_product(fk, pdf_x_pdf, axes=[(1, 2, 3), (1, 2, 3)])
results.append(res)

# the masked convolution removes the batch dimension
Expand Down
44 changes: 11 additions & 33 deletions n3fit/src/n3fit/layers/msr_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,14 @@
from n3fit.backends import MetaLayer
from n3fit.backends import operations as op

IDX = {
'photon': 0,
'sigma': 1,
'g': 2,
'v': 3,
'v3': 4,
'v8': 5,
'v15': 6,
'v24': 7,
'v35': 8,
}
IDX = {'photon': 0, 'sigma': 1, 'g': 2, 'v': 3, 'v3': 4, 'v8': 5, 'v15': 6, 'v24': 7, 'v35': 8}
MSR_COMPONENTS = ['g']
MSR_DENOMINATORS = {'g': 'g'}
# The VSR normalization factor of component f is given by
# VSR_CONSTANTS[f] / VSR_DENOMINATORS[f]
VSR_COMPONENTS = ['v', 'v35', 'v24', 'v3', 'v8', 'v15']
VSR_CONSTANTS = {
'v': 3.0,
'v35': 3.0,
'v24': 3.0,
'v3': 1.0,
'v8': 3.0,
'v15': 3.0,
}
VSR_DENOMINATORS = {
'v': 'v',
'v35': 'v',
'v24': 'v',
'v3': 'v3',
'v8': 'v8',
'v15': 'v15',
}
VSR_CONSTANTS = {'v': 3.0, 'v35': 3.0, 'v24': 3.0, 'v3': 1.0, 'v8': 3.0, 'v15': 3.0}
VSR_DENOMINATORS = {'v': 'v', 'v35': 'v', 'v24': 'v', 'v3': 'v3', 'v8': 'v8', 'v15': 'v15'}


class MSR_Normalization(MetaLayer):
Expand Down Expand Up @@ -93,18 +69,20 @@ def call(self, pdf_integrated, photon_integral):
Parameters
----------
pdf_integrated: (Tensor(1, 14, replicas))
pdf_integrated: (Tensor(1, replicas, 14))
the integrated PDF
photon_integral: (Tensor(1, 1, replicas))
photon_integral: (Tensor(1, replicas, 1))
the integrated photon PDF
Returns
-------
normalization_factor: Tensor(14, replicas)
normalization_factor: Tensor(replicas, 1, 14)
The normalization factors per flavour.
"""
y = pdf_integrated[0] # get rid of the batch dimension
photon_integral = photon_integral[0] # get rid of the batch dimension
# get rid of batch dimension and put replicas last
reshape = lambda x: op.transpose(x[0])
y = reshape(pdf_integrated)
photon_integral = reshape(photon_integral)
numerators = []

if self._msr_enabled:
Expand All @@ -122,4 +100,4 @@ def call(self, pdf_integrated, photon_integral):
numerators / divisors, indices=self.indices, output_shape=y.shape
)

return norm_constants
return op.batchit(op.transpose(norm_constants), batch_dimension=1)
13 changes: 4 additions & 9 deletions n3fit/src/n3fit/layers/rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Rotation(MetaLayer):
rotation_axis of input to be rotated
"""

def __init__(self, rotation_matrix, rotation_axis=2, **kwargs):
def __init__(self, rotation_matrix, rotation_axis=3, **kwargs):
self.rotation_matrix = op.numpy_to_tensor(rotation_matrix)
self.rotation_axis = rotation_axis
super().__init__(**kwargs)
Expand All @@ -47,12 +47,7 @@ class FlavourToEvolution(Rotation):
the evolution basis.
"""

def __init__(
self,
flav_info,
fitbasis,
**kwargs,
):
def __init__(self, flav_info, fitbasis, **kwargs):
rotation_matrix = pdfbases.fitbasis_to_NN31IC(flav_info, fitbasis)
super().__init__(rotation_matrix, **kwargs)

Expand Down Expand Up @@ -111,15 +106,15 @@ def __init__(self, photons, **kwargs):
super().__init__(**kwargs)

def register_photon(self, xgrid):
"""Compute the photon array of shape (1, xgrid, 1, replicas) and set the layer to be rebuilt"""
"""Compute the photon array of shape (1, replicas, xgrid, 1) and set the layer to be rebuilt"""
if self._photons_generator:
self._pdf_ph = self._photons_generator(xgrid)
self.built = False

def call(self, pdfs):
if self._pdf_ph is None:
return pdfs
return op.concatenate([self._pdf_ph, pdfs[:, :, 1:]], axis=2)
return op.concatenate([self._pdf_ph, pdfs[:, :, :, 1:]], axis=3)


class ObsRotation(MetaLayer):
Expand Down
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/layers/x_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ class xIntegrator(MetaLayer):
----------
grid_weights: np.array
weights of the grid
x_axis: int (default=1)
x_axis: int (default=2)
axis of the input tensor that corresponds to the x-grid
"""

def __init__(self, grid_weights, x_axis=1, **kwargs):
def __init__(self, grid_weights, x_axis=2, **kwargs):
self.x_axis = x_axis
self.grid_weights = op.flatten(op.numpy_to_tensor(grid_weights))
super().__init__(**kwargs)
Expand Down
Loading

0 comments on commit 8cbe0cf

Please sign in to comment.