Skip to content

Commit

Permalink
✨ Allow import of xarray objects in project API import_data
Browse files Browse the repository at this point in the history
  • Loading branch information
s-weigand committed Feb 10, 2023
1 parent bcf7cc2 commit df5f34f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 18 deletions.
18 changes: 11 additions & 7 deletions glotaran/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def data(self) -> dict[str, Path]:
"""
return self._data_registry.items

def load_data(self, dataset_name: str) -> xr.Dataset | xr.DataArray:
def load_data(self, dataset_name: str) -> xr.Dataset:
"""Load a dataset.
Parameters
Expand All @@ -170,11 +170,14 @@ def load_data(self, dataset_name: str) -> xr.Dataset | xr.DataArray:
.. # noqa: DAR402
"""
return self._data_registry.load_item(dataset_name)
dataset = self._data_registry.load_item(dataset_name)
if isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset(name="data")
return dataset

def import_data(
self,
path: str | Path,
dataset: str | Path | xr.Dataset | xr.DataArray,
name: str | None = None,
allow_overwrite: bool = False,
ignore_existing: bool = False,
Expand All @@ -183,17 +186,18 @@ def import_data(
Parameters
----------
path : str | Path
The path to the dataset.
dataset : str | Path | xr.Dataset | xr.DataArray
Dataset instance or path to a dataset.
name : str | None
The name of the dataset.
The name of the dataset (needs to be provided when dataset is an xarray instance).
Defaults to None.
allow_overwrite: bool
Whether to overwrite an existing dataset.
ignore_existing: bool
Whether to ignore import if the dataset already exists.
"""
self._data_registry.import_data(
path, name=name, allow_overwrite=allow_overwrite, ignore_existing=ignore_existing
dataset, name=name, allow_overwrite=allow_overwrite, ignore_existing=ignore_existing
)

@property
Expand Down
38 changes: 27 additions & 11 deletions glotaran/project/project_data_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pathlib import Path

import xarray as xr

from glotaran.io import load_dataset
from glotaran.io import save_dataset
from glotaran.plugin_system.data_io_registration import supported_file_extensions_data_io
Expand All @@ -29,7 +31,7 @@ def __init__(self, directory: Path):

def import_data(
self,
path: str | Path,
dataset: str | Path | xr.Dataset | xr.DataArray,
name: str | None = None,
allow_overwrite: bool = False,
ignore_existing: bool = False,
Expand All @@ -38,25 +40,39 @@ def import_data(
Parameters
----------
path : str | Path
The path to the dataset.
dataset : str | Path | xr.Dataset | xr.DataArray
Dataset instance or path to a dataset.
name : str | None
The name of the dataset.
The name of the dataset (needs to be provided when dataset is an xarray instance).
Defaults to None.
allow_overwrite: bool
Whether to overwrite an existing dataset.
ignore_existing: bool
Whether to ignore import if the dataset already exists.
Raises
------
ValueError
When importing from xarray object and not providing a name.
"""
path = Path(path)
if isinstance(dataset, (xr.DataArray, xr.Dataset)) and name is None:
raise ValueError(
"When importing data from a 'xarray.Dataset' or 'xarray.DataArray' "
"it is required to also pass a name."
)
if isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset(name="data")

if path.is_absolute() is False:
path = (self.directory.parent / path).resolve()
if isinstance(dataset, (str, Path)):
dataset = Path(dataset)

name = name or path.stem
data_path = self.directory / f"{name}.nc"
if dataset.is_absolute() is False:
dataset = (self.directory.parent / dataset).resolve()

name = name or dataset.stem
dataset = load_dataset(dataset)

data_path = self.directory / f"{name}.nc"
if data_path.exists() and ignore_existing:
return

dataset = load_dataset(path)
save_dataset(dataset, data_path, allow_overwrite=allow_overwrite)
34 changes: 34 additions & 0 deletions glotaran/project/test/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Literal

import pytest
import xarray as xr
from _pytest.monkeypatch import MonkeyPatch
from _pytest.recwarn import WarningsRecorder
from IPython.core.formatters import format_display_data
Expand Down Expand Up @@ -197,6 +198,23 @@ def test_import_data(project_folder: Path, project_file: Path, test_data: Path,
assert data == example_dataset


@pytest.mark.parametrize(
"data",
(
xr.DataArray([1]),
xr.Dataset({"data": xr.DataArray([1])}),
),
)
def test_import_data_xarray(tmp_path: Path, data: xr.Dataset | xr.DataArray):
"""Loaded data are always a dataset."""
project = Project.open(tmp_path)
project.import_data(data, name="test_data")

assert (tmp_path / "data/test_data.nc").is_file() is True

assert project.load_data("test_data").equals(xr.Dataset({"data": xr.DataArray([1])}))


def test_create_scheme(project_file: Path):
project = Project.open(project_file)

Expand Down Expand Up @@ -527,6 +545,22 @@ def test_missing_file_errors(tmp_path: Path, project_folder: Path):

project = Project.open(project_folder)

with pytest.raises(ValueError) as exc_info:
project.import_data(xr.Dataset({"data": [1]}))

assert str(exc_info.value) == (
"When importing data from a 'xarray.Dataset' or 'xarray.DataArray' "
"it is required to also pass a name."
)

with pytest.raises(ValueError) as exc_info:
project.import_data(xr.DataArray([1]))

assert str(exc_info.value) == (
"When importing data from a 'xarray.Dataset' or 'xarray.DataArray' "
"it is required to also pass a name."
)

with pytest.raises(ValueError) as exc_info:
project.load_data("not-existing")

Expand Down

0 comments on commit df5f34f

Please sign in to comment.