Skip to content

Commit

Permalink
Add support for saving and loading quadrature function
Browse files Browse the repository at this point in the history
  • Loading branch information
finsberg committed Nov 5, 2024
1 parent 0621e4d commit d5bb107
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 20 deletions.
23 changes: 19 additions & 4 deletions src/ldrb/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from mpi4py import MPI

import adios2
import adios4dolfinx
import dolfinx
import numpy as np
from packaging.version import Version

from . import utils

Expand Down Expand Up @@ -62,7 +64,7 @@ def save(
if i == 0:
adios4dolfinx.write_mesh(mesh=f.function_space.mesh, filename=filename)
adios4dolfinx.write_function_on_input_mesh(u=f, filename=filename)
attributes[name] = utils.element2array(f.ufl_element().basix_element)
attributes[name] = utils.element2array(f.ufl_element())

adios4dolfinx.write_attributes(
comm=comm,
Expand All @@ -76,16 +78,29 @@ def load(
comm: MPI.Comm,
filename: Path,
mesh: dolfinx.mesh.Mesh | None = None,
function_space: dict[str, np.ndarray] | None = None,
) -> dict[str, dolfinx.fem.Function]:
if not Path(filename).exists():
raise FileNotFoundError(f"File {filename} does not exist")

if mesh is None:
mesh = adios4dolfinx.read_mesh(comm=comm, filename=filename)

function_space = adios4dolfinx.read_attributes(
comm=comm, filename=filename, name="function_space"
)
if Version(np.__version__) >= Version("2.11") and Version(adios2.__version__) < Version(
"2.10.2"
):
# Broken on new numpy and old adios2
function_space = adios4dolfinx.read_attributes(
comm=comm, filename=filename, name="function_space"
)
else:
if not function_space:
raise ValueError(
"function_space must be provided if numpy version is lower "
"than 1.21.0 and adios2 version is lower than 2.10."
)
assert function_space is not None

# Assume same function space for all functions
functions = {}
for key, value in function_space.items():
Expand Down
48 changes: 37 additions & 11 deletions src/ldrb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import dolfinx
import numpy as np

QUADRATURE_FAMILY = 100


def default_markers() -> dict[str, list[int]]:
"""
Expand Down Expand Up @@ -60,9 +62,24 @@ def space_from_string(
return dolfinx.fem.functionspace(mesh, el)


def element2array(el: basix.finite_element.FiniteElement) -> np.ndarray:
def element2array(el: basix.ufl._BlockedElement) -> np.ndarray:
try:
el = el.basix_element
family = int(el.family)
cell_type = int(el.cell_type)
degree = int(el.degree)
discontinuous = int(el.discontinuous)

except NotImplementedError:
assert el.family_name == "quadrature"

family = QUADRATURE_FAMILY
cell_type = int(el.cell_type)
degree = int(el.degree)
discontinuous = int(el.discontinuous)

return np.array(
[int(el.family), int(el.cell_type), int(el.degree), int(el.discontinuous)],
[family, cell_type, degree, discontinuous],
dtype=np.uint8,
)

Expand All @@ -75,15 +92,24 @@ def number2Enum(num: int, enum: Iterable) -> Enum:


def array2element(arr: np.ndarray) -> basix.finite_element.FiniteElement:
family = number2Enum(arr[0], basix.ElementFamily)
cell_type = number2Enum(arr[1], basix.CellType)
degree = int(arr[2])
discontinuous = bool(arr[3])
# TODO: Shape is hardcoded to (3,) for now, but this should also be stored
return basix.ufl.element(
family=family,
cell=cell_type,
degree=degree,
discontinuous=discontinuous,
shape=(3,),
)
if arr[0] == QUADRATURE_FAMILY:
return basix.ufl.quadrature_element(
scheme="default",
cell=cell_type,
degree=degree,
value_shape=(3,),
)
else:
family = number2Enum(arr[0], basix.ElementFamily)

# TODO: Shape is hardcoded to (3,) for now, but this should also be stored
return basix.ufl.element(
family=family,
cell=cell_type,
degree=degree,
discontinuous=discontinuous,
shape=(3,),
)
20 changes: 15 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import pytest


@pytest.mark.parametrize("space1", ["P_1", "P_2", "dP_0", "dP_1"])
@pytest.mark.parametrize("space2", ["P_1", "P_2", "dP_0", "dP_1"])
@pytest.mark.parametrize("space1", ["P_1", "P_2", "dP_0", "dP_1", "Q_2"])
@pytest.mark.parametrize("space2", ["P_1", "P_2", "dP_0", "dP_1", "Q_2"])
def test_save_load(tmp_path, space1, space2):
# FIXME: Make it work for Quadrature spaces
mesh = dolfinx.mesh.create_unit_cube(comm=MPI.COMM_WORLD, nx=3, ny=3, nz=3)
U = ldrb.utils.space_from_string(space1, mesh=mesh, dim=3)
u = dolfinx.fem.Function(U, name="u")
Expand All @@ -18,10 +17,21 @@ def test_save_load(tmp_path, space1, space2):
V = ldrb.utils.space_from_string(space2, mesh=mesh, dim=3)
v = dolfinx.fem.Function(V, name="v")
v.interpolate(lambda x: -x)

functions = [u, v]
filename = tmp_path / "test_save_load.bp"
ldrb.io.save(comm=mesh.comm, filename=filename, functions=functions)
loaded_functions = ldrb.io.load(comm=mesh.comm, filename=filename, mesh=mesh)
function_space = {
"u": ldrb.utils.element2array(u.ufl_element()),
"v": ldrb.utils.element2array(v.ufl_element()),
}
ldrb.io.save(
comm=mesh.comm,
filename=filename,
functions=functions,
)
loaded_functions = ldrb.io.load(
comm=mesh.comm, filename=filename, mesh=mesh, function_space=function_space
)
assert len(loaded_functions) == 2

assert loaded_functions["u"].name == "u"
Expand Down

0 comments on commit d5bb107

Please sign in to comment.