Skip to content

Commit

Permalink
👌 Allow importing data from a mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
s-weigand committed Feb 24, 2023
1 parent 7b2157d commit d4293f7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
22 changes: 14 additions & 8 deletions glotaran/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
if TYPE_CHECKING:
from collections.abc import Hashable

from glotaran.typing.types import LoadableDataset

TEMPLATE = "version: {gta_version}"

PROJECT_FILE_NAME = "project.gta"
Expand Down Expand Up @@ -203,7 +205,7 @@ def load_data(

def import_data(
self,
dataset: str | Path | xr.Dataset | xr.DataArray,
dataset: LoadableDataset | Mapping[str, LoadableDataset],
dataset_name: str | None = None,
allow_overwrite: bool = False,
ignore_existing: bool = True,
Expand All @@ -212,7 +214,7 @@ def import_data(
Parameters
----------
dataset : str | Path | xr.Dataset | xr.DataArray
dataset : LoadableDataset
Dataset instance or path to a dataset.
dataset_name : str | None
The name of the dataset (needs to be provided when dataset is an xarray instance).
Expand All @@ -222,12 +224,16 @@ def import_data(
ignore_existing: bool
Whether to ignore import if the dataset already exists. Defaults to ``True``.
"""
self._data_registry.import_data(
dataset,
dataset_name=dataset_name,
allow_overwrite=allow_overwrite,
ignore_existing=ignore_existing,
)
if not isinstance(dataset, Mapping) or isinstance(dataset, (xr.Dataset, xr.DataArray)):
dataset = {dataset_name: dataset}

for key, value in dataset.items():
self._data_registry.import_data(
value,
dataset_name=key,
allow_overwrite=allow_overwrite,
ignore_existing=ignore_existing,
)

@property
def has_models(self) -> bool:
Expand Down
24 changes: 24 additions & 0 deletions glotaran/project/test/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from glotaran import __version__ as gta_version
from glotaran.builtin.io.yml.utils import load_dict
from glotaran.io import load_dataset
from glotaran.io import load_parameters
from glotaran.io import save_dataset
from glotaran.io import save_parameters
Expand Down Expand Up @@ -288,6 +289,29 @@ def test_import_data_xarray(tmp_path: Path, data: xr.Dataset | xr.DataArray):
assert project.load_data("test_data").equals(xr.Dataset({"data": xr.DataArray([1])}))


@pytest.mark.parametrize(
"data",
(
xr.DataArray([1]),
xr.Dataset({"data": xr.DataArray([1])}),
),
)
def test_import_data_mapping(tmp_path: Path, data: xr.Dataset | xr.DataArray):
"""Import data as a mapping"""
project = Project.open(tmp_path)

test_data = tmp_path / "import_data.nc"
save_dataset(example_dataset, test_data)

project.import_data({"test_data_1": data, "test_data_2": test_data})

assert (tmp_path / "data/test_data_1.nc").is_file() is True
assert (tmp_path / "data/test_data_2.nc").is_file() is True

assert project.load_data("test_data_1").equals(xr.Dataset({"data": xr.DataArray([1])}))
assert project.load_data("test_data_2").equals(load_dataset(test_data))


def test_create_scheme(existing_project: Project):
scheme = existing_project.create_scheme(
model_name="test_model",
Expand Down

0 comments on commit d4293f7

Please sign in to comment.