Skip to content

Commit

Permalink
Fix assembly of split forms with subdomain ids (#3993)
Browse files Browse the repository at this point in the history
* Fix assembly of split forms with subdomain ids
  • Loading branch information
pbrubeck authored Jan 24, 2025
1 parent 2394d9e commit 58f3d6f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
15 changes: 10 additions & 5 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def parloops(self, tensor):
self._bcs,
local_kernel,
subdomain_id,
self.all_integer_subdomain_ids,
self.all_integer_subdomain_ids[local_kernel.indices],
diagonal=self.diagonal,
)
pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices)
Expand Down Expand Up @@ -1088,9 +1088,14 @@ def diagonal(self):

@cached_property
def all_integer_subdomain_ids(self):
return tsfc_interface.gather_integer_subdomain_ids(
{k for k, _ in self.local_kernels}
)
"""Return a dict mapping local_kernel.indices to all integer subdomain ids."""
all_indices = {k.indices for k, _ in self.local_kernels}
return {
i: tsfc_interface.gather_integer_subdomain_ids(
{k for k, _ in self.local_kernels if k.indices == i}
)
for i in all_indices
}

@abc.abstractmethod
def result(self, tensor):
Expand Down Expand Up @@ -1371,7 +1376,7 @@ def _make_maps_and_regions(self):
i, j = local_kernel.indices
mesh = all_meshes[local_kernel.kinfo.domain_number] # integration domain
integral_type = local_kernel.kinfo.integral_type
all_subdomain_ids = assembler.all_integer_subdomain_ids
all_subdomain_ids = assembler.all_integer_subdomain_ids[local_kernel.indices]
# Make Sparsity independent of the subdomain of integration for better reusability;
# subdomain_id is passed here only to determine the integration_type on the target domain
# (see ``entity_node_map``).
Expand Down
18 changes: 18 additions & 0 deletions tests/firedrake/regression/test_assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,21 @@ def test_assemble_power_zero_minmax():
g = Function(V).assign(2.)
assert assemble(zero()**min_value(f, g) * dx) == 0.0
assert assemble(zero()**max_value(f, g) * dx) == 0.0


def test_split_subdomain_ids():
mesh = UnitSquareMesh(1, 1)
q = Function(FunctionSpace(mesh, "DG", 0), dtype=int)
q.dat.data[1] = 1
rmesh = RelabeledMesh(mesh, (q,), (1,))

V = FunctionSpace(rmesh, "DG", 0)
Z = V * V
v0, v1 = TestFunctions(Z)

a = assemble(conj(v0)*dx + conj(v1)*dx)
b = assemble(conj(v0)*dx + conj(v1)*dx(1))

assert (a.dat[0].data == b.dat[0].data).all()
assert b.dat[1].data[0] == 0.0
assert b.dat[1].data[1] == a.dat[1].data[1]

0 comments on commit 58f3d6f

Please sign in to comment.