diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a0dc852..e08a70e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,11 +15,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- (huggingface bridge) restructuration: do not rely on binary blobs anymore, but exploit native arrow types by flattening cgns trees into constant and variable parts. - (sample) Restructuring of the Sample class to store a global (tensor of arbitrary order) at a given time step: replaces scalar and time_series. All Sample data are now stored in CGNS trees. ### Fixes - (meshes) fix `get_field_name`, could overwrite arguments during iteration over times, bases, zones and locations. +- (docs) explain release process in Contributing page. ### Removed diff --git a/examples/bridges/check_retrocomp_benchmarks.py b/examples/bridges/check_retrocomp_benchmarks.py new file mode 100644 index 00000000..bc159529 --- /dev/null +++ b/examples/bridges/check_retrocomp_benchmarks.py @@ -0,0 +1,15 @@ +"""This files serves to check if the main retrieval command in the PLAID Benchmarks +is not returning an error.""" + +from plaid.bridges import huggingface_bridge + +hf_dataset = huggingface_bridge.load_dataset_from_hub( + f"PLAID-datasets/Tensile2d", split="all_samples[:5]", num_proc=1 +) + +plaid_dataset, pb_def = huggingface_bridge.huggingface_dataset_to_plaid( + hf_dataset, processes_number=1, verbose=True +) + +ids_train = pb_def.get_split('train_500') +sample_train_0 = plaid_dataset[ids_train[0]] \ No newline at end of file diff --git a/examples/bridges/huggingface_example.py b/examples/bridges/huggingface_example.py index f1b22f67..fa9ad021 100644 --- a/examples/bridges/huggingface_example.py +++ b/examples/bridges/huggingface_example.py @@ -1,12 +1,14 @@ +# -*- coding: utf-8 -*- # --- # jupyter: # jupytext: +# custom_cell_magics: kql # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.17.3 +# jupytext_version: 1.11.2 # kernelspec: # display_name: plaid-dev # language: python @@ -30,15 +32,18 @@ # %% # Import necessary libraries and functions import pickle +import tempfile +import shutil +from time import time import numpy as np from Muscat.Bridges.CGNSBridge import MeshToCGNS from Muscat.MeshTools import MeshCreationTools as MCT from plaid.bridges import huggingface_bridge -from plaid import Dataset -from plaid import Sample -from plaid import ProblemDefinition +from plaid import Dataset, Sample, ProblemDefinition +from plaid.types import FeatureIdentifier +from plaid.utils.base import get_mem # %% @@ -51,7 +56,7 @@ def show_sample(sample: Sample): # %% [markdown] -# ## Initialize plaid dataset and problem_definition +# ## Initialize plaid dataset, infos and problem_definition # %% # Input data @@ -73,21 +78,23 @@ def show_sample(sample: Sample): ] ) - dataset = Dataset() +scalar_feat_id = FeatureIdentifier({"type": "scalar", "name": "scalar"}) +node_field_feat_id = FeatureIdentifier({"type": "field", "name": "node_field", "location": "Vertex"}) +cell_field_feat_id = FeatureIdentifier({"type": "field", "name": "cell_field", "location": "CellCenter"}) + print("Creating meshes dataset...") for _ in range(3): mesh = MCT.CreateMeshOfTriangles(points, triangles) sample = Sample() - sample.features.add_tree(MeshToCGNS(mesh)) - sample.add_scalar("scalar", np.random.randn()) - sample.add_field("node_field", np.random.rand(len(points)), location="Vertex") - sample.add_field( - "cell_field", np.random.rand(len(triangles)), location="CellCenter" - ) + sample.add_tree(MeshToCGNS(mesh, exportOriginalIDs = False)) + + sample.update_features_from_identifier(scalar_feat_id, np.random.randn(), in_place=True) + sample.update_features_from_identifier(node_field_feat_id, np.random.rand(len(points)), in_place=True) + sample.update_features_from_identifier(cell_field_feat_id, np.random.rand(len(triangles)), in_place=True) dataset.add_sample(sample) @@ -99,214 +106,271 @@ def show_sample(sample: Sample): dataset.set_infos(infos) print(f" {dataset = }") +print(f" {infos = }") -problem = ProblemDefinition() -problem.add_output_scalars_names(["scalar"]) -problem.add_output_fields_names(["node_field", "cell_field"]) -problem.add_input_meshes_names(["/Base/Zone"]) +pb_def = ProblemDefinition() +pb_def.add_in_features_identifiers([scalar_feat_id, node_field_feat_id]) +pb_def.add_out_features_identifiers([cell_field_feat_id]) -problem.set_task("regression") -problem.set_split({"train": [0, 1], "test": [2]}) +pb_def.set_task("regression") +pb_def.set_split({"train": [0, 1], "test": [2]}) -print(f" {problem = }") +print(f" {pb_def = }") # %% [markdown] -# ## Section 1: Convert plaid dataset to Hugging Face -# -# The description field of Hugging Face dataset is automatically configured to include data from the plaid dataset info and problem_definition to prevent loss of information and equivalence of format. +# ## Section 1: Convert plaid datasets to Hugging Face DatasetDict # %% -hf_dataset = huggingface_bridge.plaid_dataset_to_huggingface(dataset, problem) -print() -print(f"{hf_dataset = }") -print(f"{hf_dataset.description = }") +main_splits = { + split_name: pb_def.get_split(split_name) for split_name in ["train", "test"] +} -# %% [markdown] -# The previous code generates a Hugging Face dataset containing all the samples from the plaid dataset, the splits being defined in the hf_dataset descriptions. For splits, Hugging Face proposes `DatasetDict`, which are dictionaries of hf datasets, with keys being the name of the corresponding splits. It is possible de generate a hf datasetdict directly from plaid: +hf_datasetdict, flat_cst, key_mappings = huggingface_bridge.plaid_dataset_to_huggingface_datasetdict(dataset, main_splits) -# %% -hf_datasetdict = huggingface_bridge.plaid_dataset_to_huggingface_datasetdict(dataset, problem, main_splits = ['train', 'test']) -print() -print(f"{hf_datasetdict['train'] = }") -print(f"{hf_datasetdict['test'] = }") +print(f"{hf_datasetdict = }") +print(f"{flat_cst = }") +print(f"{key_mappings = }") +# %% [markdown] +# A partitioning of all the indices is provided in `main_splits`. The conversion outputs `flat_cst` and `key_mappings`, which are central to the Hugging Face support: +# - **`flat_cst`**: constant features dictionary (path → value): a flatten tree containing the CGNS trees leaves that a reconstant throughout the plaid dataset. +# - **`key_mappings`**: metadata dictionary containing keys such as: +# - `variable_features`: list of paths for non-constant features. +# - `constant_features`: list of paths for constant features. +# - `cgns_types`: mapping from paths to CGNS types. +# +# `flat_cst` and `cgns_types` are required for reconstructing plaid datasets and samples from the hugginface datasets. # %% [markdown] # ## Section 2: Generate a Hugging Face dataset with a generator -# %% -def generator(): - for id in range(len(dataset)): - yield { - "sample": pickle.dumps(dataset[id]), - } - - -hf_dataset_gen = huggingface_bridge.plaid_generator_to_huggingface( - generator, infos, problem -) -print() -print(f"{hf_dataset_gen = }") -print(f"{hf_dataset_gen.description = }") - # %% [markdown] -# The same is available with datasetdict: +# Ganarators are used to handle large datasets that do not fit in memory: # %% -hf_datasetdict_gen = huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator, infos, problem, main_splits = ['train', 'test'] +generators = {} +for split_name, ids in main_splits.items(): + def generator_(ids=ids): + for id in ids: + yield dataset[id] + generators[split_name] = generator_ + +hf_datasetdict, flat_cst, key_mappings = ( + huggingface_bridge.plaid_generator_to_huggingface_datasetdict( + generators + ) ) -print() -print(f"{hf_datasetdict['train'] = }") -print(f"{hf_datasetdict['test'] = }") +print(f"{hf_datasetdict = }") +print(f"{flat_cst = }") +print(f"{key_mappings = }") + +# %% [markdown] +# In this example, the generators are not very usefull since the plaid dataset is already loaded in memory. In real settings, one can create generators in the following way to prevent loading all the data beforehand: +# ```python +# generators = {} +# for split_name, ids in main_splits.items(): +# def generator_(ids=ids): +# for id in ids: +# loaded_simulation_data = load('path/to/split_name/simulation_id') +# sample = convert_to_sample(loaded_simulation_data) +# yield sample +# generators[split_name] = generator_ +# ``` # %% [markdown] # ## Section 3: Convert a Hugging Face dataset to plaid -# -# Plaid dataset infos and problem_defitinion are recovered from the huggingface dataset # %% -dataset_2, problem_2 = huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset) +cgns_types = key_mappings["cgns_types"] + +dataset_2 = huggingface_bridge.to_plaid_dataset(hf_datasetdict['train'], flat_cst, cgns_types) print() print(f"{dataset_2 = }") -print(f"{dataset_2.get_infos() = }") -print(f"{problem_2 = }") # %% [markdown] # ## Section 4: Save and Load Hugging Face datasets # # ### From and to disk +# +# Saving and loading datasetdict, infos, tree_struct and problem definition to disk: # %% -# Save to disk -hf_dataset.save_to_disk("/tmp/path/to/dir") +with tempfile.TemporaryDirectory() as out_dir: -# %% -# Load from disk -from datasets import load_from_disk + huggingface_bridge.save_dataset_dict_to_disk(out_dir, hf_datasetdict) + huggingface_bridge.save_infos_to_disk(out_dir, infos) + huggingface_bridge.save_tree_struct_to_disk(out_dir, flat_cst, key_mappings) + huggingface_bridge.save_problem_definition_to_disk(out_dir, "task_1", pb_def) -loaded_hf_dataset = load_from_disk("/tmp/path/to/dir") + loaded_hf_datasetdict = huggingface_bridge.load_dataset_from_disk(out_dir) + loaded_infos = huggingface_bridge.load_infos_from_disk(out_dir) + flat_cst, key_mappings = huggingface_bridge.load_tree_struct_from_disk(out_dir) + loaded_pb_def = huggingface_bridge.load_problem_definition_from_disk(out_dir, "task_1") -print() -print(f"{loaded_hf_dataset = }") -print(f"{loaded_hf_dataset.description = }") + shutil.rmtree(out_dir) + +print(f"{loaded_hf_datasetdict = }") +print(f"{loaded_infos = }") +print(f"{flat_cst = }") +print(f"{key_mappings = }") +print(f"{loaded_pb_def = }") # %% [markdown] # ### From and to the Hugging Face hub # -# You need an huggingface account, with a configured access token, and to install huggingface_hub[cli]. -# Pushing and loading a huggingface dataset without loss of information requires the configuration of a DatasetCard. -# -# Find below example of instruction (not executed by this notebook). -# -# ### Push to the hub +# Find below examples of instructions (not executed by this notebook). + +# %% [markdown] +# #### Load from hub # -# First login the huggingface cli: -# ```bash -# huggingface-cli login +# To load datasetdict, infos and problem_definitions from the hub: +# ```python +# huggingface_bridge.load_dataset_from_hub("chanel/dataset", *args, **kwargs) +# huggingface_bridge.load_hf_infos_from_hub("chanel/dataset") +# huggingface_bridge.load_hf_problem_definition_from_hub("chanel/dataset", "name") # ``` -# and enter you access token. # -# Then, the following python instruction enable pushing a dataset to the hub: +# Partial retrieval are possible along samples # ```python -# hf_dataset.push_to_hub("chanel/dataset") -# -# from datasets import load_dataset_builder -# -# datasetInfo = load_dataset_builder("chanel/dataset").__getstate__()['info'] -# -# from huggingface_hub import DatasetCard -# -# card_text = create_string_for_huggingface_dataset_card( -# description = description, -# download_size_bytes = datasetInfo.download_size, -# dataset_size_bytes = datasetInfo.dataset_size, -# ...) -# dataset_card = DatasetCard(card_text) -# dataset_card.push_to_hub("chanel/dataset") +# huggingface_bridge.load_dataset_from_hub("chanel/dataset", split="train[:10], *args, **kwargs) # ``` # -# The second upload of the dataset_card is required to ensure that load_dataset from the hub will populate -# the hf-dataset.description field, and be compatible for conversion to plaid. Wihtout a dataset_card, the description field is lost. -# -# -# ### Load from hub -# -# #### General case -# +# Streaming allows handling very large datasets # ```python -# dataset = load_dataset("chanel/dataset", split="all_samples") +# hf_dataset_streamed = huggingface_bridge.load_dataset_from_hub("chanel/dataset", split="split", streaming=True, *args, **kwargs) +# for hf_sample in hf_dataset_streamed: +# sample = huggingface_bridge.to_plaid_sample(hf_sample, flat_cst, cgns_types) # ``` # -# More efficient retrieval are made possible by partial loads and split loads (in the case of a datasetdict): +# Native HF datasets commands are also possible: # # ```python # dataset_train = load_dataset("chanel/dataset", split="train") +# dataset_train = load_dataset("chanel/dataset", split="train", streaming=True) # dataset_train_extract = load_dataset("chanel/dataset", split="train[:10]") # ``` # -# #### Proxy +# If you are behind a proxy and relying on a private mirror the function `load_dataset_from_hub` is working provided the following is set: +# - `HF_ENDPOINT` to your private mirror address +# - `CURL_CA_BUNDLE` to your trusted CA certificates +# - `HF_HOME` to a shared cache directory if needed + +# %% [markdown] +# #### Push to the hub # -# A retrieval function robust to cases where you are behind a proxy and relying on a private mirror is avalable; +# To push a dataset on the Hub, you need an huggingface account, with a configured access token. # +# First login the huggingface cli: +# ```bash +# huggingface-cli login +# ``` +# and enter you access token. +# +# Then, the following python instruction enable pushing datasetdict, infos and problem_definitions to the hub: # ```python -# from plaid.bridges.huggingface_bridge import load_hf_dataset_from_hub -# hf_dataset = load_hf_dataset_from_hub("chanel/dataset", *args, **kwargs) +# huggingface_bridge.push_dataset_dict_to_hub("chanel/dataset", hf_dataset_dict) +# huggingface_bridge.push_dataset_infos_to_hub("chanel/dataset", infos) +# huggingface_bridge.push_tree_struct_to_hub("chanel/dataset", flat_cst, key_mappings) +# huggingface_bridge.push_problem_definition_to_hub("chanel/dataset", pb_def, "location") # ``` # -# - Streaming mode is not supported when using a private mirror. -# - Falls back to local download if streaming or public loading fails. -# - To use behind a proxy, you may need to set: -# - `HF_ENDPOINT` to your private mirror address -# - `CURL_CA_BUNDLE` to your trusted CA certificates -# - `HF_HOME` to a shared cache directory if needed - - +# The dataset card can then be customized online, on the dataset repo page directly. # %% [markdown] # ## Section 5: Handle plaid samples from Hugging Face datasets without converting the complete dataset to plaid # -# To fully exploit optimzed data handling of the Hugging Face datasets library, it is possible to extract information from the huggingface dataset without converting to plaid. The ``description`` atttribute includes the plaid dataset _infos attribute and plaid problem_definition attributes. +# To fully exploit optimzed data handling of the Hugging Face datasets library, it is possible to extract information from the huggingface dataset without converting to plaid. -# %% -print(f"{loaded_hf_dataset.description = }") # %% [markdown] # Get the first sample of the first split # %% -split_names = list(loaded_hf_dataset.description["split"].keys()) -id = loaded_hf_dataset.description["split"][split_names[0]] -hf_sample = loaded_hf_dataset[id[0]] +hf_sample = hf_datasetdict['train'][0] print(f"{hf_sample = }") # %% [markdown] -# We notice that ``hf_sample`` is a binary object efficiently handled by huggingface datasets. It can be converted into a plaid sample using a specific constructor relying on a pydantic validator. +# We notice that ``hf_sample`` is not a plaid sample, but a dict containing the variable features of the datasets, with keys being the flattened path of the CGNS tree. contains a binary object efficiently handled by huggingface datasets. It can be converted into a plaid sample using a specific constructor relying on a pydantic validator, and the required `flat_cst` and `cgns_types`. # %% -plaid_sample = huggingface_bridge.to_plaid_sample(hf_sample) +plaid_sample = huggingface_bridge.to_plaid_sample(hf_sample, flat_cst, cgns_types) show_sample(plaid_sample) + # %% [markdown] -# Very large datasets can be streamed directly from the Hugging Face hub: +# Very large datasets that do not fit on disk can be streamed directly from the Hugging Face hub: # # ```python -# hf_dataset_stream = load_dataset("chanel/dataset", split="all_samples", streaming=True) -# -# plaid_sample = huggingface_bridge.to_plaid_sample(next(iter(hf_dataset_stream))) -# -# show_sample(plaid_sample) +# hf_dataset_stream = load_dataset("chanel/dataset", split="train", streaming=True) +# plaid_sample = huggingface_bridge.to_plaid_sample(next(iter(hf_dataset_stream)), flat_cst, cgns_types) # ``` # -# Or initialize a plaid dataset and problem definition for any number of samples relying on this streaming mechanisme: +# If you are behing a proxy: +# ```python +# hf_dataset_stream = huggingface_bridge.load_dataset_from_hub("chanel/dataset", split="train", streaming=True) +# plaid_sample = huggingface_bridge.to_plaid_sample(next(iter(hf_dataset_stream)), flat_cst, cgns_types) +# ``` + +# %% [markdown] +# ## Section 6: Advanced concepts + +# %% [markdown] +# In this section, we investigate concepts to better exploit the datasets made available on Hugging Face, by looking into read speed and memory usage. The commands are not executed by this notebook. You can copy/paste the following code to execute it, but be mindfull that it will download a 235MB dataset. # # ```python -# from plaid.bridges.huggingface_bridge import streamed_huggingface_dataset_to_plaid +# repo_id = "fabiencasenave/Tensile2d_DO_NOT_DELETE" +# split_names = ["train_500", "test", "OOD"] # -# dataset, pb_def = streamed_huggingface_dataset_to_plaid('PLAID-datasets/VKI-LS59', 2) +# hf_dataset_dict = huggingface_bridge.load_dataset_from_hub(repo_id) # ``` +# %% [markdown] +# We investigate the time and memory needed to instantiate the plaid dataset dict from the repo_id, now that the hf datasets have been loaded in cache: +# ```python +# init_ram = get_mem() +# start = time() +# dataset_dict = huggingface_bridge.instantiate_plaid_datasetdict_from_hub(repo_id) +# elapsed = time() - start +# print(f"Time to instantiate plaid dataset dict from cache: {elapsed:.6g} s, RAM usage increase: {get_mem()-init_ram} MB") +# ``` +# ```bash +# >> Time to instantiate plaid dataset dict from cache: 1.37948 s, RAM usage increase: 22.5 MB +# ``` +# We notice the RAM usage is lower than the size of the dataset: all the variable shape 1DArrays and constant shape 2DArrays in the samples are initiated in no-copy mode. +# %% [markdown] +# We now investigate the possible gains when handling the datasets directly. First, bypassing cache checks and constructing plaid dataset from an instantiated HF dataset is much faster: +# ```python +# flat_cst, key_mappings = huggingface_bridge.load_tree_struct_from_hub(repo_id) +# pb_def = huggingface_bridge.load_problem_definition_from_hub(repo_id, "task_1") +# infos = huggingface_bridge.load_infos_from_hub(repo_id) +# cgns_types = key_mappings["cgns_types"] +# +# hf_dataset = hf_dataset_dict[split_names[0]] +# +# init_ram = get_mem() +# start = time() +# dataset = huggingface_bridge.to_plaid_dataset(hf_dataset, flat_cst, cgns_types) +# elapsed = time() - start +# print(f"Time to build dataset on split {split_names[0]}: {elapsed:.6g} s, RAM usage increase: {get_mem()-init_ram} MB") +# ``` +# ```bash +# >> Time to build dataset on split train_500: 0.173115 s, RAM usage increase: 16.3125 MB +# ``` +# %% [markdown] +# It is possible to further remove overheads by accessing directly 1DArrays in the arrow table of the HF datasets in no-copy mode: +# ```python +# init_ram = get_mem() +# start = time() +# data = {} +# for i in range(len(hf_dataset)): +# data[i] = hf_dataset.data["Base_2_2/Zone/PointData/sig12"][i].values.to_numpy(zero_copy_only=True) +# elapsed = time() - start +# print(f"Time to read 1D fields of variable size on the complete split {split_names[0]}: {elapsed:.6g} s, RAM usage increase: {get_mem()-init_ram} MB") +# ``` +# ```bash +# >> Time to read 1D fields of variable size on the complete split train_500: 0.0021801 s, RAM usage increase: 0.375 MB +# ``` diff --git a/examples/containers/sample_example.py b/examples/containers/sample_example.py index be629975..bd731bbe 100644 --- a/examples/containers/sample_example.py +++ b/examples/containers/sample_example.py @@ -56,7 +56,7 @@ def show_sample(sample: Sample): # %% [markdown] # ## Section 1: Initializing an Empty Sample and Adding Data # -# This section demonstrates how to initialize an empty Sample and add scalars, time series data, and meshes / CGNS trees. +# This section demonstrates how to initialize an empty Sample and add scalars, and meshes / CGNS trees. # %% [markdown] # ### Create and display CGNS tree from an unstructured mesh diff --git a/examples/pipelines/pipeline_example.py b/examples/pipelines/pipeline_example.py index a1a53355..b6466e65 100644 --- a/examples/pipelines/pipeline_example.py +++ b/examples/pipelines/pipeline_example.py @@ -59,7 +59,7 @@ from sklearn.model_selection import KFold, GridSearchCV -from plaid.bridges.huggingface_bridge import huggingface_dataset_to_plaid, load_hf_dataset_from_hub +from plaid.bridges.huggingface_bridge import huggingface_dataset_to_plaid, load_dataset_from_hub from plaid.pipelines.sklearn_block_wrappers import WrappedSklearnTransformer, WrappedSklearnRegressor from plaid.pipelines.plaid_blocks import TransformedTargetRegressor, ColumnTransformer @@ -73,8 +73,8 @@ # We load the `VKI-LS59` dataset from Hugging Face and restrict ourselves to the first 24 samples of the training set. # %% -hf_dataset = load_hf_dataset_from_hub("PLAID-datasets/VKI-LS59", split="all_samples[:24]") -dataset_train, _ = huggingface_dataset_to_plaid(hf_dataset, processes_number = n_processes, verbose = False) +hf_dataset = load_dataset_from_hub("PLAID-datasets/VKI-LS59", split="all_samples[:24]") +dataset_train, pb_def = huggingface_dataset_to_plaid(hf_dataset, processes_number = n_processes, verbose = False) # %% [markdown] diff --git a/examples/run_examples.bat b/examples/run_examples.bat index 91147c3e..47ec177a 100644 --- a/examples/run_examples.bat +++ b/examples/run_examples.bat @@ -1,5 +1,5 @@ @echo off -for %%f in (*.py utils\*.py containers\*.py post\*.py) do ( +for %%f in (*.py examples\*.py utils\*.py containers\*.py post\*.py) do ( echo -------------------------------------------------------------------------------------- echo #---# run python %%f python %%f || exit /b 1 diff --git a/examples/run_examples.sh b/examples/run_examples.sh index 379fd62e..8853bbd1 100755 --- a/examples/run_examples.sh +++ b/examples/run_examples.sh @@ -1,9 +1,9 @@ #!/bin/bash if [[ "$(uname)" == "Linux" ]]; then - FILES="*.py utils/*.py containers/*.py post/*.py pipelines/*.py" + FILES="*.py examples/*.py bridges/*.py utils/*.py containers/*.py post/*.py pipelines/*.py" else - FILES="*.py utils/*.py containers/*.py post/*.py" + FILES="*.py examples/*.py utils/*.py containers/*.py post/*.py" fi for file in $FILES diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index 730c7c6e..7d06d356 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -6,13 +6,20 @@ # file 'LICENSE.txt', which is part of this source code package. # # +import io +import json +import os import pickle import shutil import sys +from functools import partial from multiprocessing import Pool from pathlib import Path -from typing import Any, Callable, Optional +from typing import Callable, Optional +import numpy as np +import pyarrow as pa +import yaml from tqdm import tqdm if sys.version_info >= (3, 11): @@ -23,38 +30,618 @@ Self = TypeVar("Self") import logging -import os -from typing import Union +from typing import Any, Union import datasets -from datasets import load_dataset -from huggingface_hub import snapshot_download +from datasets import Features, Sequence, Value, load_dataset, load_from_disk +from huggingface_hub import HfApi, hf_hub_download, snapshot_download from pydantic import ValidationError from plaid import Dataset, ProblemDefinition, Sample from plaid.containers.features import SampleFeatures from plaid.types import IndexType +from plaid.types.cgns_types import CGNSTree +from plaid.utils.cgns_helper import flatten_cgns_tree, unflatten_cgns_tree +from plaid.utils.deprecation import deprecated logger = logging.getLogger(__name__) -""" -Convention with hf (Hugging Face) datasets: -- hf-datasets contains a single Hugging Face split, named 'all_samples'. -- samples contains a single Hugging Face feature, named called "sample". -- Samples are instances of :ref:`Sample`. -- Mesh objects included in samples follow the CGNS standard, and can be converted in Muscat.Containers.Mesh.Mesh. -- problem_definition info is stored in hf-datasets "description" parameter -""" +# ------------------------------------------------------------------------------ +# HUGGING FACE BRIDGE (with tree flattening and pyarrow tables) +# ------------------------------------------------------------------------------ + + +def to_cgns_tree_columnar( + ds: datasets.Dataset, + i: int, + flat_cst: dict[str, Any], + cgns_types: dict[str, str], + enforce_shapes: bool = False, +) -> CGNSTree: + """Convert a Hugging Face dataset row to a PLAID Sample object. + + This function extracts a single row from a Hugging Face dataset and converts it + into a PLAID Sample by unflattening the CGNS tree structure. Constant features + are added from the flat_cst dictionary. + + Args: + ds (datasets.Dataset): The Hugging Face dataset containing the sample data. + i (int): The index of the row to convert. + flat_cst (dict[str, any]): Dictionary of constant features to add to each sample. + cgns_types (dict[str, str]): Dictionary mapping paths to CGNS types for reconstruction. + enforce_shapes (bool, optional): If True, ensures consistent array shapes during conversion. + Defaults to False. + + Returns: + Sample: A validated PLAID Sample object reconstructed from the Hugging Face dataset row. + + Notes: + - The function uses the dataset's pyarrow table data for efficient access + - When enforce_shapes is False, it uses zero_copy_only=False for numpy conversion + - When enforce_shapes is True, it handles ListArray types specially by stacking them + - Constant features from flat_cst are merged with the variable features from the row + """ + table = ds.data + + row = {} + if not enforce_shapes: + for name in table.column_names: + value = table[name][i].values + if value is None: + row[name] = None + else: + row[name] = value.to_numpy(zero_copy_only=False) + else: + for name in table.column_names: + value = table[name][i].values + if value is None: + row[name] = None + else: + if isinstance(value, pa.ListArray): + row[name] = np.stack(value.to_numpy(zero_copy_only=False)) + else: + row[name] = value.to_numpy(zero_copy_only=True) + + row.update(flat_cst) + return unflatten_cgns_tree(row, cgns_types) + + +def to_cgns_tree( + hf_sample: dict[str, Features], flat_cst: dict[str, Any], cgns_types: dict[str, str] +) -> CGNSTree: + """Convert a Hugging Face dataset row to a PLAID Sample object. + + This function extracts a single row from a Hugging Face dataset and converts it + into a PLAID Sample by unflattening the CGNS tree structure. Constant features + are added from the flat_cst dictionary. + + Args: + hf_sample (dict[str, Features]): row of a Hugging Face dataset + flat_cst (dict[str, any]): Dictionary of constant features to add to each sample. + cgns_types (dict[str, str]): Dictionary mapping paths to CGNS types for reconstruction. + enforce_shapes (bool, optional): If True, ensures consistent array shapes during conversion. + Defaults to False. + + Returns: + Sample: A validated PLAID Sample object reconstructed from the Hugging Face dataset row. + + Notes: + - The function uses the dataset's pyarrow table data for efficient access + - When enforce_shapes is False, it uses zero_copy_only=False for numpy conversion + - When enforce_shapes is True, it handles ListArray types specially by stacking them + - Constant features from flat_cst are merged with the variable features from the row + """ + row = {name: np.array(value) for name, value in hf_sample.items()} + row.update(flat_cst) + return unflatten_cgns_tree(row, cgns_types) + + +def to_plaid_sample_columnar( + ds: datasets.Dataset, + i: int, + flat_cst: dict[str, Any], + cgns_types: dict[str, str], + enforce_shapes: bool = False, +) -> Sample: + """Convert a Hugging Face dataset row to a PLAID Sample object. + + This function extracts a single row from a Hugging Face dataset and converts it + into a PLAID Sample by unflattening the CGNS tree structure. Constant features + are added from the flat_cst dictionary. + + Args: + ds (datasets.Dataset): The Hugging Face dataset containing the sample data. + i (int): The index of the row to convert. + flat_cst (dict[str, any]): Dictionary of constant features to add to each sample. + cgns_types (dict[str, str]): Dictionary mapping paths to CGNS types for reconstruction. + enforce_shapes (bool, optional): If True, ensures consistent array shapes during conversion. + Defaults to False. + + Returns: + Sample: A validated PLAID Sample object reconstructed from the Hugging Face dataset row. + + Notes: + - The function uses the dataset's pyarrow table data for efficient access + - When enforce_shapes is False, it uses zero_copy_only=False for numpy conversion + - When enforce_shapes is True, it handles ListArray types specially by stacking them + - Constant features from flat_cst are merged with the variable features from the row + """ + cgns_tree = to_cgns_tree_columnar(ds, i, flat_cst, cgns_types, enforce_shapes) + + sample = Sample(path=None, features=SampleFeatures({0.0: cgns_tree})) + return Sample.model_validate(sample) + + +def to_plaid_sample( + hf_sample: dict[str, Features], + flat_cst: dict[str, Any], + cgns_types: dict[str, str], +) -> Sample: + """Convert a Hugging Face dataset row to a PLAID Sample object. + + This function extracts a single row from a Hugging Face dataset and converts it + into a PLAID Sample by unflattening the CGNS tree structure. Constant features + are added from the flat_cst dictionary. + + Args: + hf_sample (dict[str, Features]): row of a Hugging Face dataset + flat_cst (dict[str, any]): Dictionary of constant features to add to each sample. + cgns_types (dict[str, str]): Dictionary mapping paths to CGNS types for reconstruction. + + Returns: + Sample: A validated PLAID Sample object reconstructed from the Hugging Face dataset row. + + Notes: + - The function uses the dataset's pyarrow table data for efficient access + - When enforce_shapes is False, it uses zero_copy_only=False for numpy conversion + - When enforce_shapes is True, it handles ListArray types specially by stacking them + - Constant features from flat_cst are merged with the variable features from the row + """ + cgns_tree = to_cgns_tree(hf_sample, flat_cst, cgns_types) + + sample = Sample(path=None, features=SampleFeatures({0.0: cgns_tree})) + return Sample.model_validate(sample) + + +def to_plaid_dataset( + hf_dataset: datasets.Dataset, + flat_cst: dict[str, Any], + cgns_types: dict[str, str], + enforce_shapes: bool = True, +) -> Dataset: + """Convert a Hugging Face dataset into a PLAID dataset. -def load_hf_dataset_from_hub( + Iterates over all samples in a Hugging Face `Dataset` and converts each one + into a PLAID-compatible sample using `to_plaid_sample_columnar`. The resulting + samples are then collected into a single PLAID `Dataset`. + + Args: + hf_dataset (datasets.Dataset): + The Hugging Face dataset split to convert. + flat_cst: + Flattened representation of the CGNS tree structure constants, + used to map data fields. + cgns_types: + Mapping of CGNS paths to their expected types. + enforce_shapes (bool, optional): + If True, ensures all arrays strictly follow the reference shapes. + Defaults to False. + + Returns: + Dataset: + A PLAID `Dataset` object containing the converted samples. + + """ + sample_list = [] + for i in range(len(hf_dataset)): + sample_list.append( + to_plaid_sample_columnar( + hf_dataset, i, flat_cst, cgns_types, enforce_shapes + ) + ) + + return Dataset(samples=sample_list) + + +def infer_hf_features_from_value(value: Any) -> Union[Value, Sequence]: + """Infer Hugging Face dataset feature type from a given value. + + This function analyzes the input value and determines the appropriate Hugging Face + feature type representation. It handles None values, scalars, and arrays/lists + of various dimensions, mapping them to corresponding Hugging Face Value or Sequence types. + + Args: + value (Any): The value to infer the feature type from. Can be None, scalar, + list, tuple, or numpy array. + + Returns: + datasets.Feature: A Hugging Face feature type (Value or Sequence) that corresponds + to the input value's structure and data type. + + Raises: + TypeError: If the value type is not supported. + TypeError: If the array dimensionality exceeds 3D for arrays/lists. + + Notes: + - For scalar values, maps numpy dtypes to appropriate Hugging Face Value types: + float types to "float32", int32 to "int32", int64 to "int64", others to "string" + - For arrays/lists, creates nested Sequence structures based on dimensionality: + 1D → Sequence(base_type), 2D → Sequence(Sequence(base_type)), + 3D → Sequence(Sequence(Sequence(base_type))) + - All float values are enforced to "float32" to limit data size + - All int64 values are preserved as "int64" to satisfy CGNS standards + """ + if value is None: + return Value("null") + + # Scalars + if np.isscalar(value): + dtype = np.array(value).dtype + if np.issubdtype( + dtype, np.floating + ): # enforcing float32 for all floats, to be updated in case we want to keep float64 + return Value("float32") + elif np.issubdtype(dtype, np.int32): + return Value("int32") + elif np.issubdtype( + dtype, np.int64 + ): # very important to satisfy the CGNS standard + return Value("int64") + else: + return Value("string") + + # Arrays / lists + elif isinstance(value, (list, tuple, np.ndarray)): + arr = np.array(value) + base_type = infer_hf_features_from_value(arr.flat[0] if arr.size > 0 else None) + if arr.ndim == 1: + return Sequence(base_type) + elif arr.ndim == 2: + return Sequence(Sequence(base_type)) + elif arr.ndim == 3: + return Sequence(Sequence(Sequence(base_type))) + else: + raise TypeError(f"Unsupported ndim: {arr.ndim}") # pragma: no cover + raise TypeError(f"Unsupported type: {type(value)}") # pragma: no cover + + +def plaid_dataset_to_huggingface_datasetdict( + dataset: Dataset, + main_splits: dict[str, IndexType], + processes_number: int = 1, + writer_batch_size: int = 1, + verbose: bool = False, +) -> tuple[datasets.DatasetDict, dict[str, Any], dict[str, Any]]: + """Convert a PLAID dataset into a Hugging Face `datasets.DatasetDict`. + + This is a thin wrapper that creates per-split generators from a PLAID dataset + and delegates the actual dataset construction to + `plaid_generator_to_huggingface_datasetdict`. + + Args: + dataset (plaid.Dataset): + The PLAID dataset to be converted. Must support indexing with + a list of IDs (from `main_splits`). + main_splits (dict[str, IndexType]): + Mapping from split names (e.g. "train", "test") to the subset of + sample indices belonging to that split. + processes_number (int, optional, default=1): + Number of parallel processes to use when writing the Hugging Face dataset. + writer_batch_size (int, optional, default=1): + Batch size used when writing samples to disk in Hugging Face format. + verbose (bool, optional, default=False): + If True, print progress and debug information. + + Returns: + datasets.DatasetDict: + A Hugging Face `DatasetDict` containing one dataset per split. + + Example: + >>> ds_dict = plaid_dataset_to_huggingface_datasetdict( + ... dataset=my_plaid_dataset, + ... main_splits={"train": [0, 1, 2], "test": [3]}, + ... processes_number=4, + ... writer_batch_size=3 + ... ) + >>> print(ds_dict) + DatasetDict({ + train: Dataset({ + features: ... + }), + test: Dataset({ + features: ... + }) + }) + """ + + def generator(dataset): + for sample in dataset: + yield sample + + generators = { + split_name: partial(generator, dataset[ids]) + for split_name, ids in main_splits.items() + } + + return plaid_generator_to_huggingface_datasetdict( + generators, processes_number, writer_batch_size, verbose + ) + + +def _generator_prepare_for_huggingface( + generators: dict[str, Callable], + verbose: bool = True, +) -> tuple[dict[str, Any], dict[str, Any], Features]: + """Inspect PLAID dataset generators and infer Hugging Face feature schema. + + This function scans all provided split generators to: + 1. Flatten each CGNS tree into a dictionary of paths → values. + 2. Infer Hugging Face `Features` types for all variable leaves. + 3. Detect constant leaves (values that never change across all samples). + 4. Collect global CGNS type metadata. + + Args: + generators (dict[str, Callable]): + A dictionary mapping split names to callables returning sample generators. + Each sample is expected to have the structure `sample.features.data[0.0]` + compatible with `flatten_cgns_tree`. + verbose (bool, optional, default=True): + If True, displays progress bars while processing splits. + + Returns: + tuple: + - **flat_cst (dict[str, Any])**: + Mapping from feature path to constant values detected across all splits. + - **key_mappings (dict[str, Any])**: + Metadata dictionary with: + - `"variable_features"` (list[str]): paths of non-constant features. + - `"constant_features"` (list[str]): paths of constant features. + - `"cgns_types"` (dict[str, Any]): CGNS type information for all paths. + - **hf_features (datasets.Features)**: + Hugging Face feature specification for variable features. + + Raises: + ValueError: + If inconsistent CGNS types or feature types are found for the same path. + + Example: + >>> flat_cst, key_mappings, hf_features = _generator_prepare_for_huggingface( + ... {"train": lambda: iter(train_samples), + ... "test": lambda: iter(test_samples)} + ... ) + >>> print(key_mappings["variable_features"][:5]) + ['Zone1/FlowSolution/VelocityX', 'Zone1/FlowSolution/VelocityY', ...] + >>> print(flat_cst) + {'Zone1/GridCoordinates': array([0., 0.1, 0.2])} + >>> print(hf_features) + {'Zone1/FlowSolution/VelocityX': Value(dtype='float32', id=None), ...} + """ + + def values_equal(v1, v2): + if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray): + return np.array_equal(v1, v2) + return v1 == v2 + + global_cgns_types = {} + global_feature_types = {} + global_constant_leaves = {} + total_samples = 0 + + for split_name, generator in generators.items(): + for sample in tqdm( + generator(), + disable=not verbose, + desc=f"Prepare for HF on split {split_name}", + ): + total_samples += 1 + tree = sample.features.data[0.0] + flat, cgns_types = flatten_cgns_tree(tree) + + for path, value in flat.items(): + # --- CGNS types --- + if path not in global_cgns_types: + global_cgns_types[path] = cgns_types[path] + elif global_cgns_types[path] != cgns_types[path]: # pragma: no cover + raise ValueError( + f"Conflict for path '{path}': {global_cgns_types[path]} vs {cgns_types[path]}" + ) + + # --- feature types --- + inferred_feature = infer_hf_features_from_value(value) + if path not in global_feature_types: + global_feature_types[path] = inferred_feature + else: + # sanity check: convert to dict before comparing + if repr(global_feature_types[path]) != repr( + inferred_feature + ): # pragma: no cover + raise ValueError( + f"Feature type mismatch for {path}: " + f"{global_feature_types[path]} vs {inferred_feature}" + ) + + if path not in global_constant_leaves: + global_constant_leaves[path] = { + "value": value, + "constant": True, + "count": 1, + } + else: + entry = global_constant_leaves[path] + entry["count"] += 1 + if entry["constant"] and not values_equal(entry["value"], value): + entry["constant"] = False + + # After loop: only keep constants that appeared in all samples + for path, entry in global_constant_leaves.items(): + if entry["count"] != total_samples: + entry["constant"] = False + + # Sort dicts by keys + global_cgns_types = {p: global_cgns_types[p] for p in sorted(global_cgns_types)} + global_feature_types = { + p: global_feature_types[p] for p in sorted(global_feature_types) + } + global_constant_leaves = { + p: global_constant_leaves[p] for p in sorted(global_constant_leaves) + } + + flat_cst = { + p: e["value"] for p, e in global_constant_leaves.items() if e["constant"] + } + + cst_features = list(flat_cst.keys()) + var_features = [k for k in global_cgns_types.keys() if k not in cst_features] + + hf_features = Features( + {k: v for k, v in global_feature_types.items() if k in var_features} + ) + + key_mappings = {} + key_mappings["variable_features"] = var_features + key_mappings["constant_features"] = cst_features + key_mappings["cgns_types"] = global_cgns_types + + return flat_cst, key_mappings, hf_features + + +def plaid_generator_to_huggingface_datasetdict( + generators: dict[str, Callable], + processes_number: int = 1, + writer_batch_size: int = 1, + verbose: bool = False, +) -> tuple[datasets.DatasetDict, dict[str, Any], dict[str, Any]]: + """Convert PLAID dataset generators into a Hugging Face `datasets.DatasetDict`. + + This function inspects samples produced by the given generators, flattens their + CGNS tree structure, infers Hugging Face feature types, and builds one + `datasets.Dataset` per split. Constant features (identical across all samples) + are separated out from variable features. + + Args: + generators (dict[str, Callable]): + Mapping from split names (e.g., "train", "test") to generator functions. + Each generator function must return an iterable of PLAID samples, where + each sample provides `sample.features.data[0.0]` for flattening. + processes_number (int, optional, default=1): + Number of processes used internally by Hugging Face when materializing + the dataset from the generators. + writer_batch_size (int, optional, default=1): + Batch size used when writing samples to disk in Hugging Face format. + verbose (bool, optional, default=False): + If True, displays progress bars and diagnostic messages. + + Returns: + tuple: + - **DatasetDict** (`datasets.DatasetDict`): + A Hugging Face dataset dictionary with one dataset per split. + - **flat_cst** (`dict[str, Any]`): + Dictionary of constant features detected across all splits. + - **key_mappings** (`dict[str, Any]`): + Metadata dictionary containing: + - `"variable_features"`: list of paths for non-constant features. + - `"constant_features"`: list of paths for constant features. + - `"cgns_types"`: inferred CGNS types for all features. + + Example: + >>> ds_dict, flat_cst, key_mappings = plaid_generator_to_huggingface_datasetdict( + ... {"train": lambda: iter(train_samples), + ... "test": lambda: iter(test_samples)}, + ... processes_number=4, + ... writer_batch_size=2, + ... verbose=True + ... ) + >>> print(ds_dict) + DatasetDict({ + train: Dataset({ + features: ... + }), + test: Dataset({ + features: ... + }) + }) + >>> print(flat_cst) + {'Zone1/GridCoordinates': array([0., 0.1, 0.2])} + >>> print(key_mappings["variable_features"][:3]) + ['Zone1/FlowSolution/VelocityX', 'Zone1/FlowSolution/VelocityY', ...] + """ + flat_cst, key_mappings, hf_features = _generator_prepare_for_huggingface( + generators, verbose + ) + + all_features_keys = list(hf_features.keys()) + + def generator_fn(gen_func, all_features_keys): + for sample in gen_func(): + tree = sample.features.data[0.0] + flat, _ = flatten_cgns_tree(tree) + yield {path: flat.get(path, None) for path in all_features_keys} + + _dict = {} + for split_name, gen_func in generators.items(): + gen = partial(generator_fn, gen_func, all_features_keys) + _dict[split_name] = datasets.Dataset.from_generator( + generator=gen, + features=hf_features, + num_proc=processes_number, + writer_batch_size=writer_batch_size, + split=datasets.splits.NamedSplit(split_name), + ) + + return datasets.DatasetDict(_dict), flat_cst, key_mappings + + +# ------------------------------------------------------------------------------ +# HUGGING FACE HUB INTERACTIONS +# ------------------------------------------------------------------------------ + + +def instantiate_plaid_datasetdict_from_hub( + repo_id: str, + enforce_shapes: bool = False, +) -> dict[str, Dataset]: # pragma: no cover (not tested in unit tests) + """Load a Hugging Face dataset from the Hub and instantiate it as a dictionary of PLAID datasets. + + This function retrieves a dataset dictionary from the Hugging Face Hub, + along with its associated CGNS tree structure and type information. Each + split of the Hugging Face dataset is then converted into a PLAID dataset. + + Args: + repo_id (str): + The Hugging Face repository identifier (e.g. `"user/dataset"`). + enforce_shapes (bool, optional): + If True, enforce strict array shapes when converting to PLAID + datasets. Defaults to False. + + Returns: + dict[str, Dataset]: + A dictionary mapping split names (e.g. `"train"`, `"test"`) to + PLAID `Dataset` objects. + + """ + hf_dataset_dict = load_dataset_from_hub(repo_id) + + flat_cst, key_mappings = load_tree_struct_from_hub(repo_id) + cgns_types = key_mappings["cgns_types"] + + datasetdict = {} + for split_name, hf_dataset in hf_dataset_dict.items(): + datasetdict[split_name] = to_plaid_dataset( + hf_dataset, flat_cst, cgns_types, enforce_shapes + ) + + return datasetdict + + +def load_dataset_from_hub( repo_id: str, streaming: bool = False, *args, **kwargs ) -> Union[ datasets.Dataset, datasets.DatasetDict, datasets.IterableDataset, datasets.IterableDatasetDict, -]: # pragma: no cover (to prevent testing from downloading, this is run by examples) +]: # pragma: no cover (not tested in unit tests) """Loads a Hugging Face dataset from the public hub, a private mirror, or local cache, with automatic handling of streaming and download modes. Behavior: @@ -72,8 +659,12 @@ def load_hf_dataset_from_hub( Args: repo_id (str): The Hugging Face dataset repository ID (e.g., 'username/dataset'). streaming (bool, optional): If True, attempts to stream the dataset (only supported on the public hub). - *args: Additional positional arguments passed to `datasets.load_dataset` or `datasets.load_from_disk`. - **kwargs: Additional keyword arguments passed to `datasets.load_dataset` or `datasets.load_from_disk`. + *args: + Positional arguments forwarded to + [`datasets.load_dataset`](https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset). + **kwargs: + Keyword arguments forwarded to + [`datasets.load_dataset`](https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset). Returns: Union[datasets.Dataset, datasets.DatasetDict]: The loaded Hugging Face dataset object. @@ -119,11 +710,460 @@ def _get_cached_path(repo_id_): return load_dataset(repo_id, streaming=streaming, *args, **kwargs) -def to_plaid_sample(hf_sample: dict[str, bytes]) -> Sample: - """Convert a Hugging Face dataset sample to a plaid :class:`Sample `. +def load_infos_from_hub( + repo_id: str, +) -> dict[str, dict[str, str]]: # pragma: no cover (not tested in unit tests) + """Load dataset infos from the Hugging Face Hub. + + Downloads the infos.yaml file from the specified repository and parses it as a dictionary. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + + Returns: + dict[str, dict[str, str]]: Dictionary containing dataset infos. + """ + # Download infos.yaml + yaml_path = hf_hub_download( + repo_id=repo_id, filename="infos.yaml", repo_type="dataset" + ) + with open(yaml_path, "r", encoding="utf-8") as f: + infos = yaml.safe_load(f) + + return infos + + +def load_problem_definition_from_hub( + repo_id: str, name: str +) -> ProblemDefinition: # pragma: no cover (not tested in unit tests) + """Load a ProblemDefinition from the Hugging Face Hub. + + Downloads the problem infos YAML and split JSON files from the specified repository and location, + then initializes a ProblemDefinition object with this information. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + name (str): The name of the problem_definition stored in the repo. + + Returns: + ProblemDefinition: The loaded problem definition. + """ + # Download split.json + json_path = hf_hub_download( + repo_id=repo_id, + filename=f"problem_definitions/{name}/split.json", + repo_type="dataset", + ) + with open(json_path, "r", encoding="utf-8") as f: + json_data = json.load(f) + + # Download problem_infos.yaml + yaml_path = hf_hub_download( + repo_id=repo_id, + filename=f"problem_definitions/{name}/problem_infos.yaml", + repo_type="dataset", + ) + with open(yaml_path, "r", encoding="utf-8") as f: + yaml_data = yaml.safe_load(f) + + prob_def = ProblemDefinition() + prob_def._initialize_from_problem_infos_dict(yaml_data) + prob_def.set_split(json_data) + + return prob_def + + +def load_tree_struct_from_hub( + repo_id: str, +) -> tuple[dict, dict]: # pragma: no cover (not tested in unit tests) + """Load the tree structure metadata of a PLAID dataset from the Hugging Face Hub. + + This function retrieves two artifacts previously uploaded alongside a dataset: + - **tree_constant_part.pkl**: a pickled dictionary of constant feature values + (features that are identical across all samples). + - **key_mappings.yaml**: a YAML file containing metadata about the dataset + feature structure, including variable features, constant features, and CGNS types. + + Args: + repo_id (str): + The repository ID on the Hugging Face Hub + (e.g., `"username/dataset_name"`). + + Returns: + tuple[dict, dict]: + - **flat_cst (dict)**: constant features dictionary (path → value). + - **key_mappings (dict)**: metadata dictionary containing keys such as: + - `"variable_features"`: list of paths for non-constant features. + - `"constant_features"`: list of paths for constant features. + - `"cgns_types"`: mapping from paths to CGNS types. + """ + # constant part of the tree + flat_cst_path = hf_hub_download( + repo_id=repo_id, + filename="tree_constant_part.pkl", + repo_type="dataset", + ) + + with open(flat_cst_path, "rb") as f: + flat_cst = pickle.load(f) + + # key mappings + yaml_path = hf_hub_download( + repo_id=repo_id, + filename="key_mappings.yaml", + repo_type="dataset", + ) + with open(yaml_path, "r", encoding="utf-8") as f: + key_mappings = yaml.safe_load(f) + + return flat_cst, key_mappings + + +def push_dataset_dict_to_hub( + repo_id: str, hf_dataset_dict: datasets.DatasetDict, *args, **kwargs +) -> None: # pragma: no cover (not tested in unit tests) + """Push a Hugging Face `DatasetDict` to the Hugging Face Hub. + + This is a thin wrapper around `datasets.DatasetDict.push_to_hub`, allowing + you to upload a dataset dictionary (with one or more splits such as + `"train"`, `"validation"`, `"test"`) to the Hugging Face Hub. + + Args: + repo_id (str): + The repository ID on the Hugging Face Hub + (e.g. `"username/dataset_name"`). + hf_dataset_dict (datasets.DatasetDict): + The Hugging Face dataset dictionary to push. + *args: + Positional arguments forwarded to + [`DatasetDict.push_to_hub`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub). + **kwargs: + Keyword arguments forwarded to + [`DatasetDict.push_to_hub`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.push_to_hub). + + Returns: + None + """ + hf_dataset_dict.push_to_hub(repo_id, *args, **kwargs) + + +def push_infos_to_hub( + repo_id: str, infos: dict[str, dict[str, str]] +) -> None: # pragma: no cover (not tested in unit tests) + """Upload dataset infos to the Hugging Face Hub. + + Serializes the infos dictionary to YAML and uploads it to the specified repository as infos.yaml. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + infos (dict[str, dict[str, str]]): Dictionary containing dataset infos to upload. + + Raises: + ValueError: If the infos dictionary is empty. + """ + if len(infos) > 0: + api = HfApi() + yaml_str = yaml.dump(infos) + yaml_buffer = io.BytesIO(yaml_str.encode("utf-8")) + api.upload_file( + path_or_fileobj=yaml_buffer, + path_in_repo="infos.yaml", + repo_id=repo_id, + repo_type="dataset", + commit_message="Upload infos.yaml", + ) + else: + raise ValueError("'infos' must not be empty") + + +def push_problem_definition_to_hub( + repo_id: str, name: str, pb_def: ProblemDefinition +) -> None: # pragma: no cover (not tested in unit tests) + """Upload a ProblemDefinition and its split information to the Hugging Face Hub. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + name (str): The name of the problem_definition to store in the repo. + pb_def (ProblemDefinition): The problem definition to upload. + """ + api = HfApi() + data = pb_def._generate_problem_infos_dict() + if data is not None: + yaml_str = yaml.dump(data) + yaml_buffer = io.BytesIO(yaml_str.encode("utf-8")) + + api.upload_file( + path_or_fileobj=yaml_buffer, + path_in_repo=f"problem_definitions/{name}/problem_infos.yaml", + repo_id=repo_id, + repo_type="dataset", + commit_message=f"Upload problem_definitions/{name}/problem_infos.yaml", + ) + + data = pb_def.get_split() + json_str = json.dumps(data) + json_buffer = io.BytesIO(json_str.encode("utf-8")) + + api.upload_file( + path_or_fileobj=json_buffer, + path_in_repo=f"problem_definitions/{name}/split.json", + repo_id=repo_id, + repo_type="dataset", + commit_message=f"Upload problem_definitions/{name}/split.json", + ) + + +def push_tree_struct_to_hub( + repo_id: str, + flat_cst: dict[str, Any], + key_mappings: dict[str, Any], +) -> None: # pragma: no cover (not tested in unit tests) + """Upload a dataset's tree structure to a Hugging Face dataset repository. + + This function pushes two components of a dataset tree structure to the specified + Hugging Face Hub repository: + + 1. `flat_cst`: the constant parts of the dataset tree, serialized as a pickle file + (`tree_constant_part.pkl`). + 2. `key_mappings`: the dictionary of key mappings and metadata for the dataset tree, + serialized as a YAML file (`key_mappings.yaml`). + + Both files are uploaded using the Hugging Face `HfApi().upload_file` method. + + Args: + repo_id (str): The Hugging Face dataset repository ID where files will be uploaded. + flat_cst (dict[str, Any]): Dictionary containing constant values in the dataset tree. + key_mappings (dict[str, Any]): Dictionary containing key mappings and additional metadata. + + Returns: + None - If the sample is not valid, it tries to build it from its components. - If it still fails because of a missing key, it raises a KeyError. + Notes: + - Each upload includes a commit message indicating the filename. + - This function is not covered by unit tests (`pragma: no cover`). + """ + api = HfApi() + + # constant part of the tree + api.upload_file( + path_or_fileobj=io.BytesIO(pickle.dumps(flat_cst)), + path_in_repo="tree_constant_part.pkl", + repo_id=repo_id, + repo_type="dataset", + commit_message="Upload tree_constant_part.pkl", + ) + + # key mappings + yaml_str = yaml.dump(key_mappings, sort_keys=False) + yaml_buffer = io.BytesIO(yaml_str.encode("utf-8")) + + api.upload_file( + path_or_fileobj=yaml_buffer, + path_in_repo="key_mappings.yaml", + repo_id=repo_id, + repo_type="dataset", + commit_message="Upload key_mappings.yaml", + ) + + +# ------------------------------------------------------------------------------ +# HUGGING FACE INTERACTIONS ON DISK +# ------------------------------------------------------------------------------ + + +def load_dataset_from_disk( + path: Union[str, Path], *args, **kwargs +) -> Union[datasets.Dataset, datasets.DatasetDict]: + """Load a Hugging Face dataset or dataset dictionary from disk. + + This function wraps `datasets.load_from_disk` to accept either a string path or a + `Path` object and returns the loaded dataset object. + + Args: + path (Union[str, Path]): Path to the directory containing the saved dataset. + *args: + Positional arguments forwarded to + [`datasets.load_from_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_from_disk). + **kwargs: + Keyword arguments forwarded to + [`datasets.load_from_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_from_disk). + + Returns: + Union[datasets.Dataset, datasets.DatasetDict]: The loaded Hugging Face dataset + object, which may be a single `Dataset` or a `DatasetDict` depending on what + was saved on disk. + """ + return load_from_disk(str(path), *args, **kwargs) + + +def load_infos_from_disk(path: Union[str, Path]) -> dict[str, dict[str, str]]: + """Load dataset information from a YAML file stored on disk. + + Args: + path (Union[str, Path]): Directory path containing the `infos.yaml` file. + + Returns: + dict[str, dict[str, str]]: Dictionary containing dataset infos. + """ + infos_fname = Path(path) / "infos.yaml" + with infos_fname.open("r") as file: + infos = yaml.safe_load(file) + return infos + + +def load_problem_definition_from_disk( + path: Union[str, Path], name: Union[str, Path] +) -> ProblemDefinition: + """Load a ProblemDefinition and its split information from disk. + + Args: + path (Union[str, Path]): The root directory path for loading. + name (str): The name of the problem_definition stored in the disk directory. + + Returns: + ProblemDefinition: The loaded problem definition. + """ + pb_def = ProblemDefinition() + pb_def._load_from_dir_(Path(path) / Path("problem_definitions") / Path(name)) + return pb_def + + +def load_tree_struct_from_disk( + path: Union[str, Path], +) -> tuple[dict[str, Any], dict[str, Any]]: + """Load a tree structure for a dataset from disk. + + This function loads two components from the specified directory: + 1. `tree_constant_part.pkl`: a pickled dictionary containing the constant parts of the tree. + 2. `key_mappings.yaml`: a YAML file containing key mappings and metadata. + + Args: + path (Union[str, Path]): Directory path containing the `tree_constant_part.pkl` + and `key_mappings.yaml` files. + + Returns: + tuple[dict, dict]: A tuple containing: + - `flat_cst` (dict): Dictionary of constant tree values. + - `key_mappings` (dict): Dictionary of key mappings and metadata. + """ + with open(Path(path) / Path("key_mappings.yaml"), "r", encoding="utf-8") as f: + key_mappings = yaml.safe_load(f) + + with open(Path(path) / "tree_constant_part.pkl", "rb") as f: + flat_cst = pickle.load(f) + + return flat_cst, key_mappings + + +def save_dataset_dict_to_disk( + path: Union[str, Path], hf_dataset_dict: datasets.DatasetDict, *args, **kwargs +) -> None: + """Save a Hugging Face DatasetDict to disk. + + This function serializes the provided DatasetDict and writes it to the specified + directory, preserving its features, splits, and data for later loading. + + Args: + path (Union[str, Path]): Directory path where the DatasetDict will be saved. + hf_dataset_dict (datasets.DatasetDict): The Hugging Face DatasetDict to save. + *args: + Positional arguments forwarded to + [`DatasetDict.save_to_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.save_to_disk). + **kwargs: + Keyword arguments forwarded to + [`DatasetDict.save_to_disk`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict.save_to_disk). + + Returns: + None + """ + hf_dataset_dict.save_to_disk(str(path), *args, **kwargs) + + +def save_infos_to_disk( + path: Union[str, Path], infos: dict[str, dict[str, str]] +) -> None: + """Save dataset infos as a YAML file to disk. + + Args: + path (Union[str, Path]): The directory path where the infos file will be saved. + infos (dict[str, dict[str, str]]): Dictionary containing dataset infos. + """ + infos_fname = Path(path) / "infos.yaml" + infos_fname.parent.mkdir(parents=True, exist_ok=True) + with open(infos_fname, "w") as file: + yaml.dump(infos, file, default_flow_style=False, sort_keys=False) + + +def save_problem_definition_to_disk( + path: Union[str, Path], name: Union[str, Path], pb_def: ProblemDefinition +) -> None: + """Save a ProblemDefinition and its split information to disk. + + Args: + path (Union[str, Path]): The root directory path for saving. + name (str): The name of the problem_definition to store in the disk directory. + pb_def (ProblemDefinition): The problem definition to save. + """ + pb_def._save_to_dir_(Path(path) / Path("problem_definitions") / Path(name)) + + +def save_tree_struct_to_disk( + path: Union[str, Path], + flat_cst: dict[str, Any], + key_mappings: dict[str, Any], +) -> None: + """Save the structure of a dataset tree to disk. + + This function writes the constant part of the tree and its key mappings to files + in the specified directory. The constant part is serialized as a pickle file, + while the key mappings are saved in YAML format. + + Args: + path (Union[str, Path]): Directory path where the tree structure files will be saved. + flat_cst (dict): Dictionary containing the constant part of the tree. + key_mappings (dict): Dictionary containing key mappings for the tree structure. + + Returns: + None + """ + Path(path).mkdir(parents=True, exist_ok=True) + + with open(Path(path) / "tree_constant_part.pkl", "wb") as f: # wb = write binary + pickle.dump(flat_cst, f) + + with open(Path(path) / "key_mappings.yaml", "w", encoding="utf-8") as f: + yaml.dump(key_mappings, f, sort_keys=False) + + +# ------------------------------------------------------------------------------ +# DEPRECATED HUGGING FACE BRIDGE (binary blobs) +# ------------------------------------------------------------------------------ + + +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) +def binary_to_plaid_sample(hf_sample: dict[str, bytes]) -> Sample: + """Convert a Hugging Face dataset sample in binary format to a Plaid `Sample`. + + The input `hf_sample` is expected to contain a pickled representation of a sample + under the key `"sample"`. This function attempts to validate the unpickled sample + as a Plaid `Sample`. If validation fails, it reconstructs the sample from its + components (`meshes`, `path`, and optional `scalars`) before validating it. + + Args: + hf_sample (dict[str, bytes]): A dictionary representing a Hugging Face sample, + with the pickled sample stored under the key `"sample"`. + + Returns: + Sample: A validated Plaid `Sample` object. + + Raises: + KeyError: If required keys (`"sample"`, `"meshes"`, `"path"`) are missing + and the sample cannot be reconstructed. + ValidationError: If the reconstructed sample still fails Plaid validation. """ pickled_hf_sample = pickle.loads(hf_sample["sample"]) @@ -149,44 +1189,15 @@ def to_plaid_sample(hf_sample: dict[str, bytes]) -> Sample: return Sample.model_validate(sample) -def generate_huggingface_description( - infos: dict, problem_definition: ProblemDefinition -) -> dict[str, Any]: - """Generates a Hugging Face dataset description field from a plaid dataset infos and problem definition. - - The conventions chosen here ensure working conversion to and from huggingset datasets. - - Args: - infos (dict): infos entry of the plaid dataset from which the Hugging Face description is to be generated - problem_definition (ProblemDefinition): of which the Hugging Face description is to be generated - - Returns: - dict[str]: Hugging Face dataset description - """ - # type hinting the values as Any because they can be of various types - description: dict[str, Any] = {} - - description.update(infos) - - split: dict[str, IndexType] = problem_definition.get_split(indices_name=None) # pyright: ignore[reportAssignmentType] - description["split"] = split - description["task"] = problem_definition.get_task() - - description["in_scalars_names"] = problem_definition.in_scalars_names - description["out_scalars_names"] = problem_definition.out_scalars_names - description["in_timeseries_names"] = problem_definition.in_timeseries_names - description["out_timeseries_names"] = problem_definition.out_timeseries_names - description["in_fields_names"] = problem_definition.in_fields_names - description["out_fields_names"] = problem_definition.out_fields_names - description["in_meshes_names"] = problem_definition.in_meshes_names - description["out_meshes_names"] = problem_definition.out_meshes_names - return description - - -def plaid_dataset_to_huggingface( +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) +def plaid_dataset_to_huggingface_binary( dataset: Dataset, - problem_definition: ProblemDefinition, - split: str = "all_samples", + ids: Optional[list[IndexType]] = None, + split_name: str = "all_samples", processes_number: int = 1, ) -> datasets.Dataset: """Use this function for converting a Hugging Face dataset from a plaid dataset. @@ -195,8 +1206,8 @@ def plaid_dataset_to_huggingface( Args: dataset (Dataset): the plaid dataset to be converted in Hugging Face format - problem_definition (ProblemDefinition): the problem definition is used to generate the description of the Hugging Face dataset. - split (str): The name of the split. Default: "all_samples". + ids (list, optional): The specific sample IDs to convert the dataset. Defaults to None. + split_name (str): The name of the split. Default: "all_samples". processes_number (int): The number of processes used to generate the Hugging Face dataset. Default: 1. Returns: @@ -205,14 +1216,12 @@ def plaid_dataset_to_huggingface( Example: .. code-block:: python - dataset = plaid_dataset_to_huggingface(dataset, problem_definition, split) + dataset = plaid_dataset_to_huggingface_binary(dataset, problem_definition, split) dataset.save_to_disk("path/to/dir) dataset.push_to_hub("chanel/dataset") """ - if split == "all_samples": + if ids is None: ids = dataset.get_sample_ids() - else: - ids = problem_definition.get_split(split) def generator(): for sample in dataset[ids]: @@ -220,29 +1229,31 @@ def generator(): "sample": pickle.dumps(sample.model_dump()), } - return plaid_generator_to_huggingface( + return plaid_generator_to_huggingface_binary( generator=generator, - infos=dataset.get_infos(), - problem_definition=problem_definition, - split=split, + split_name=split_name, processes_number=processes_number, ) -def plaid_dataset_to_huggingface_datasetdict( - dataset: Dataset, - problem_definition: ProblemDefinition, - main_splits: list[str], +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) +def plaid_generator_to_huggingface_binary( + generator: Callable, + split_name: str = "all_samples", processes_number: int = 1, -) -> datasets.DatasetDict: - """Use this function for converting a Hugging Face dataset dict from a plaid dataset. +) -> datasets.Dataset: + """Use this function for creating a Hugging Face dataset from a sample generator function. - The dataset can then be saved to disk, or pushed to the Hugging Face hub. + This function can be used when the plaid dataset cannot be loaded in RAM all at once due to its size. + The generator enables loading samples one by one. Args: - dataset (Dataset): the plaid dataset to be converted in Hugging Face format - problem_definition (ProblemDefinition): the problem definition is used to generate the description of the Hugging Face dataset. - main_splits (list[str]): The name of the main splits: defining a partitioning of the sample ids. + generator (Callable): a function yielding a dict {"sample" : sample}, where sample is of type 'bytes' + split_name (str): The name of the split. Default: "all_samples". processes_number (int): The number of processes used to generate the Hugging Face dataset. Default: 1. Returns: @@ -251,41 +1262,36 @@ def plaid_dataset_to_huggingface_datasetdict( Example: .. code-block:: python - dataset = plaid_dataset_to_huggingface(dataset, problem_definition, split) - dataset.save_to_disk("path/to/dir) - dataset.push_to_hub("chanel/dataset") + dataset = plaid_generator_to_huggingface_binary(generator, infos, split) """ - _dict = {} - for _, split in enumerate(main_splits): - ds = plaid_dataset_to_huggingface( - dataset=dataset, - problem_definition=problem_definition, - split=split, - processes_number=processes_number, - ) - _dict[split] = ds + ds: datasets.Dataset = datasets.Dataset.from_generator( # pyright: ignore[reportAssignmentType] + generator=generator, + features=datasets.Features({"sample": datasets.Value("binary")}), + num_proc=processes_number, + writer_batch_size=1, + split=datasets.splits.NamedSplit(split_name), + ) - return datasets.DatasetDict(_dict) + return ds -def plaid_generator_to_huggingface( - generator: Callable, - infos: dict, - problem_definition: ProblemDefinition, - split: str = "all_samples", +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) +def plaid_dataset_to_huggingface_datasetdict_binary( + dataset: Dataset, + main_splits: dict[str, IndexType], processes_number: int = 1, -) -> datasets.Dataset: - """Use this function for creating a Hugging Face dataset from a sample generator function. +) -> datasets.DatasetDict: + """Use this function for converting a Hugging Face dataset dict from a plaid dataset. - This function can be used when the plaid dataset cannot be loaded in RAM all at once due to its size. - The generator enables loading samples one by one. The dataset can then be saved to disk, or pushed to the Hugging Face hub. Args: - generator (Callable): a function yielding a dict {"sample" : sample}, where sample is of type 'bytes' - infos (dict): the info is used to generate the description of the Hugging Face dataset. - problem_definition (ProblemDefinition): the problem definition is used to generate the description of the Hugging Face dataset. - split (str): The name of the split. Default: "all_samples". + dataset (Dataset): the plaid dataset to be converted in Hugging Face format. + main_splits (list[str]): The name of the main splits: defining a partitioning of the sample ids. processes_number (int): The number of processes used to generate the Hugging Face dataset. Default: 1. Returns: @@ -294,39 +1300,29 @@ def plaid_generator_to_huggingface( Example: .. code-block:: python - dataset = plaid_generator_to_huggingface(generator, infos, split, problem_definition) + dataset = plaid_dataset_to_huggingface_datasetdict_binary(dataset, problem_definition, split) + dataset.save_to_disk("path/to/dir) dataset.push_to_hub("chanel/dataset") - dataset.save_to_disk("path/to/dir") """ - ds: datasets.Dataset = datasets.Dataset.from_generator( # pyright: ignore[reportAssignmentType] - generator, - features=datasets.Features({"sample": datasets.Value("binary")}), - num_proc=processes_number, - writer_batch_size=1, - split=datasets.splits.NamedSplit(split), - ) - - def update_dataset_description( - ds: datasets.Dataset, new_desc: dict[str, Any] - ) -> datasets.Dataset: - info = ds.info.copy() - info.description = new_desc # pyright: ignore[reportAttributeAccessIssue] -> info.description is HF's DatasetInfo. We might want to correct this later. - ds._info = info - return ds - - new_description: dict[str, Any] = generate_huggingface_description( - infos, problem_definition - ) - ds = update_dataset_description(ds, new_description) + _dict = {} + for split_name, ids in main_splits.items(): + ds = plaid_dataset_to_huggingface_binary( + dataset=dataset, + ids=ids, + processes_number=processes_number, + ) + _dict[split_name] = ds - return ds + return datasets.DatasetDict(_dict) -def plaid_generator_to_huggingface_datasetdict( - generator: Callable, - infos: dict, - problem_definition: ProblemDefinition, - main_splits: list, +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) +def plaid_generator_to_huggingface_datasetdict_binary( + generators: dict[str, Callable], processes_number: int = 1, ) -> datasets.DatasetDict: """Use this function for creating a Hugging Face dataset dict (containing multiple splits) from a sample generator function. @@ -339,10 +1335,7 @@ def plaid_generator_to_huggingface_datasetdict( Only the first split will contain the decription. Args: - generator (Callable): a function yielding a dict {"sample" : sample}, where sample is of type 'bytes' - infos (dict): infos entry of the plaid dataset from which the Hugging Face dataset is to be generated - problem_definition (ProblemDefinition): the problem definition is used to generate the description of the Hugging Face dataset. - main_splits (str, optional): The name of the main splits: defining a partitioning of the sample ids. + generators (dict[str, Callable]): a dict of functions yielding a dict {"sample" : sample}, where sample is of type 'bytes' processes_number (int): The number of processes used to generate the Hugging Face dataset. Default: 1. Returns: @@ -351,57 +1344,34 @@ def plaid_generator_to_huggingface_datasetdict( Example: .. code-block:: python - dataset = plaid_generator_to_huggingface_datasetdict(generator, infos, problem_definition, main_splits) - dataset.push_to_hub("chanel/dataset") - dataset.save_to_disk("path/to/dir") + hf_dataset_dict = plaid_generator_to_huggingface_datasetdict(generator, infos, problem_definition, main_splits) + push_dataset_dict_to_hub("chanel/dataset", hf_dataset_dict) + hf_dataset_dict.save_to_disk("path/to/dir") """ _dict = {} - for _, split in enumerate(main_splits): - ds = plaid_generator_to_huggingface( - generator, - infos, - problem_definition=problem_definition, - split=split, + for split_name, generator in generators.items(): + ds = plaid_generator_to_huggingface_binary( + generator=generator, processes_number=processes_number, + split_name=split_name, ) - _dict[split] = ds + _dict[split_name] = ds return datasets.DatasetDict(_dict) -def huggingface_description_to_problem_definition( - description: dict, -) -> ProblemDefinition: - """Converts a Hugging Face dataset description to a plaid problem definition. - - Args: - description (dict): the description field of a Hugging Face dataset, containing the problem definition - - Returns: - problem_definition (ProblemDefinition): the plaid problem definition initialized from the Hugging Face dataset description - """ - problem_definition = ProblemDefinition() - problem_definition.set_task(description["task"]) - problem_definition.set_split(description["split"]) - problem_definition.add_input_scalars_names(description["in_scalars_names"]) - problem_definition.add_output_scalars_names(description["out_scalars_names"]) - problem_definition.add_input_timeseries_names(description["in_timeseries_names"]) - problem_definition.add_output_timeseries_names(description["out_timeseries_names"]) - problem_definition.add_input_fields_names(description["in_fields_names"]) - problem_definition.add_output_fields_names(description["out_fields_names"]) - problem_definition.add_input_meshes_names(description["in_meshes_names"]) - problem_definition.add_output_meshes_names(description["out_meshes_names"]) - - return problem_definition - - +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) def huggingface_dataset_to_plaid( ds: datasets.Dataset, ids: Optional[list[int]] = None, processes_number: int = 1, large_dataset: bool = False, verbose: bool = True, -) -> tuple[Dataset, ProblemDefinition]: +) -> Union[Dataset, ProblemDefinition]: """Use this function for converting a plaid dataset from a Hugging Face dataset. A Hugging Face dataset can be read from disk or the hub. From the hub, the @@ -441,10 +1411,9 @@ def huggingface_dataset_to_plaid( "Trying to parallelize with more processes than selected samples in dataset" ) - dataset = Dataset() + description = "Converting Hugging Face binary dataset to plaid" - if verbose: - print("Converting Hugging Face dataset to plaid dataset...") + dataset = Dataset() if large_dataset: if ids: @@ -463,6 +1432,7 @@ def parallel_convert(shard_path, n_workers): pool.imap(converter, range(len(converter.hf_ds))), total=len(converter.hf_ds), disable=not verbose, + desc=description, ) ) @@ -485,21 +1455,26 @@ def parallel_convert(shard_path, n_workers): else: indices = range(len(ds)) - with Pool(processes=processes_number) as pool: - for idx, sample in enumerate( - tqdm( - pool.imap(_HFToPlaidSampleConverter(ds), indices), - total=len(indices), - disable=not verbose, - ) + if processes_number == 1: + for idx in tqdm( + indices, total=len(indices), disable=not verbose, desc=description ): - dataset.add_sample(sample, id=indices[idx]) + sample = _HFToPlaidSampleConverter(ds)(idx) + dataset.add_sample(sample, id=idx) - infos = {} - if "legal" in ds.description: - infos["legal"] = ds.description["legal"] - if "data_production" in ds.description: - infos["data_production"] = ds.description["data_production"] + else: + with Pool(processes=processes_number) as pool: + for idx, sample in enumerate( + tqdm( + pool.imap(_HFToPlaidSampleConverter(ds), indices), + total=len(indices), + disable=not verbose, + desc=description, + ) + ): + dataset.add_sample(sample, id=indices[idx]) + + infos = huggingface_description_to_infos(ds.description) dataset.set_infos(infos) @@ -508,55 +1483,75 @@ def parallel_convert(shard_path, n_workers): return dataset, problem_definition -def streamed_huggingface_dataset_to_plaid( - hf_repo: str, - number_of_samples: int, -) -> tuple[ - Dataset, ProblemDefinition -]: # pragma: no cover (to prevent testing from downloading, this is run by examples) - """Use this function for creating a plaid dataset by streaming on Hugging Face. - - The indices of the retrieved sample is not controled. +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) +def huggingface_description_to_problem_definition( + description: dict, +) -> ProblemDefinition: + """Converts a Hugging Face dataset description to a plaid problem definition. Args: - hf_repo (str): the name of the repo on Hugging Face - number_of_samples (int): The number of samples to retrieve. + description (dict): the description field of a Hugging Face dataset, containing the problem definition Returns: - dataset (Dataset): the converted dataset. - problem_definition (ProblemDefinition): the problem definition generated from the Hugging Face dataset - - Notes: - .. code-block:: python - - from plaid.bridges.huggingface_bridge import streamed_huggingface_dataset_to_plaid - - dataset, pb_def = streamed_huggingface_dataset_to_plaid('PLAID-datasets/VKI-LS59', 2) + problem_definition (ProblemDefinition): the plaid problem definition initialized from the Hugging Face dataset description """ - ds_stream = load_hf_dataset_from_hub(hf_repo, split="all_samples", streaming=True) - - infos = {} - if "legal" in ds_stream.description: - infos["legal"] = ds_stream.description["legal"] - if "data_production" in ds_stream.description: - infos["data_production"] = ds_stream.description["data_production"] - - problem_definition = huggingface_description_to_problem_definition( - ds_stream.description - ) + description = {} if description == "" else description + problem_definition = ProblemDefinition() + for func, key in [ + (problem_definition.set_task, "task"), + (problem_definition.set_split, "split"), + (problem_definition.add_input_scalars_names, "in_scalars_names"), + (problem_definition.add_output_scalars_names, "out_scalars_names"), + (problem_definition.add_input_fields_names, "in_fields_names"), + (problem_definition.add_output_fields_names, "out_fields_names"), + (problem_definition.add_input_meshes_names, "in_meshes_names"), + (problem_definition.add_output_meshes_names, "out_meshes_names"), + ]: + try: + func(description[key]) + except KeyError: + logger.info(f"Could not retrieve key:'{key}' from description") + pass - samples = [] - for _ in range(number_of_samples): - hf_sample = next(iter(ds_stream)) - samples.append(to_plaid_sample(hf_sample)) + return problem_definition - dataset = Dataset(samples=samples) - dataset.set_infos(infos) +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.10", + removal="1.0.0", +) +def huggingface_description_to_infos( + description: dict, +) -> dict[str, dict[str, str]]: + """Convert a Hugging Face dataset description dictionary to a PLAID infos dictionary. - return dataset, problem_definition + Extracts the "legal" and "data_production" sections from the Hugging Face description + and returns them in a format compatible with PLAID dataset infos. + Args: + description (dict): The Hugging Face dataset description dictionary. + Returns: + dict[str, dict[str, str]]: Dictionary containing "legal" and "data_production" infos if present. + """ + infos = {} + if "legal" in description: + infos["legal"] = description["legal"] + if "data_production" in description: + infos["data_production"] = description["data_production"] + return infos + + +@deprecated( + "will be removed (this hf format will not be not maintained)", + version="0.1.9", + removal="0.2.0", +) def create_string_for_huggingface_dataset_card( description: dict, download_size_bytes: int, diff --git a/src/plaid/bridges/huggingface_helpers.py b/src/plaid/bridges/huggingface_helpers.py index 4fdd65ba..cabcd728 100644 --- a/src/plaid/bridges/huggingface_helpers.py +++ b/src/plaid/bridges/huggingface_helpers.py @@ -14,7 +14,7 @@ from datasets import load_from_disk from plaid import Sample -from plaid.bridges.huggingface_bridge import to_plaid_sample +from plaid.bridges.huggingface_bridge import binary_to_plaid_sample class _HFToPlaidSampleConverter: @@ -24,7 +24,7 @@ def __init__(self, hf_ds: Union[datasets.Dataset, datasets.DatasetDict]): self.hf_ds = hf_ds def __call__(self, sample_id: int) -> Sample: # pragma: no cover - return to_plaid_sample(self.hf_ds[sample_id]) + return binary_to_plaid_sample(self.hf_ds[sample_id]) class _HFShardToPlaidSampleConverter: @@ -44,4 +44,4 @@ def __call__( self, sample_id: int ) -> Sample: # pragma: no cover (not reported with multiprocessing) """Convert a sample shard from the huggingface dataset to a plaid :class:`Sample `.""" - return to_plaid_sample(self.hf_ds[sample_id]) + return binary_to_plaid_sample(self.hf_ds[sample_id]) diff --git a/src/plaid/containers/features.py b/src/plaid/containers/features.py index 86b4adab..7d623f87 100644 --- a/src/plaid/containers/features.py +++ b/src/plaid/containers/features.py @@ -144,8 +144,6 @@ def set_default_zone_base( self._default_active_zone = zone_name - # -------------------------------------------------------------------------# - def set_default_time(self, time: float) -> None: """Set the default time for the system. diff --git a/src/plaid/examples/dataset.py b/src/plaid/examples/dataset.py index ac1af71a..55f9f7fa 100644 --- a/src/plaid/examples/dataset.py +++ b/src/plaid/examples/dataset.py @@ -7,7 +7,7 @@ # # from plaid import Dataset -from plaid.bridges.huggingface_bridge import streamed_huggingface_dataset_to_plaid +from plaid.bridges.huggingface_bridge import load_dataset_from_hub, binary_to_plaid_sample from plaid.examples.config import _HF_REPOS @@ -40,7 +40,12 @@ def _load_dataset( return self._cache[ex_name] try: - dataset, _ = streamed_huggingface_dataset_to_plaid(hf_repo, 2) + ds_stream = load_dataset_from_hub(hf_repo, split="all_samples", streaming=True) + samples = [] + for _ in range(2): + hf_sample = next(iter(ds_stream)) + samples.append(binary_to_plaid_sample(hf_sample)) + dataset = Dataset(samples=samples) self._cache[ex_name] = dataset return dataset except Exception as e: # pragma: no cover diff --git a/src/plaid/problem_definition.py b/src/plaid/problem_definition.py index 82881f53..e724096c 100644 --- a/src/plaid/problem_definition.py +++ b/src/plaid/problem_definition.py @@ -1126,6 +1126,26 @@ def get_all_indices(self) -> list[int]: return list(set(all_indices)) # -------------------------------------------------------------------------# + def _generate_problem_infos_dict(self) -> dict[str, Union[str, list]]: + """Generate a dictionary containing all relevant problem definition data. + + Returns: + dict[str, Union[str, list]]: A dictionary with keys for task, input/output features, scalars, fields, timeseries, and meshes. + """ + return { + "task": self._task, + "input_features": [dict(**d) for d in self.in_features_identifiers], + "output_features": [dict(**d) for d in self.out_features_identifiers], + "input_scalars": self.in_scalars_names, # list[input scalar name] + "output_scalars": self.out_scalars_names, # list[output scalar name] + "input_fields": self.in_fields_names, # list[input field name] + "output_fields": self.out_fields_names, # list[output field name] + "input_timeseries": self.in_timeseries_names, # list[input timeseries name] + "output_timeseries": self.out_timeseries_names, # list[output timeseries name] + "input_meshes": self.in_meshes_names, # list[input mesh name] + "output_meshes": self.out_meshes_names, # list[output mesh name] + } + def _save_to_dir_(self, path: Union[str, Path]) -> None: """Save problem information, inputs, outputs, and split to the specified directory in YAML and CSV formats. @@ -1142,30 +1162,20 @@ def _save_to_dir_(self, path: Union[str, Path]) -> None: path = Path(path) if not (path.is_dir()): - path.mkdir() + path.mkdir(parents=True) - data = { - "task": self._task, - "input_features": [dict(**d) for d in self.in_features_identifiers], - "output_features": [dict(**d) for d in self.out_features_identifiers], - "input_scalars": self.in_scalars_names, # list[input scalar name] - "output_scalars": self.out_scalars_names, # list[output scalar name] - "input_fields": self.in_fields_names, # list[input field name] - "output_fields": self.out_fields_names, # list[output field name] - "input_timeseries": self.in_timeseries_names, # list[input timeseries name] - "output_timeseries": self.out_timeseries_names, # list[output timeseries name] - "input_meshes": self.in_meshes_names, # list[input mesh name] - "output_meshes": self.out_meshes_names, # list[output mesh name] - } + problem_infos_dict = self._generate_problem_infos_dict() pbdef_fname = path / "problem_infos.yaml" with pbdef_fname.open("w") as file: - yaml.dump(data, file, default_flow_style=False, sort_keys=False) + yaml.dump( + problem_infos_dict, file, default_flow_style=False, sort_keys=False + ) split_fname = path / "split.json" - if self._split is not None: + if self.get_split() is not None: with split_fname.open("w") as file: - json.dump(self._split, file) + json.dump(self.get_split(), file) @classmethod def load(cls, path: Union[str, Path]) -> Self: # pragma: no cover @@ -1181,6 +1191,25 @@ def load(cls, path: Union[str, Path]) -> Self: # pragma: no cover instance._load_from_dir_(path) return instance + def _initialize_from_problem_infos_dict( + self, data: dict[str, Union[str, list]] + ) -> None: + self._task = data["task"] + self.in_features_identifiers = [ + FeatureIdentifier(**tup) for tup in data["input_features"] + ] + self.out_features_identifiers = [ + FeatureIdentifier(**tup) for tup in data["output_features"] + ] + self.in_scalars_names = data["input_scalars"] + self.out_scalars_names = data["output_scalars"] + self.in_fields_names = data["input_fields"] + self.out_fields_names = data["output_fields"] + self.in_timeseries_names = data["input_timeseries"] + self.out_timeseries_names = data["output_timeseries"] + self.in_meshes_names = data["input_meshes"] + self.out_meshes_names = data["output_meshes"] + def _load_from_dir_(self, path: Union[str, Path]) -> None: """Load problem information, inputs, outputs, and split from the specified directory in YAML and CSV formats. @@ -1216,21 +1245,7 @@ def _load_from_dir_(self, path: Union[str, Path]) -> None: f"file with path `{pbdef_fname}` does not exist. Abort" ) - self._task = data["task"] - self.in_features_identifiers = [ - FeatureIdentifier(**tup) for tup in data["input_features"] - ] - self.out_features_identifiers = [ - FeatureIdentifier(**tup) for tup in data["output_features"] - ] - self.in_scalars_names = data["input_scalars"] - self.out_scalars_names = data["output_scalars"] - self.in_fields_names = data["input_fields"] - self.out_fields_names = data["output_fields"] - self.in_timeseries_names = data["input_timeseries"] - self.out_timeseries_names = data["output_timeseries"] - self.in_meshes_names = data["input_meshes"] - self.out_meshes_names = data["output_meshes"] + self._initialize_from_problem_infos_dict(data) # if it was saved with version <=0.1.7 it is a .csv else it is .json split = {} @@ -1248,7 +1263,7 @@ def _load_from_dir_(self, path: Union[str, Path]) -> None: logger.warning( f"file with path `{split_fname_csv}` or `{split_fname_json}` does not exist. Splits will not be set" ) - self._split = split + self.set_split(split) def extract_problem_definition_from_identifiers( self, identifiers: list[FeatureIdentifier] diff --git a/src/plaid/types/__init__.py b/src/plaid/types/__init__.py index b42de25d..d777be58 100644 --- a/src/plaid/types/__init__.py +++ b/src/plaid/types/__init__.py @@ -18,7 +18,6 @@ Field, Scalar, TimeSequence, - TimeSeries, ) from plaid.types.sklearn_types import SklearnBlock @@ -31,7 +30,6 @@ "Scalar", "Field", "TimeSequence", - "TimeSeries", "Feature", "FeatureIdentifier", "SklearnBlock", diff --git a/src/plaid/types/feature_types.py b/src/plaid/types/feature_types.py index 5e0d663e..9187d446 100644 --- a/src/plaid/types/feature_types.py +++ b/src/plaid/types/feature_types.py @@ -22,10 +22,9 @@ Scalar: TypeAlias = Union[float, int] Field: TypeAlias = Array TimeSequence: TypeAlias = Array -TimeSeries: TypeAlias = tuple[TimeSequence, Field] # Feature data types -Feature: TypeAlias = Union[Scalar, Field, TimeSeries, Array] +Feature: TypeAlias = Union[Scalar, Field, Array] # Identifiers diff --git a/src/plaid/utils/base.py b/src/plaid/utils/base.py index 745a0c34..58a1b53d 100644 --- a/src/plaid/utils/base.py +++ b/src/plaid/utils/base.py @@ -9,9 +9,11 @@ # %% Imports +import os from functools import wraps import numpy as np +import psutil # %% Functions @@ -52,6 +54,12 @@ def safe_len(obj): return len(obj) if hasattr(obj, "__len__") else 0 +def get_mem(): + """Get the current memory usage of the process in MB.""" + process = psutil.Process(os.getpid()) + return process.memory_info().rss / (1024**2) # in MB + + def delegate_methods(to: str, methods: list[str]): """Class decorator to forward specific methods from a delegate attribute.""" diff --git a/src/plaid/utils/cgns_helper.py b/src/plaid/utils/cgns_helper.py index 9de7d7d3..228acbd9 100644 --- a/src/plaid/utils/cgns_helper.py +++ b/src/plaid/utils/cgns_helper.py @@ -7,6 +7,8 @@ # # +from typing import Any, Optional + import CGNS.PAT.cgnsutils as CGU import numpy as np @@ -67,11 +69,11 @@ def get_time_values(tree: CGNSTree) -> np.ndarray: return time_values[0] -def show_cgns_tree(pyTree: list, pre: str = ""): +def show_cgns_tree(pyTree: CGNSTree, pre: str = ""): """Pretty print for CGNS Tree. Args: - pyTree (list): CGNS tree to print + pyTree (CGNSTree): CGNS tree to print pre (str, optional): indentation of print. Defaults to ''. """ if not (isinstance(pyTree, list)): @@ -107,11 +109,372 @@ def printValue(node): np.set_printoptions(edgeitems=3, threshold=1000) -def summarize_cgns_tree(pyTree: list, verbose=True) -> str: +def flatten_cgns_tree( + pyTree: CGNSTree, +) -> tuple[dict[str, object], dict[str, str]]: + """Flatten a CGNS tree into dictionaries of primitives for Hugging Face serialization. + + Traverses the CGNS tree and produces: + - flat: a dictionary mapping paths to primitive values (lists, scalars, or None) + - dtypes: a dictionary mapping paths to dtype strings + - extras: a dictionary mapping paths to extra CGNS metadata + + Args: + pyTree (CGNSTree): The CGNS tree to flatten. + + Returns: + tuple[dict[str, object], dict[str, str], dict[str, object]]: + - flat: dict of paths to primitive values + - dtypes: dict of paths to dtype strings + - extras: dict of paths to extra CGNS metadata + + Example: + >>> flat, dtypes, extras = flatten_cgns_tree(pyTree) + >>> flat["Base1/Zone1/Solution1/Field1"] # [1.0, 2.0, ...] + >>> dtypes["Base1/Zone1/Solution1/Field1"] # 'float64' + """ + flat = {} + cgns_types = {} + + def visit(tree, path=""): + for node in tree[2]: + name, data, children, cgns_type = node + new_path = f"{path}/{name}" if path else name + + flat[new_path] = data + cgns_types[new_path] = cgns_type + + if children: + visit(node, new_path) + + visit(pyTree) + return flat, cgns_types + + +def nodes_to_tree(nodes: dict[str, CGNSTree]) -> Optional[CGNSTree]: + """Reconstruct a CGNS tree from a dictionary of nodes keyed by their paths. + + Each node is assumed to follow the CGNSTree format: + [name: str, data: Any, children: List[CGNSTree], cgns_type: str] + + The dictionary keys are the full paths to each node, e.g. "Base1/Zone1/Field1". + + Args: + nodes (Dict[str, CGNSTree]): A dictionary mapping node paths to CGNSTree nodes. + + Returns: + Optional[CGNSTree]: The root CGNSTree node with all children linked, + or None if the input dictionary is empty. + + Notes: + - Nodes with a path of length 1 are treated as root-level nodes. + - The root node is named "CGNSTree" with type "CGNSTree_t". + - Parent-child relationships are reconstructed using path prefixes. + """ + root = None + for path, node in nodes.items(): + parts = path.split("/") + if len(parts) == 1: + # root-level node + if root is None: + root = ["CGNSTree", None, [node], "CGNSTree_t"] + else: + root[2].append(node) + else: + parent_path = "/".join(parts[:-1]) + parent = nodes[parent_path] + parent[2].append(node) + return root + + +def unflatten_cgns_tree( + flat: dict[str, object], + cgns_types: dict[str, str], +) -> CGNSTree: + """Reconstruct a CGNS tree from flattened dictionaries of data and types. + + This function takes a "flat" representation of a CGNS tree, where each node + is stored in a dictionary keyed by its full path (e.g., "Base1/Zone1/Field1"), + and another dictionary mapping each path to its CGNS type. It rebuilds the + original tree structure by creating nodes and linking them according to their paths. + + Args: + flat (dict[str, object]): Dictionary mapping node paths to their data values. + The data can be a scalar, list, numpy array, or None. + cgns_types (dict[str, str]): Dictionary mapping node paths to CGNS type strings + (e.g., "Zone_t", "FlowSolution_t"). + + Returns: + CGNSTree: The reconstructed CGNS tree with nodes properly nested according + to their paths. Each node is a list in the format: + [name: str, data: Any, children: List[CGNSTree], cgns_type: str] + + Example: + >>> flat = { + >>> "Base1": None, + >>> "Base1/Zone1": [10, 20], + >>> "Base1/Zone1/Field1": [1.0, 2.0] + >>> } + >>> cgns_types = { + >>> "Base1": "CGNSBase_t", + >>> "Base1/Zone1": "Zone_t", + >>> "Base1/Zone1/Field1": "FlowSolution_t" + >>> } + >>> tree = unflatten_cgns_tree(flat, cgns_types) + """ + # Build all nodes from paths + nodes = {} + + for path, value in flat.items(): + cgns_type = cgns_types.get(path) + nodes[path] = [path.split("/")[-1], value, [], cgns_type] + + # Re-link nodes into tree structure + return nodes_to_tree(nodes) + + +def fix_cgns_tree_types(tree: CGNSTree) -> CGNSTree: + """Recursively fix the data types of a CGNS tree node and its children. + + This function ensures that data arrays match the expected CGNS types: + - "IndexArray_t": converted to integer arrays and stacked + - "Zone_t": stacked as numpy arrays + - "Elements_t", "CGNSBase_t", "BaseIterativeData_t": converted to integer arrays + + Args: + tree (CGNSTree): A CGNS tree of the form + [name: str, data: Any, children: List[CGNSTree], cgns_type: str]. + + Returns: + CGNSTree: A new CGNS tree node with corrected data types and recursively + fixed children. + + Example: + >>> node = ["Zone1", [[1, 2], [3, 4]], [], "Zone_t"] + >>> fixed_node = fix_cgns_tree_types(node) + >>> fixed_node[1].shape + (2, 2) + """ + name, data, children, cgns_type = tree + + # Fix data types according to CGNS type + if data is not None: + if cgns_type == "IndexArray_t": + data = CGU.setIntegerAsArray(*data) + data = np.stack(data) + elif cgns_type == "Zone_t": + data = np.stack(data) + elif cgns_type in ["Elements_t", "CGNSBase_t", "BaseIterativeData_t"]: + data = CGU.setIntegerAsArray(*data) + + # Recursively fix children + new_children = [] + if children: + for child in children: + new_children.append(fix_cgns_tree_types(child)) + + return [name, data, new_children, cgns_type] + + +def compare_cgns_trees( + tree1: CGNSTree, + tree2: CGNSTree, + path: str = "CGNSTree", +) -> bool: + """Recursively compare two CGNS trees for exact equality, ignoring the order of children. + + This function checks: + - Node names + - Node data (numpy arrays or scalars) with exact dtype and values + - Number and names of children nodes + - CGNS type (stored as the extra field) + + It prints informative messages whenever a mismatch is found, including the + path in the tree where the mismatch occurs. + + Args: + tree1 (CGNSTree): The first CGNS tree node to compare. + tree2 (CGNSTree): The second CGNS tree node to compare. + path (str, optional): The current path in the tree for error messages. + Defaults to "CGNSTree". + + Returns: + bool: True if the trees are identical (including node names, data, types, + and children), False otherwise. + + Example: + >>> identical = compare_cgns_trees(tree1, tree2) + >>> if identical: + >>> print("The trees are identical") + >>> else: + >>> print("The trees differ") + """ + # Compare node name + if tree1[0] != tree2[0]: + print(f"Name mismatch at {path}: {tree1[0]} != {tree2[0]}") + return False + + # Compare data + data1, data2 = tree1[1], tree2[1] + + if data1 is None and data2 is None: + pass + elif isinstance(data1, np.ndarray) and isinstance(data2, np.ndarray): + if data1.dtype != data2.dtype: + print( + f"Dtype mismatch at {path}/{tree1[0]}: {data1.dtype} != {data2.dtype}" + ) + return False + if len(data1) == 0 and len(data2) == 0: + pass + elif not np.array_equal(data1, data2): + print(f"Data mismatch at {path}/{tree1[0]}") + return False + else: + if isinstance(data1, np.ndarray) or isinstance(data2, np.ndarray): + print(f"Data type mismatch at {path}/{tree1[0]}") + return False + + # Compare extra (CGNS type) + extra1, extra2 = tree1[3], tree2[3] + if extra1 != extra2: + print(f"Type mismatch at {path}/{tree1[0]}: {extra1} != {extra2}") + return False + + # Compare children ignoring order + children1_dict = {c[0]: c for c in tree1[2] or []} + children2_dict = {c[0]: c for c in tree2[2] or []} + + if set(children1_dict.keys()) != set(children2_dict.keys()): + print( + f"Children name mismatch at {path}/{tree1[0]}: {set(children1_dict.keys())} != {set(children2_dict.keys())}" + ) + return False + + # Recursively compare children + for name in children1_dict: + if not compare_cgns_trees( + children1_dict[name], children2_dict[name], path=f"{path}/{tree1[0]}" + ): + return False + + return True + + +def compare_leaves(d1: Any, d2: Any) -> bool: + """Compare two leaf values in a CGNS tree or flattened structure, handling arrays and scalars. + + This function supports: + - NumPy arrays, including byte arrays (converted to str) + - Floating-point arrays or scalars (compared with tolerance) + - Integer arrays or scalars (exact comparison) + - Strings and None + + Args: + d1 (Any): First value to compare (scalar or np.ndarray). + d2 (Any): Second value to compare (scalar or np.ndarray). + + Returns: + bool: True if the values are considered equal, False otherwise. + + Notes: + - Floating-point comparisons use `np.allclose` or `np.isclose` with `rtol=1e-7` and `atol=0`. + - Byte arrays (`dtype.kind == "S"`) are converted to string before comparison. + + Examples: + >>> compare_leaves(np.array([1.0, 2.0]), np.array([1.0, 2.0])) + True + >>> compare_leaves(3.0, 3.00000001) + True + >>> compare_leaves(np.array([1, 2]), np.array([2, 1])) + False + """ + # Convert bytes arrays to str + if isinstance(d1, np.ndarray) and d1.dtype.kind == "S": + d1 = d1.astype(str) + if isinstance(d2, np.ndarray) and d2.dtype.kind == "S": + d2 = d2.astype(str) + + # Both arrays + if isinstance(d1, np.ndarray) and isinstance(d2, np.ndarray): + if np.issubdtype(d1.dtype, np.floating) or np.issubdtype(d2.dtype, np.floating): + return np.allclose(d1, d2, rtol=1e-7, atol=0) + else: + return np.array_equal(d1, d2) + + # Scalars (int/float/str/None) + if isinstance(d1, float) or isinstance(d2, float): + return np.isclose(d1, d2, rtol=1e-7, atol=0) + return d1 == d2 + + +def compare_cgns_trees_no_types( + tree1: CGNSTree, tree2: CGNSTree, path: str = "CGNSTree" +) -> bool: + """Recursively compare two CGNS trees ignoring the order of children and relaxing strict type checks. + + This function is useful for heterogeneous or nested CGNS samples, + such as those encountered in Hugging Face Arrow datasets. It compares: + - Node names + - Node data using `compare_leaves` (supports arrays, scalars, strings) + - CGNS type (extra field) + - Children nodes by name, ignoring their order + + Args: + tree1 (CGNSTree): The first CGNS tree node to compare. + tree2 (CGNSTree): The second CGNS tree node to compare. + path (str, optional): Path for error reporting. Defaults to "CGNSTree". + + Returns: + bool: True if the trees are considered equivalent, False otherwise. + + Example: + >>> identical = compare_cgns_trees_no_types(tree1, tree2) + >>> if identical: + >>> print("The trees match ignoring types") + >>> else: + >>> print("The trees differ") + """ + if tree1[0] != tree2[0]: + print(f"Name mismatch at {path}: {tree1[0]} != {tree2[0]}") + return False + + # Compare data using recursive helper + data1, data2 = tree1[1], tree2[1] + if not compare_leaves(data1, data2): + print(f"Data mismatch at {path}/{tree1[0]}") + return False + + # Compare extra (CGNS type) + if tree1[3] != tree2[3]: + print(f"Type mismatch at {path}/{tree1[0]}: {tree1[3]} != {tree2[3]}") + return False + + # Compare children ignoring order + children1_dict = {c[0]: c for c in tree1[2] or []} + children2_dict = {c[0]: c for c in tree2[2] or []} + + if set(children1_dict.keys()) != set(children2_dict.keys()): + print( + f"Children name mismatch at {path}/{tree1[0]}: {set(children1_dict.keys())} != {set(children2_dict.keys())}" + ) + return False + + # Recursively compare children + for name in children1_dict: + if not compare_cgns_trees_no_types( + children1_dict[name], children2_dict[name], path=f"{path}/{tree1[0]}" + ): + return False + + return True + + +def summarize_cgns_tree(pyTree: CGNSTree, verbose=True) -> str: """Provide a summary of a CGNS tree's contents. Args: - pyTree (list): The CGNS tree to summarize. + pyTree (CGNSTree): The CGNS tree to summarize. verbose (bool, optional): If True, include detailed field information. Defaults to True. Example: diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index e89ca974..0177b83b 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -8,22 +8,33 @@ # %% Imports import pickle +import shutil +from pathlib import Path from typing import Callable import pytest from plaid.bridges import huggingface_bridge -from plaid.bridges.huggingface_bridge import to_plaid_sample from plaid.containers.dataset import Dataset from plaid.containers.sample import Sample from plaid.problem_definition import ProblemDefinition +from plaid.utils import cgns_helper + + +@pytest.fixture() +def current_directory(): + return Path(__file__).absolute().parent # %% Fixtures @pytest.fixture() def dataset(samples, infos) -> Dataset: - dataset = Dataset() - dataset.add_samples(samples[:2]) + samples_ = [] + for i, sample in enumerate(samples): + if i == 1: + sample.add_scalar("toto", 1.0) + samples_.append(sample) + dataset = Dataset(samples=samples_) dataset.set_infos(infos) return dataset @@ -39,50 +50,201 @@ def problem_definition() -> ProblemDefinition: @pytest.fixture() def generator(dataset) -> Callable: - def generator(): - for id in range(len(dataset)): + def generator_(): + for sample in dataset: + yield sample + + return generator_ + + +@pytest.fixture() +def generator_split(dataset, problem_definition) -> dict[str, Callable]: + generators_ = {} + for split_name, ids in problem_definition.get_split().items(): + + def generator_(ids=ids): + for id in ids: + yield dataset[id] + + generators_[split_name] = generator_ + return generators_ + + +@pytest.fixture() +def generator_binary(dataset) -> Callable: + def generator_(): + for sample in dataset: yield { - "sample": pickle.dumps(dataset[id]), + "sample": pickle.dumps(sample), } - return generator + return generator_ + + +@pytest.fixture() +def generator_split_binary(dataset, problem_definition) -> dict[str, Callable]: + generators_ = {} + for split_name, ids in problem_definition.get_split().items(): + + def generator_(ids=ids): + for id in ids: + yield {"sample": pickle.dumps(dataset[id])} + + generators_[split_name] = generator_ + return generators_ @pytest.fixture() -def hf_dataset(generator, infos, problem_definition) -> Dataset: - hf_dataset = huggingface_bridge.plaid_generator_to_huggingface( - generator, infos, problem_definition +def hf_dataset(generator_binary) -> Dataset: + hf_dataset = huggingface_bridge.plaid_generator_to_huggingface_binary( + generator_binary ) return hf_dataset class Test_Huggingface_Bridge: - def assert_hf_dataset(self, hfds): - assert hfds.description["legal"] == {"owner": "PLAID2", "license": "BSD-3"} - assert hfds.description["task"] == "regression" - assert hfds.description["in_scalars_names"][0] == "feature_name_1" - assert hfds.description["in_scalars_names"][1] == "feature_name_2" - self.assert_sample(to_plaid_sample(hfds[0])) - - def assert_plaid_dataset(self, ds, pbdef): - assert ds.get_infos()["legal"] == {"owner": "PLAID2", "license": "BSD-3"} - assert pbdef.get_input_scalars_names()[0] == "feature_name_1" - assert pbdef.get_input_scalars_names()[1] == "feature_name_2" - self.assert_sample(ds[0]) - def assert_sample(self, sample): assert isinstance(sample, Sample) assert sample.get_scalar_names()[0] == "test_scalar" assert "test_field_same_size" in sample.get_field_names() assert sample.get_field("test_field_same_size").shape[0] == 17 - def test_to_plaid_sample(self, generator, infos, problem_definition): - hfds = huggingface_bridge.plaid_generator_to_huggingface( - generator, infos, problem_definition + def assert_hf_dataset_binary(self, hfds_binary): + self.assert_sample(huggingface_bridge.binary_to_plaid_sample(hfds_binary[0])) + + def assert_plaid_dataset(self, ds): + self.assert_sample(ds[0]) + + # ------------------------------------------------------------------------------ + # HUGGING FACE BRIDGE (with tree flattening and pyarrow tables) + # ------------------------------------------------------------------------------ + + def test_to_cgns_tree_columnar(self, dataset, problem_definition): + main_splits = problem_definition.get_split() + hf_dataset_dict, flat_cst, key_mappings = ( + huggingface_bridge.plaid_dataset_to_huggingface_datasetdict( + dataset, main_splits + ) + ) + huggingface_bridge.to_cgns_tree_columnar( + hf_dataset_dict["train"], 0, flat_cst, key_mappings["cgns_types"] + ) + huggingface_bridge.to_cgns_tree_columnar( + hf_dataset_dict["train"], + 0, + flat_cst, + key_mappings["cgns_types"], + enforce_shapes=True, ) - to_plaid_sample(hfds[0]) - def test_to_plaid_sample_fallback_build_succeeds(self, dataset): + def test_with_datasetdict(self, dataset, problem_definition): + main_splits = problem_definition.get_split() + hf_dataset_dict, flat_cst, key_mappings = ( + huggingface_bridge.plaid_dataset_to_huggingface_datasetdict( + dataset, main_splits + ) + ) + huggingface_bridge.to_plaid_sample_columnar( + hf_dataset_dict["train"], 0, flat_cst, key_mappings["cgns_types"] + ) + huggingface_bridge.to_plaid_sample_columnar( + hf_dataset_dict["test"], + 0, + flat_cst, + key_mappings["cgns_types"], + enforce_shapes=True, + ) + huggingface_bridge.to_plaid_dataset( + hf_dataset_dict["train"], flat_cst, key_mappings["cgns_types"] + ) + huggingface_bridge.to_plaid_dataset( + hf_dataset_dict["test"], + flat_cst=flat_cst, + cgns_types=key_mappings["cgns_types"], + enforce_shapes=True, + ) + huggingface_bridge.to_plaid_sample( + hf_dataset_dict["train"][0], flat_cst, key_mappings["cgns_types"] + ) + cgns_helper.compare_cgns_trees(dataset[0].get_mesh(), dataset[0].get_mesh()) + cgns_helper.compare_cgns_trees_no_types( + dataset[0].get_mesh(), dataset[0].get_mesh() + ) + + def test_with_generator(self, generator_split): + hf_dataset_dict, flat_cst, key_mappings = ( + huggingface_bridge.plaid_generator_to_huggingface_datasetdict( + generator_split + ) + ) + huggingface_bridge.to_plaid_sample_columnar( + hf_dataset_dict["train"], 0, flat_cst, key_mappings["cgns_types"] + ) + huggingface_bridge.to_plaid_sample_columnar( + hf_dataset_dict["test"], + 0, + flat_cst, + key_mappings["cgns_types"], + enforce_shapes=True, + ) + + # ------------------------------------------------------------------------------ + # HUGGING FACE INTERACTIONS ON DISK + # ------------------------------------------------------------------------------ + + def test_save_load_to_disk( + self, current_directory, generator_split, infos, problem_definition + ): + hf_dataset_dict, flat_cst, key_mappings = ( + huggingface_bridge.plaid_generator_to_huggingface_datasetdict( + generator_split + ) + ) + + test_dir = Path(current_directory) / Path("test") + huggingface_bridge.save_dataset_dict_to_disk(test_dir, hf_dataset_dict) + huggingface_bridge.save_infos_to_disk(test_dir, infos) + huggingface_bridge.save_problem_definition_to_disk( + test_dir, "task_1", problem_definition + ) + huggingface_bridge.save_tree_struct_to_disk(test_dir, flat_cst, key_mappings) + + huggingface_bridge.load_dataset_from_disk(test_dir) + huggingface_bridge.load_infos_from_disk(test_dir) + huggingface_bridge.load_problem_definition_from_disk(test_dir, "task_1") + huggingface_bridge.load_tree_struct_from_disk(test_dir) + shutil.rmtree(test_dir) + + # ------------------------------------------------------------------------------ + # DEPRECATED HUGGING FACE BRIDGE (binary blobs) + # ------------------------------------------------------------------------------ + + def test_save_load_to_disk_binary( + self, current_directory, generator_split_binary, infos, problem_definition + ): + hf_dataset_dict = ( + huggingface_bridge.plaid_generator_to_huggingface_datasetdict_binary( + generator_split_binary + ) + ) + test_dir = Path(current_directory) / Path("test") + huggingface_bridge.save_dataset_dict_to_disk(test_dir, hf_dataset_dict) + huggingface_bridge.save_infos_to_disk(test_dir, infos) + huggingface_bridge.save_problem_definition_to_disk( + test_dir, "task_1", problem_definition + ) + huggingface_bridge.load_dataset_from_disk(test_dir) + huggingface_bridge.load_infos_from_disk(test_dir) + huggingface_bridge.load_problem_definition_from_disk(test_dir, "task_1") + shutil.rmtree(test_dir) + + def test_binary_to_plaid_sample(self, generator_binary): + hfds = huggingface_bridge.plaid_generator_to_huggingface_binary( + generator_binary + ) + huggingface_bridge.binary_to_plaid_sample(hfds[0]) + + def test_binary_to_plaid_sample_fallback_build_succeeds(self, dataset): sample = dataset[0] old_hf_sample = { "path": getattr(sample, "path", None), @@ -90,71 +252,86 @@ def test_to_plaid_sample_fallback_build_succeeds(self, dataset): "meshes": sample.features.data, } old_hf_sample = {"sample": pickle.dumps(old_hf_sample)} - plaid_sample = to_plaid_sample(old_hf_sample) + plaid_sample = huggingface_bridge.binary_to_plaid_sample(old_hf_sample) assert isinstance(plaid_sample, Sample) - def test_plaid_dataset_to_huggingface(self, dataset, problem_definition): - hfds = huggingface_bridge.plaid_dataset_to_huggingface( - dataset, problem_definition, split="train" + def test_plaid_dataset_to_huggingface_binary(self, dataset): + hfds = huggingface_bridge.plaid_dataset_to_huggingface_binary(dataset) + hfds = huggingface_bridge.plaid_dataset_to_huggingface_binary( + dataset, ids=[0, 1] ) - hfds = huggingface_bridge.plaid_dataset_to_huggingface( - dataset, problem_definition - ) - self.assert_hf_dataset(hfds) + self.assert_hf_dataset_binary(hfds) - def test_plaid_dataset_to_huggingface_datasetdict( + def test_plaid_dataset_to_huggingface_datasetdict_binary( self, dataset, problem_definition ): - huggingface_bridge.plaid_dataset_to_huggingface_datasetdict( - dataset, problem_definition, main_splits=["train", "test"] + huggingface_bridge.plaid_dataset_to_huggingface_datasetdict_binary( + dataset, main_splits=problem_definition.get_split() ) - def test_plaid_generator_to_huggingface(self, generator, infos, problem_definition): - hfds = huggingface_bridge.plaid_generator_to_huggingface( - generator, infos, problem_definition, split="train" + def test_plaid_generator_to_huggingface_binary(self, generator_binary): + hfds = huggingface_bridge.plaid_generator_to_huggingface_binary( + generator_binary ) - hfds = huggingface_bridge.plaid_generator_to_huggingface( - generator, infos, problem_definition + hfds = huggingface_bridge.plaid_generator_to_huggingface_binary( + generator_binary, processes_number=2 ) - self.assert_hf_dataset(hfds) + self.assert_hf_dataset_binary(hfds) - def test_plaid_generator_to_huggingface_datasetdict( - self, generator, infos, problem_definition + def test_plaid_generator_to_huggingface_datasetdict_binary( + self, generator_split_binary ): - huggingface_bridge.plaid_generator_to_huggingface_datasetdict( - generator, infos, problem_definition, main_splits=["train", "test"] + huggingface_bridge.plaid_generator_to_huggingface_datasetdict_binary( + generator_split_binary ) def test_huggingface_dataset_to_plaid(self, hf_dataset): - ds, pbdef = huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset) - self.assert_plaid_dataset(ds, pbdef) + ds, _ = huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset) + self.assert_plaid_dataset(ds) - def test_huggingface_dataset_to_plaid_with_ids(self, hf_dataset): + def test_huggingface_dataset_to_plaid_with_ids_binary(self, hf_dataset): huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset, ids=[0, 1]) - def test_huggingface_dataset_to_plaid_large(self, hf_dataset): + def test_huggingface_dataset_to_plaid_large_binary(self, hf_dataset): huggingface_bridge.huggingface_dataset_to_plaid( hf_dataset, processes_number=2, large_dataset=True ) - def test_huggingface_dataset_to_plaid_with_ids_large(self, hf_dataset): + def test_huggingface_dataset_to_plaid_large_binary_2(self, hf_dataset): + huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset, processes_number=2) + + def test_huggingface_dataset_to_plaid_with_ids_large_binary(self, hf_dataset): with pytest.raises(NotImplementedError): huggingface_bridge.huggingface_dataset_to_plaid( hf_dataset, ids=[0, 1], processes_number=2, large_dataset=True ) - def test_huggingface_dataset_to_plaid_error_processes_number(self, hf_dataset): + def test_huggingface_dataset_to_plaid_error_processes_number_binary( + self, hf_dataset + ): with pytest.raises(AssertionError): huggingface_bridge.huggingface_dataset_to_plaid( hf_dataset, processes_number=128 ) - def test_huggingface_dataset_to_plaid_error_processes_number_2(self, hf_dataset): + def test_huggingface_dataset_to_plaid_error_processes_number_binary_2( + self, hf_dataset + ): with pytest.raises(AssertionError): huggingface_bridge.huggingface_dataset_to_plaid( hf_dataset, ids=[0], processes_number=2 ) + def test_huggingface_description_to_problem_definition(self, hf_dataset): + huggingface_bridge.huggingface_description_to_problem_definition( + hf_dataset.description + ) + + def test_huggingface_description_to_infos(self, infos): + hf_description = {} + hf_description.update(infos) + huggingface_bridge.huggingface_description_to_infos(hf_description) + def test_create_string_for_huggingface_dataset_card(self, hf_dataset): huggingface_bridge.create_string_for_huggingface_dataset_card( description=hf_dataset.description, diff --git a/tests/conftest.py b/tests/conftest.py index 06f006f5..6cc01a47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,8 @@ def generate_samples(nb: int, zone_name: str, base_name: str) -> list[Sample]: sample.init_zone(np.array([0, 0, 0]), zone_name=zone_name, base_name=base_name) sample.add_scalar("test_scalar", float(i)) sample.add_scalar("test_scalar_2", float(i**2)) + sample.add_global("global_0", 0.5 + np.ones((2, 3))) + sample.add_global("global_1", 1.5 + i + np.ones((2, 3, 2))) sample.add_field( name="test_field_same_size", field=float(i**4) * np.ones(17), @@ -39,7 +41,7 @@ def generate_samples(nb: int, zone_name: str, base_name: str) -> list[Sample]: ) sample.add_field( name="test_field_2785", - field=float(i**5) * np.ones(3 * i), + field=float(i**5) * np.ones(3 * (i + 1)), zone_name=zone_name, base_name=base_name, ) diff --git a/tests/containers/test_dataset.py b/tests/containers/test_dataset.py index 5cd3e4b4..e5ed06a9 100644 --- a/tests/containers/test_dataset.py +++ b/tests/containers/test_dataset.py @@ -8,6 +8,7 @@ # %% Imports import copy +import shutil from pathlib import Path import numpy as np @@ -119,7 +120,7 @@ def test_get_all_features_identifiers_by_type( feature_type="scalar" ) ) - == 2 + == 4 ) assert ( len( @@ -144,7 +145,7 @@ def test_get_all_features_identifiers_by_type( ids=np.arange(np.random.randint(2, nb_samples)), ) ) - == 2 + == 4 ) assert ( len( @@ -152,7 +153,7 @@ def test_get_all_features_identifiers_by_type( feature_type="scalar", ids=[0, 0] ) ) - == 2 + == 4 ) def test_add_sample(self, dataset, sample): @@ -1075,6 +1076,7 @@ def test__add_to_dir__both_path_and_save_dir( save_dir = current_directory / "my_dataset_dir" with pytest.raises(ValueError): empty_dataset.add_to_dir(sample, path=save_dir, save_dir=save_dir) + shutil.rmtree(Path(save_dir)) # -------------------------------------------------------------------------# def test__save_to_dir_(self, dataset_with_samples, tmp_path): diff --git a/tests/utils/test_base.py b/tests/utils/test_base.py new file mode 100644 index 00000000..705a8399 --- /dev/null +++ b/tests/utils/test_base.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# +# + +# %% Imports + + +from plaid.utils import base + + +# %% Tests +class Test_base: + def test_generate_random_ASCII(self): + base.generate_random_ASCII() + + def test_safe_len(self): + assert base.safe_len([0, 1]) == 2 + assert base.safe_len(0) == 0 + + def test_get_mem(self): + base.get_mem() diff --git a/tests/utils/test_cgns_helper.py b/tests/utils/test_cgns_helper.py index 6b2c729c..659f41c9 100644 --- a/tests/utils/test_cgns_helper.py +++ b/tests/utils/test_cgns_helper.py @@ -7,14 +7,12 @@ # %% Imports +import copy + +import numpy as np import pytest -from plaid.utils.cgns_helper import ( - get_base_names, - get_time_values, - show_cgns_tree, - summarize_cgns_tree, -) +from plaid.utils import cgns_helper # %% Tests @@ -22,37 +20,85 @@ class Test_cgns_helper: def test_get_base_names(self, sample_with_tree): tree = sample_with_tree.get_mesh() # Test with full_path=False and unique=False - base_names = get_base_names(tree, full_path=False, unique=False) + base_names = cgns_helper.get_base_names(tree, full_path=False, unique=False) assert base_names == ["Base_2_2"] # Test with full_path=True and unique=False - base_names_full = get_base_names(tree, full_path=True, unique=False) + base_names_full = cgns_helper.get_base_names(tree, full_path=True, unique=False) print(base_names_full) assert base_names_full == ["/Base_2_2"] # Test with full_path=False and unique=True - base_names_unique = get_base_names(tree, full_path=False, unique=True) + base_names_unique = cgns_helper.get_base_names( + tree, full_path=False, unique=True + ) print(base_names_unique) assert base_names_unique == ["Base_2_2"] - def test_get_time_values(self, sample_with_tree): - tree = sample_with_tree.get_mesh() - time_value = get_time_values(tree) + def test_get_time_values(self, samples): + tree = samples[0].get_mesh() + time_value = cgns_helper.get_time_values(tree) assert time_value == 0.0 empty_tree = [] with pytest.raises(IndexError): - get_time_values(empty_tree) + cgns_helper.get_time_values(empty_tree) def test_show_cgns_tree(self, tree): - show_cgns_tree(tree) + cgns_helper.show_cgns_tree(tree) def test_show_cgns_tree_not_a_list(self): with pytest.raises(TypeError): - show_cgns_tree({1: 2}) + cgns_helper.show_cgns_tree({1: 2}) + + def test_fix_cgns_tree_types(self, tree): + cgns_helper.fix_cgns_tree_types(tree) + + def test_compare_cgns_trees(self, tree, samples): + assert cgns_helper.compare_cgns_trees(tree, tree) + assert not cgns_helper.compare_cgns_trees(tree, samples[0].get_mesh()) + + tree2 = copy.deepcopy(tree) + tree2[0] = "A" + assert not cgns_helper.compare_cgns_trees(tree, tree2) + + tree2[0] = tree[0] + tree2[1] = np.array([0], dtype=np.float32) + tree[1] = np.array([0], dtype=np.float64) + assert not cgns_helper.compare_cgns_trees(tree, tree2) + + tree[1] = np.array([1], dtype=np.float32) + assert not cgns_helper.compare_cgns_trees(tree, tree2) + + tree[1] = "A" + assert not cgns_helper.compare_cgns_trees(tree, tree2) + + tree[1] = tree2[1] + tree[3] = "A_t" + assert not cgns_helper.compare_cgns_trees(tree, tree2) + + tree[3] = tree2[3] + tree[2][0][3] = "A_t" + assert not cgns_helper.compare_cgns_trees(tree, tree2) + + def test_compare_cgns_trees_no_types(self, tree, samples): + assert cgns_helper.compare_cgns_trees_no_types(tree, tree) + assert not cgns_helper.compare_cgns_trees_no_types(tree, samples[0].get_mesh()) + + tree2 = copy.deepcopy(tree) + tree2[0] = "A" + assert not cgns_helper.compare_cgns_trees_no_types(tree, tree2) + + tree2[0] = tree[0] + tree[2][0][1] = 1.0 + assert not cgns_helper.compare_cgns_trees_no_types(tree, tree2) + + tree[2][0][1] = tree2[2][0][1] + tree[3] = "A_t" + assert not cgns_helper.compare_cgns_trees_no_types(tree, tree2) def test_summarize_cgns_tree(self, tree): - summarize_cgns_tree(tree, verbose=False) + cgns_helper.summarize_cgns_tree(tree, verbose=False) def test_summarize_cgns_tree_verbose(self, tree): - summarize_cgns_tree(tree, verbose=True) + cgns_helper.summarize_cgns_tree(tree, verbose=True)