From d5bb1078dfba9dc5fa218b48f095498fda362e7b Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Tue, 5 Nov 2024 12:28:59 +0100 Subject: [PATCH] Add support for saving and loading quadrature function --- src/ldrb/io.py | 23 +++++++++++++++++++---- src/ldrb/utils.py | 48 ++++++++++++++++++++++++++++++++++++----------- tests/test_io.py | 20 +++++++++++++++----- 3 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/ldrb/io.py b/src/ldrb/io.py index 2bd6b12..ccf3730 100644 --- a/src/ldrb/io.py +++ b/src/ldrb/io.py @@ -5,9 +5,11 @@ from mpi4py import MPI +import adios2 import adios4dolfinx import dolfinx import numpy as np +from packaging.version import Version from . import utils @@ -62,7 +64,7 @@ def save( if i == 0: adios4dolfinx.write_mesh(mesh=f.function_space.mesh, filename=filename) adios4dolfinx.write_function_on_input_mesh(u=f, filename=filename) - attributes[name] = utils.element2array(f.ufl_element().basix_element) + attributes[name] = utils.element2array(f.ufl_element()) adios4dolfinx.write_attributes( comm=comm, @@ -76,6 +78,7 @@ def load( comm: MPI.Comm, filename: Path, mesh: dolfinx.mesh.Mesh | None = None, + function_space: dict[str, np.ndarray] | None = None, ) -> dict[str, dolfinx.fem.Function]: if not Path(filename).exists(): raise FileNotFoundError(f"File {filename} does not exist") @@ -83,9 +86,21 @@ def load( if mesh is None: mesh = adios4dolfinx.read_mesh(comm=comm, filename=filename) - function_space = adios4dolfinx.read_attributes( - comm=comm, filename=filename, name="function_space" - ) + if Version(np.__version__) >= Version("2.11") and Version(adios2.__version__) < Version( + "2.10.2" + ): + # Broken on new numpy and old adios2 + function_space = adios4dolfinx.read_attributes( + comm=comm, filename=filename, name="function_space" + ) + else: + if not function_space: + raise ValueError( + "function_space must be provided if numpy version is lower " + "than 1.21.0 and adios2 version is lower than 2.10." + ) + assert function_space is not None + # Assume same function space for all functions functions = {} for key, value in function_space.items(): diff --git a/src/ldrb/utils.py b/src/ldrb/utils.py index 4c64b2d..4e5b8d7 100644 --- a/src/ldrb/utils.py +++ b/src/ldrb/utils.py @@ -5,6 +5,8 @@ import dolfinx import numpy as np +QUADRATURE_FAMILY = 100 + def default_markers() -> dict[str, list[int]]: """ @@ -60,9 +62,24 @@ def space_from_string( return dolfinx.fem.functionspace(mesh, el) -def element2array(el: basix.finite_element.FiniteElement) -> np.ndarray: +def element2array(el: basix.ufl._BlockedElement) -> np.ndarray: + try: + el = el.basix_element + family = int(el.family) + cell_type = int(el.cell_type) + degree = int(el.degree) + discontinuous = int(el.discontinuous) + + except NotImplementedError: + assert el.family_name == "quadrature" + + family = QUADRATURE_FAMILY + cell_type = int(el.cell_type) + degree = int(el.degree) + discontinuous = int(el.discontinuous) + return np.array( - [int(el.family), int(el.cell_type), int(el.degree), int(el.discontinuous)], + [family, cell_type, degree, discontinuous], dtype=np.uint8, ) @@ -75,15 +92,24 @@ def number2Enum(num: int, enum: Iterable) -> Enum: def array2element(arr: np.ndarray) -> basix.finite_element.FiniteElement: - family = number2Enum(arr[0], basix.ElementFamily) cell_type = number2Enum(arr[1], basix.CellType) degree = int(arr[2]) discontinuous = bool(arr[3]) - # TODO: Shape is hardcoded to (3,) for now, but this should also be stored - return basix.ufl.element( - family=family, - cell=cell_type, - degree=degree, - discontinuous=discontinuous, - shape=(3,), - ) + if arr[0] == QUADRATURE_FAMILY: + return basix.ufl.quadrature_element( + scheme="default", + cell=cell_type, + degree=degree, + value_shape=(3,), + ) + else: + family = number2Enum(arr[0], basix.ElementFamily) + + # TODO: Shape is hardcoded to (3,) for now, but this should also be stored + return basix.ufl.element( + family=family, + cell=cell_type, + degree=degree, + discontinuous=discontinuous, + shape=(3,), + ) diff --git a/tests/test_io.py b/tests/test_io.py index 68d59b5..abe70d3 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -6,10 +6,9 @@ import pytest -@pytest.mark.parametrize("space1", ["P_1", "P_2", "dP_0", "dP_1"]) -@pytest.mark.parametrize("space2", ["P_1", "P_2", "dP_0", "dP_1"]) +@pytest.mark.parametrize("space1", ["P_1", "P_2", "dP_0", "dP_1", "Q_2"]) +@pytest.mark.parametrize("space2", ["P_1", "P_2", "dP_0", "dP_1", "Q_2"]) def test_save_load(tmp_path, space1, space2): - # FIXME: Make it work for Quadrature spaces mesh = dolfinx.mesh.create_unit_cube(comm=MPI.COMM_WORLD, nx=3, ny=3, nz=3) U = ldrb.utils.space_from_string(space1, mesh=mesh, dim=3) u = dolfinx.fem.Function(U, name="u") @@ -18,10 +17,21 @@ def test_save_load(tmp_path, space1, space2): V = ldrb.utils.space_from_string(space2, mesh=mesh, dim=3) v = dolfinx.fem.Function(V, name="v") v.interpolate(lambda x: -x) + functions = [u, v] filename = tmp_path / "test_save_load.bp" - ldrb.io.save(comm=mesh.comm, filename=filename, functions=functions) - loaded_functions = ldrb.io.load(comm=mesh.comm, filename=filename, mesh=mesh) + function_space = { + "u": ldrb.utils.element2array(u.ufl_element()), + "v": ldrb.utils.element2array(v.ufl_element()), + } + ldrb.io.save( + comm=mesh.comm, + filename=filename, + functions=functions, + ) + loaded_functions = ldrb.io.load( + comm=mesh.comm, filename=filename, mesh=mesh, function_space=function_space + ) assert len(loaded_functions) == 2 assert loaded_functions["u"].name == "u"