Skip to content

Commit

Permalink
add user facing interface
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Feb 7, 2025
1 parent 46fd17c commit e55800f
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 36 deletions.
113 changes: 113 additions & 0 deletions bgtrees/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,114 @@
"""
Public-facing functions
compute_current_j_mu:
compute current per event given arrays of momenta and polarization
generate_batch_phase_space
generate a batch momentum-conserving phase space point for a given field
"""

import numpy as np

from ._version import __version__ # noqa
from .currents import J_μ, another_j
from .finite_gpufields.finite_fields_tf import FiniteField
from .phase_space import random_phase_space_point
from .settings import settings
from .states import ε1, ε2, ε3, ε4

PVAL = settings.p


def generate_batch_points(multiplicity=4, dimension=6, batch_size=3, field_type="ff", helconf="ppmm"):
"""Generate a batch of random momentum conserving phase space points.
Requires syngular.
Returns a tuple containing momenta and polarization for each phase space point,
both arrays of shape (events, multiplicity, dimension)
"""
# TODO: pass seed through
try:
from syngular import Field
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Please install `syngular` to generate phase space points.") from e

if len(helconf) != multiplicity:
raise ValueError(f"Please, make sure that multiplicity ({multiplicity}) and helconf ({helconf}) are consistent")

if field_type.lower() in ("ff", "finitefield"):
field = Field("finite field", PVAL, 1)
settings.dtype = np.int64

def convert(xinput):
# Make it into a container
# TODO: check whether there's a benefit on the container in CPU as well, otherwise keep the modP type
return FiniteField(np.array(xinput).astype(int), PVAL)

elif field_type.lower() in ("mpc", "float"):
field = Field("mpc", 0, 300)
# settings.dtype = np.float64

def convert(xinput):
return np.array(xinput)

else:
raise ValueError(f"Field type not understood: {field_type}")

lmoms = []
lpols = []

reference_vector = random_phase_space_point(2, 4, field, seed=74)[0]
for _ in range(batch_size):
momenta = np.array(random_phase_space_point(multiplicity, dimension, field))

tmp = []
for idx, hel in enumerate(helconf):
if hel in ("1", "m"):
polarization_function = ε1
elif hel in ("2", "p"):
polarization_function = ε2
elif hel == 3:
polarization_function = ε3
elif hel == 4:
polarization_function = ε4
else:
raise Exception(f"Polarization not understood {hel}.")

pol = polarization_function(momenta[idx], reference_vector, field)
tmp.append(np.block([pol]))

lmoms.append(momenta)
lpols.append(tmp)

lmoms = convert(lmoms)
lpols = convert(lpols)
return lmoms, lpols


def compute_current_j_mu(lmoms, lpols, put_propagator=True):
"""
Recursive vectorized current builder. End of recursion is polarization tensors.
The momenta and polarization arrays can be any number of phase space points in
any multiplicity or dimensionality.
The shape of the input arrays must be (events, multiplicity, dimensions)
"""
# Check whether we are working with a Finite Field TF container
if isinstance(lmoms, FiniteField) or isinstance(lpols, FiniteField):
# Safety check: they are both finite fields
if (tm := type(lmoms)) != (tp := type(lpols)):
raise TypeError(
f"If a either momenta or polarization are Finite Fields both must be. Momenta: {tm}, Polarizations: {tp}"
)
# If we have a Finite Field container and they are both finite field, we are good to go
return another_j(lmoms, lpols, put_propagator=put_propagator)

settings.use_gpu = False

if (tm := type(lmoms.flat[0])) != (tp := type(lpols.flat[0])):
raise TypeError(f"The type of momenta ({tm}) and polarizations ({tp}) differ.")

# Depending on the type of the input we can use different versions of J_mu
return J_μ(lmoms, lpols, put_propagator, einsum=np.einsum)
12 changes: 8 additions & 4 deletions bgtrees/currents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Low-level vectorized current builders.
By using bgtrees.compute_current_j_mu (from __init__.py) the right function
will be used automagically.
"""

import functools

import numpy
Expand All @@ -12,10 +19,7 @@

# @gpu_function
def J_μ(lmoms, lpols, put_propagator=True, depth=0, verbose=False, einsum=numpy.einsum):
"""Recursive vectorized current builder. End of recursion is polarization tensors.
TODO: try to merge this and another_j
"""
"""Recursive vectorized current builder. End of recursion is polarization tensors."""

assert lmoms.shape[:2] == lpols.shape[:2]
replicas, multiplicity, D = lmoms.shape
Expand Down
1 change: 0 additions & 1 deletion bgtrees/finite_gpufields/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def ff_dot_product_tris(x, y, rank_x=None, rank_y=None):
@tf.function(reduce_retracing=False)
def ff_dot_product_tris_single_batch(x, y, rank_x=None, rank_y=None):
"""Single batched version of ff_dot_product_tris
TODO: it should eventually go as well
"""
if rank_x is None:
rank_x = len(x.shape)
Expand Down
1 change: 0 additions & 1 deletion bgtrees/metric_and_vertices.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def V4g(D):


@gpu_function
@tf.function(reduce_retracing=False, jit_compile=True)
def V3g(lp1, lp2, einsum=numpy.einsum):
"""3-gluon vertex, upper indices μνρ, D-dimensional"""
D = lp1.shape[1]
Expand Down
7 changes: 3 additions & 4 deletions bgtrees/phase_space.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy
import sympy

from syngular import Ring, Ideal, QRing
from syngular import Ideal, QRing, Ring

from .metric_and_vertices import η

Expand All @@ -11,7 +10,7 @@ def dDimPhaseSpaceQRing(m, D):
momenta = numpy.array([numpy.array([sympy.symbols(f"p{i}_{j}") for j in range(D)]) for i in range(1, m + 1)])
on_shell_relations = [momentum @ η(D) @ momentum for momentum in momenta]
momentum_conservation = sum(momenta).tolist()
r = Ring('0', tuple(momenta.flatten().tolist()), 'dp')
r = Ring("0", tuple(momenta.flatten().tolist()), "dp")
i = Ideal(r, on_shell_relations + momentum_conservation)
q = QRing(r, i)
return q
Expand All @@ -35,7 +34,7 @@ def μ2(momD, d=None):
D = momD.shape[0]
if d is None:
d = D
return - momD[4:d] @ η(D)[4:d, 4:d] @ momD[4:d]
return -momD[4:d] @ η(D)[4:d, 4:d] @ momD[4:d]


def momflat(momD, momχ):
Expand Down
70 changes: 44 additions & 26 deletions bgtrees/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,50 @@
See eq.~11 to 18 of arXiv:250X.XXXXX.
"""


import numpy

from lips import Particle, Particles
import numpy

from .phase_space import momflat, μ2


def ε1(momD, momχ, field):
"""Corresponds to plus helicity in D=4."""
D, momFlat = len(momD), momflat(momD, momχ)
ε1 = Particle(Particles([Particle(momFlat, field=field), Particle(momχ, field=field)],
field=field, fix_mom_cons=False)("(-|2]⟨1|)/([1|2])"), field=field)
return numpy.append(ε1.four_mom, (field(0), ) * (D - 4))
ε1 = Particle(
Particles([Particle(momFlat, field=field), Particle(momχ, field=field)], field=field, fix_mom_cons=False)(
"(-|2]⟨1|)/([1|2])"
),
field=field,
)
return numpy.append(ε1.four_mom, (field(0),) * (D - 4))


def ε2(momD, momχ, field):
"""Corresponds to minus helicity in D=4"""
D, momFlat = len(momD), momflat(momD, momχ)
ε2 = Particle(Particles([Particle(momFlat, field=field), Particle(momχ, field=field)],
field=field, fix_mom_cons=False)("2(|1]⟨2|)/(⟨1|2⟩)"), field=field)
return numpy.append(ε2.four_mom, (field(0), ) * (D - 4))
ε2 = Particle(
Particles([Particle(momFlat, field=field), Particle(momχ, field=field)], field=field, fix_mom_cons=False)(
"2(|1]⟨2|)/(⟨1|2⟩)"
),
field=field,
)
return numpy.append(ε2.four_mom, (field(0),) * (D - 4))


def ε3(momD, momχ, field):
D, momFlat = len(momD), momflat(momD, momχ)
if D < 5:
raise ValueError(f"Not enough dimensions for ε3, need D>=5, was given D={D}.")
ε3 = Particle(Particles([Particle(momFlat, field=field), Particle(momχ, field=field),
Particle(momD[:4], field=field)], field=field, fix_mom_cons=False,
internal_masses={'μ2': μ2(momD)})("(|1]⟨1|)-μ2*|2]⟨2|/(⟨2|3|2])"), field=field)
return numpy.append(ε3.four_mom, (field(0), ) * (D - 4))
ε3 = Particle(
Particles(
[Particle(momFlat, field=field), Particle(momχ, field=field), Particle(momD[:4], field=field)],
field=field,
fix_mom_cons=False,
internal_masses={"μ2": μ2(momD)},
)("(|1]⟨1|)-μ2*|2]⟨2|/(⟨2|3|2])"),
field=field,
)
return numpy.append(ε3.four_mom, (field(0),) * (D - 4))


def ε3c(momD, momχ, field):
Expand All @@ -45,9 +57,9 @@ def ε4(momD):
D = len(momD)
if D < 6:
raise ValueError(f"Not enough dimensions for ε4, need D>=6, was given D={D}.")
ε4 = numpy.block([numpy.array([0, ] * 4),
numpy.array([[1, 0], [0, -1]]) @ momD[4:6][::-1],
numpy.array([0, ] * (D - 6))]) / μ2(momD, 6)
ε4 = numpy.block(
[numpy.array([0] * 4), numpy.array([[1, 0], [0, -1]]) @ momD[4:6][::-1], numpy.array([0] * (D - 6))]
) / μ2(momD, 6)
return ε4


Expand All @@ -60,10 +72,16 @@ def εxs(momD, x):
D = len(momD)
if D < 7:
raise ValueError(f"Not enough dimensions for εx, need D>=7, was given D={D}.")
return numpy.block([numpy.array([0, ] * 4),
numpy.array([momD[j] * momD[x + 1] for j in range(4, x + 1)] +
[-μ2(momD, x + 1)] + [0, ] * (D - x - 2))
]) / μ2(momD, x + 1) / μ2(momD, x + 2)
return (
numpy.block(
[
numpy.array([0] * 4),
numpy.array([momD[j] * momD[x + 1] for j in range(4, x + 1)] + [-μ2(momD, x + 1)] + [0] * (D - x - 2)),
]
)
/ μ2(momD, x + 1)
/ μ2(momD, x + 2)
)


def εxcs(momD, x):
Expand All @@ -82,22 +100,22 @@ def all_states(momD, momχ, field):
e1 = e2c = ε1(momD, momχ, field)
e2 = e1c = ε2(momD, momχ, field)

states = [e1, e2, ]
states_conj = [e1c, e2c, ]
states = [e1, e2]
states_conj = [e1c, e2c]

if D == 4:
return states, states_conj

e3, e3c = ε3(momD, momχ, field), ε3c(momD, momχ, field)
states += [e3, ]
states_conj += [e3c, ]
states += [e3]
states_conj += [e3c]

if D == 5:
return states, states_conj

e4, e4c = ε4(momD), ε4c(momD)
states += [e4, ]
states_conj += [e4c, ]
states += [e4]
states_conj += [e4c]

if D == 5:
return states, states_conj
Expand Down
56 changes: 56 additions & 0 deletions bgtrees/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import operator

import numpy
import tensorflow
Expand Down Expand Up @@ -54,3 +55,58 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


def _oinsum(eq, *arrays):
"""A ``einsum`` implementation for ``numpy`` object arrays."""
lhs, output = eq.split("->")
inputs = lhs.split(",")

sizes = {}
for term, array in zip(inputs, arrays):
for k, d in zip(term, array.shape):
sizes[k] = d

out_size = tuple(sizes[k] for k in output)
out = numpy.empty(out_size, dtype=object)

inner = [k for k in sizes if k not in output]
inner_size = [sizes[k] for k in inner]

for coo_o in numpy.ndindex(*out_size):
coord = dict(zip(output, coo_o))

def gen_inner_sum():
for coo_i in numpy.ndindex(*inner_size):
coord.update(dict(zip(inner, coo_i)))

locs = []
for term in inputs:
locs.append(tuple(coord[k] for k in term))

elements = []
for array, loc in zip(arrays, locs):
elements.append(array[loc])

yield functools.reduce(operator.mul, elements)

tmp = functools.reduce(operator.add, gen_inner_sum())
out[coo_o] = tmp

# if the output is made of finite fields, take them out
if isinstance(tmp, FiniteField) and len(out_size) == 0:
out = tmp
elif isinstance(tmp, FiniteField):
p = tmp.p

def unff(x):
if isinstance(x, FiniteField):
return x.n.numpy()
return x

vunff = numpy.vectorize(unff)

new_out = vunff(out)
out = FiniteField(new_out, p)

return out
1 change: 1 addition & 0 deletions tests/test_currents.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def test_D_dim_amplitude_vs_Caravel(verbose=False, nt=NTEST):
prev_setting = settings.use_gpu
settings.use_gpu = False

import ipdb; ipdb.set_trace()
res_cpu = numpy.einsum("rm,rm->r", lpols[:, 0], J_μ(lmoms[:, 1:], lpols[:, 1:], put_propagator=False, verbose=verbose))

settings.use_gpu = True
Expand Down
1 change: 1 addition & 0 deletions tests/test_finite_gpufields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tensorflow as tf

from bgtrees.finite_gpufields import FiniteField
from bgtrees.tools import _oinsum as oinsum
from bgtrees.finite_gpufields.operations import (
ff_dot_product,
ff_dot_product_single_batch,
Expand Down

0 comments on commit e55800f

Please sign in to comment.