From 13559a95889a62e3daf669a8d6e04a9c51d0c55a Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 19 Nov 2020 12:07:02 +0000 Subject: [PATCH] Add load_dataset and save_dataset functions --- sgkit/__init__.py | 3 ++ sgkit/io/dataset.py | 53 ++++++++++++++++++++++++++++++++++ sgkit/tests/io/test_dataset.py | 24 +++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 sgkit/io/dataset.py create mode 100644 sgkit/tests/io/test_dataset.py diff --git a/sgkit/__init__.py b/sgkit/__init__.py index 4a2d6862d..48726a155 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -1,6 +1,7 @@ from pkg_resources import DistributionNotFound, get_distribution from .display import display_genotypes +from .io.dataset import load_dataset, save_dataset from .io.vcfzarr_reader import read_vcfzarr from .model import ( DIM_ALLELE, @@ -59,4 +60,6 @@ "variables", "pca", "window", + "load_dataset", + "save_dataset", ] diff --git a/sgkit/io/dataset.py b/sgkit/io/dataset.py new file mode 100644 index 000000000..e5a06e339 --- /dev/null +++ b/sgkit/io/dataset.py @@ -0,0 +1,53 @@ +from typing import Any + +import xarray as xr +from xarray import Dataset + +from sgkit.typing import PathType + + +def save_dataset(ds: Dataset, path: PathType, **kwargs: Any) -> None: + """Save a dataset to Zarr storage. + + This function is a thin wrapper around :meth:`xarray.Dataset.to_zarr` + that uses sensible defaults and makes it easier to use in a pipeline. + + Parameters + ---------- + ds + Dataset to save. + path + Path to directory in file system to save to. + kwargs + Additional arguments to pass to :meth:`xarray.Dataset.to_zarr`. + """ + store = str(path) + for v in ds: + # Workaround for https://github.com/pydata/xarray/issues/4380 + ds[v].encoding.pop("chunks", None) + ds.to_zarr(store, **kwargs) + + +def load_dataset(path: PathType) -> Dataset: + """Load a dataset from Zarr storage. + + This function is a thin wrapper around :meth:`xarray.open_zarr` + that uses sensible defaults and makes it easier to use in a pipeline. + + Parameters + ---------- + path + Path to directory in file system to load from. + + Returns + ------- + Dataset + The dataset loaded from the file system. + """ + store = str(path) + ds: Dataset = xr.open_zarr(store, concat_characters=False) # type: ignore[no-untyped-call] + for v in ds: + # Workaround for https://github.com/pydata/xarray/issues/4386 + if v.endswith("_mask"): # type: ignore + ds[v] = ds[v].astype(bool) # type: ignore[no-untyped-call] + return ds diff --git a/sgkit/tests/io/test_dataset.py b/sgkit/tests/io/test_dataset.py new file mode 100644 index 000000000..58438a542 --- /dev/null +++ b/sgkit/tests/io/test_dataset.py @@ -0,0 +1,24 @@ +import xarray as xr +from xarray import Dataset + +from sgkit import load_dataset, save_dataset +from sgkit.testing import simulate_genotype_call_dataset + + +def assert_identical(ds1: Dataset, ds2: Dataset) -> None: + """Assert two Datasets are identical, including dtypes for all variables.""" + xr.testing.assert_identical(ds1, ds2) # type: ignore[no-untyped-call] + assert all([ds1[v].dtype == ds2[v].dtype for v in ds1.data_vars]) + + +def test_save_and_load_dataset(tmp_path): + path = str(tmp_path / "ds.zarr") + ds = simulate_genotype_call_dataset(n_variant=10, n_sample=10) + save_dataset(ds, path) + ds2 = load_dataset(path) + assert_identical(ds, ds2) + + # save and load again to test https://github.com/pydata/xarray/issues/4386 + path2 = str(tmp_path / "ds2.zarr") + save_dataset(ds2, path2) + assert_identical(ds, load_dataset(path2))