Skip to content

Commit

Permalink
Fixes for Slate
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed May 15, 2023
1 parent 644e5bd commit 11abb75
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 14 deletions.
8 changes: 6 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,8 +1247,12 @@ def iter_active_coefficients(form, kinfo):
@staticmethod
def iter_constants(form, kinfo): # FIXME kinfo not currently needed
from tsfc.ufl_utils import extract_firedrake_constants
for const in extract_firedrake_constants(form):
yield const
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
else:
for const in extract_firedrake_constants(form):
yield const

@staticmethod
def index_function_spaces(form, indices):
Expand Down
49 changes: 38 additions & 11 deletions firedrake/slate/slac/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from loopy.symbolic import SubArrayRef
import pymbolic.primitives as pym

from firedrake.constant import Constant
import firedrake.slate.slate as slate
from firedrake.slate.slac.tsfc_driver import compile_terminal_form

Expand Down Expand Up @@ -107,13 +108,16 @@ def shape(self, tensor):
else:
return tensor.shape

def extent(self, coefficient):
""" Calculation of the range of a coefficient."""
element = coefficient.ufl_element()
if element.family() == "Real":
return (coefficient.dat.cdim, )
def extent(self, argument):
""" Calculation of the range of a constant or coefficient."""
if isinstance(argument, Constant):
return (argument.dat.cdim, )
else:
return (create_element(element).space_dimension(), )
element = argument.ufl_element()
if element.family() == "Real":
return (argument.dat.cdim, )
else:
return (create_element(element).space_dimension(), )

def generate_lhs(self, tensor, temp):
""" Generation of an lhs for the loopy kernel,
Expand All @@ -123,7 +127,7 @@ def generate_lhs(self, tensor, temp):
lhs = pym.Subscript(temp, idx)
return SubArrayRef(idx, lhs)

def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, wrapper_coefficients, kinfo):
def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrapper_coefficients, wrapper_constants, kinfo):
""" Collect the kernel data aka the parameters fed into the subkernel,
that are coordinates, orientations, cell sizes and cofficients.
"""
Expand All @@ -147,11 +151,14 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, wrapper_coefficients
ind, = tsfc_coefficients[c]
if ind != 0:
raise ValueError(f"Active indices of non-mixed function must be (0, ), not {tsfc_coefficients[c]}")
kernel_data.extend([(c, cinfo[0])])
kernel_data.append((c, cinfo[0]))
else:
for ind, (c_, info) in enumerate(cinfo.items()):
if ind in tsfc_coefficients[c]:
kernel_data.extend([(c_, info[0])])
kernel_data.append((c_, info[0]))

# Pick the constants associated with a Tensor()/TSFC kernel
kernel_data.extend([(c, c.name) for c in wrapper_constants if c in tsfc_constants])
return kernel_data

def loopify_tsfc_kernel_data(self, kernel_data):
Expand Down Expand Up @@ -244,6 +251,10 @@ def collect_coefficients(self):
coeff_dict[coeff] = (f"w_{i}", self.extent(coeff))
return coeff_dict

def collect_constants(self):
""" All constants of self.expression as a list """
return self.expression.constants()

def initialise_terminals(self, var2terminal, coefficients):
""" Initilisation of the variables in which coefficients
and the Tensors coming from TSFC are saved.
Expand Down Expand Up @@ -349,6 +360,14 @@ def generate_wrapper_kernel_args(self, tensor2temp):
dtype=self.tsfc_parameters["scalar_type"])
args.append(kernel_args.CoefficientKernelArg(coeff_loopy_arg))

for constant in self.bag.constants:
constant_loopy_arg = loopy.GlobalArg(
constant.name,
shape=constant.dat.cdim,
dtype=self.tsfc_parameters["scalar_type"]
)
args.append(kernel_args.ConstantKernelArg(constant_loopy_arg))

if self.bag.needs_cell_facets:
# Arg for is exterior (==0)/interior (==1) facet or not
facet_loopy_arg = loopy.GlobalArg(self.cell_facets_arg_name,
Expand Down Expand Up @@ -398,7 +417,14 @@ def generate_tsfc_calls(self, terminal, loopy_tensor):
output_var = pym.Variable(loopy_tensor.name)
reads.append(output_var)
output = self.generate_lhs(slate_tensor, output_var)
kernel_data = self.collect_tsfc_kernel_data(mesh, cxt_kernel.coefficients, self.bag.coefficients, kinfo)
kernel_data = self.collect_tsfc_kernel_data(
mesh,
cxt_kernel.coefficients,
cxt_kernel.constants,
self.bag.coefficients,
self.bag.constants,
kinfo
)
reads.extend(self.loopify_tsfc_kernel_data(kernel_data))

# Generate predicates for different integral types
Expand Down Expand Up @@ -427,8 +453,9 @@ def generate_tsfc_calls(self, terminal, loopy_tensor):

class SlateWrapperBag:

def __init__(self, coeffs):
def __init__(self, coeffs, constants):
self.coefficients = coeffs
self.constants = constants
self.inames = OrderedDict()
self.needs_cell_orientations = False
self.needs_cell_sizes = False
Expand Down
5 changes: 5 additions & 0 deletions firedrake/slate/slac/tsfc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from firedrake.slate.slac.utils import RemoveRestrictions
from firedrake.tsfc_interface import compile_form as tsfc_compile

from tsfc.ufl_utils import extract_firedrake_constants

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl import Form


ContextKernel = collections.namedtuple("ContextKernel",
["tensor",
"coefficients",
"constants",
"original_integral_type",
"tsfc_kernels"])
ContextKernel.__doc__ = """\
Expand All @@ -22,6 +25,7 @@
list of TSFC assembly kernels.
:param coefficients: The local coefficients of the tensor contained
in the integrands (arguments for TSFC subkernels).
:param constants: The local constants of the tensor contained in the integrands.
:param original_integral_type: The unmodified measure type
of the form integrals.
:param tsfc_kernels: A list of local tensor assembly kernels
Expand Down Expand Up @@ -65,6 +69,7 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None):
if kernels:
cxt_k = ContextKernel(tensor=tensor,
coefficients=form.coefficients(),
constants=extract_firedrake_constants(form),
original_integral_type=orig_it_type,
tsfc_kernels=kernels)
cxt_kernels.append(cxt_k)
Expand Down
10 changes: 9 additions & 1 deletion firedrake/slate/slac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def merge_loopy(slate_loopy, output_arg, builder, var2terminal, name):
""" Merges tsfc loopy kernels and slate loopy kernel into a wrapper kernel."""
from firedrake.slate.slac.kernel_builder import SlateWrapperBag
coeffs = builder.collect_coefficients()
builder.bag = SlateWrapperBag(coeffs)
constants = builder.collect_constants()
builder.bag = SlateWrapperBag(coeffs, constants)

# In the initialisation the loopy tensors for the terminals are generated
# Those are the needed again for generating the TSFC calls
Expand Down Expand Up @@ -224,6 +225,13 @@ def merge_loopy(slate_loopy, output_arg, builder, var2terminal, name):
# Add profiling for the whole kernel
insns, slate_wrapper_event, preamble = profile_insns(name, insns, PETSc.Log.isActive())

# Add a no-op touching all kernel arguments to make sure they are not
# silently dropped
noop = lp.CInstruction(
(), "", read_variables=frozenset({a.name for a in loopy_args}),
within_inames=frozenset(), within_inames_is_final=True)
insns.append(noop)

# Inames come from initialisations + loopyfying kernel args and lhs
domains = builder.bag.index_creator.domains

Expand Down
38 changes: 38 additions & 0 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

from firedrake.formmanipulation import ExtractSubBlock

from tsfc.ufl_utils import extract_firedrake_constants


__all__ = ['AssembledVector', 'Block', 'Factorization', 'Tensor',
'Inverse', 'Transpose', 'Negative',
Expand Down Expand Up @@ -213,6 +215,10 @@ def rank(self):
def coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""

@abstractmethod
def constants(self):
"""Returns a tuple of constants associated with the tensor."""

@abstractmethod
def slate_coefficients(self):
"""Returns a tuple of Slate coefficients associated with the tensor."""
Expand Down Expand Up @@ -463,6 +469,9 @@ def coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""
return (self._function,)

def constants(self):
return ()

def slate_coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""
return self.coefficients()
Expand Down Expand Up @@ -707,6 +716,11 @@ def coefficients(self):
tensor, = self.operands
return tensor.coefficients()

def constants(self):
"""Returns a tuple of constants associated with the tensor."""
tensor, = self.operands
return tensor.constants()

def slate_coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""
return self.coefficients()
Expand Down Expand Up @@ -797,6 +811,11 @@ def coefficients(self):
tensor, = self.operands
return tensor.coefficients()

def constants(self):
"""Returns a tuple of constants associated with the tensor."""
tensor, = self.operands
return tensor.constants()

def slate_coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""
return self.coefficients()
Expand Down Expand Up @@ -896,6 +915,10 @@ def coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""
return self.form.coefficients()

def constants(self):
"""Returns a tuple of constants associated with the tensor."""
return unique(extract_firedrake_constants(self.form))

def slate_coefficients(self):
"""Returns a tuple of coefficients associated with the tensor."""
return self.coefficients()
Expand Down Expand Up @@ -944,6 +967,11 @@ def coefficients(self):
coeffs = [op.coefficients() for op in self.operands]
return tuple(OrderedDict.fromkeys(chain(*coeffs)))

def constants(self):
"""Returns a tuple of constants associated with the tensor."""
const = [op.constants() for op in self.operands]
return unique(chain(*const))

def slate_coefficients(self):
"""Returns the expected coefficients of the resulting tensor."""
coeffs = [op.slate_coefficients() for op in self.operands]
Expand Down Expand Up @@ -1360,6 +1388,16 @@ def space_equivalence(A, B):
return A.mesh() == B.mesh() and A.ufl_element() == B.ufl_element()


def unique(iterable):
""" Return tuple of unique items in iterable, items must be hashable
"""
# Use dict to preserve order and compare by hash
unique_dict = {}
for item in iterable:
unique_dict[item] = None
return tuple(unique_dict.keys())


# Establishes levels of precedence for Slate tensors
precedences = [
[AssembledVector, Block, Factorization, Tensor, DiagonalTensor, Reciprocal],
Expand Down

0 comments on commit 11abb75

Please sign in to comment.