Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save InferenceData attrs #2131

Merged
merged 16 commits into from
Oct 22, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
* Add testing module for labeller classes ([2095](https://github.com/arviz-devs/arviz/pull/2095))
* Skip compression for object dtype while creating a netcdf file ([2129](https://github.com/arviz-devs/arviz/pull/2129))
* Fix issue in dim generation when default dims are present in user inputed dims ([2138](https://github.com/arviz-devs/arviz/pull/2138))
* Save InferenceData level attrs to netcdf and zarr ([2131](https://github.com/arviz-devs/arviz/pull/2131))

### Deprecation
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))
Expand Down
39 changes: 27 additions & 12 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=too-many-lines,too-many-public-methods
"""Data structure for using netcdf groups with xarray."""
import re
import sys
import uuid
import warnings
Expand All @@ -9,7 +10,6 @@
from copy import deepcopy
from datetime import datetime
from html import escape
import re
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -363,13 +363,13 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
InferenceData object
"""
groups = {}
attrs = {}

try:
with nc.Dataset(filename, mode="r") as data:
data_groups = list(data.groups)

for group in data_groups:

group_kws = {}
if group_kwargs is not None and regex is False:
group_kws = group_kwargs.get(group, {})
Expand All @@ -382,21 +382,24 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
groups[group] = data.load()
else:
groups[group] = data
res = InferenceData(**groups)
return res
except OSError as e: # pylint: disable=invalid-name
if e.errno == -101:
raise type(e)(
str(e)

with xr.open_dataset(filename, mode="r") as data:
attrs.update(data.load().attrs)

return InferenceData(attrs=attrs, **groups)
except OSError as err:
if err.errno == -101:
raise type(err)(
str(err)
+ (
" while reading a NetCDF file. This is probably an error in HDF5, "
"which happens because your OS does not support HDF5 file locking. See "
"https://stackoverflow.com/questions/49317927/"
"errno-101-netcdf-hdf-error-when-opening-netcdf-file#49317928"
" for a possible solution."
)
)
raise e
) from err
raise err

def to_netcdf(
self, filename: str, compress: bool = True, groups: Optional[List[str]] = None
Expand All @@ -419,6 +422,10 @@ def to_netcdf(
Location of netcdf file
"""
mode = "w" # overwrite first, then append
if self._attrs:
xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode)
mode = "a"

if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._groups_all
Expand All @@ -437,7 +444,7 @@ def to_netcdf(
data.to_netcdf(filename, mode=mode, group=group, **kwargs)
data.close()
mode = "a"
else: # creates a netcdf file for an empty InferenceData object.
elif not self._attrs: # creates a netcdf file for an empty InferenceData object.
empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
empty_netcdf_file.close()
return filename
Expand Down Expand Up @@ -688,6 +695,10 @@ def to_zarr(self, store=None):
if not groups:
raise TypeError("No valid groups found!")

# order matters here, saving attrs after the groups will erase the groups.
if self.attrs:
xr.Dataset(attrs=self.attrs).to_zarr(store=store, mode="w")

for group in groups:
# Create zarr group in store with same group name
getattr(self, group).to_zarr(store=store, group=group, mode="w")
Expand Down Expand Up @@ -738,7 +749,11 @@ def from_zarr(store) -> "InferenceData":
for key_group, _ in zarr_handle.groups():
with xr.open_zarr(store=store, group=key_group) as data:
groups[key_group] = data.load() if rcParams["data.load"] == "eager" else data
return InferenceData(**groups)

with xr.open_zarr(store=store) as root:
attrs = root.attrs

return InferenceData(attrs=attrs, **groups)

def __add__(self, other: "InferenceData") -> "InferenceData":
"""Concatenate two InferenceData objects."""
Expand Down
21 changes: 21 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
from html import escape
from typing import Dict
from tempfile import TemporaryDirectory
from urllib.parse import urlunsplit

import numpy as np
Expand Down Expand Up @@ -107,6 +108,26 @@ def test_load_local_arviz_data():
assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")


@pytest.mark.parametrize("fill_attrs", [True, False])
def test_local_save(fill_attrs):
inference_data = load_arviz_data("centered_eight")
assert isinstance(inference_data, InferenceData)

if fill_attrs:
inference_data.attrs["test"] = 1
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
path = os.path.join(tmp_dir, "test_file.nc")
inference_data.to_netcdf(path)

inference_data2 = from_netcdf(path)
if fill_attrs:
assert "test" in inference_data2.attrs
assert inference_data2.attrs["test"] == 1
# pylint: disable=protected-access
assert all(group in inference_data2 for group in inference_data._groups_all)
# pylint: enable=protected-access


def test_clear_data_home():
resource = REMOTE_DATASETS["test_remote"]
assert not os.path.exists(resource.filename)
Expand Down
75 changes: 40 additions & 35 deletions arviz/tests/base_tests/test_data_zarr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=redefined-outer-name
import os
import shutil
from collections.abc import MutableMapping
from tempfile import TemporaryDirectory
from typing import Mapping

import numpy as np
Expand Down Expand Up @@ -31,7 +31,7 @@ class Data:

return Data

def get_inference_data(self, data, eight_schools_params):
def get_inference_data(self, data, eight_schools_params, fill_attrs):
return from_dict(
posterior=data.obj,
posterior_predictive=data.obj,
Expand All @@ -42,13 +42,15 @@ def get_inference_data(self, data, eight_schools_params):
observed_data=eight_schools_params,
coords={"school": np.arange(8)},
dims={"theta": ["school"], "eta": ["school"]},
attrs={"test": 1} if fill_attrs else None,
)

@pytest.mark.parametrize("store", [0, 1, 2])
def test_io_method(self, data, eight_schools_params, store):
@pytest.mark.parametrize("fill_attrs", [True, False])
def test_io_method(self, data, eight_schools_params, store, fill_attrs):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
data, eight_schools_params, fill_attrs
)
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
Expand All @@ -62,39 +64,42 @@ def test_io_method(self, data, eight_schools_params, store):
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

if fill_attrs:
assert inference_data.attrs["test"] == 1
else:
assert "test" not in inference_data.attrs

# check filename does not exist and use to_zarr method
here = os.path.dirname(os.path.abspath(__file__))
data_directory = os.path.join(here, "..", "saved_models")
filepath = os.path.join(data_directory, "zarr")
assert not os.path.exists(filepath)
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
filepath = os.path.join(tmp_dir, "zarr")

# InferenceData method
if store == 0:
# Tempdir
store = inference_data.to_zarr(store=None)
assert isinstance(store, MutableMapping)
elif store == 1:
inference_data.to_zarr(store=filepath)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
elif store == 2:
store = zarr.storage.DirectoryStore(filepath)
inference_data.to_zarr(store=store)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
# InferenceData method
if store == 0:
# Tempdir
store = inference_data.to_zarr(store=None)
assert isinstance(store, MutableMapping)
elif store == 1:
inference_data.to_zarr(store=filepath)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
elif store == 2:
store = zarr.storage.DirectoryStore(filepath)
inference_data.to_zarr(store=store)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0

if isinstance(store, MutableMapping):
inference_data2 = InferenceData.from_zarr(store)
else:
inference_data2 = InferenceData.from_zarr(filepath)
if isinstance(store, MutableMapping):
inference_data2 = InferenceData.from_zarr(store)
else:
inference_data2 = InferenceData.from_zarr(filepath)

# Everything in dict still available in inference_data2 ?
fails = check_multiple_attrs(test_dict, inference_data2)
assert not fails
# Everything in dict still available in inference_data2 ?
fails = check_multiple_attrs(test_dict, inference_data2)
assert not fails

# Remove created folder structure
if os.path.exists(filepath):
shutil.rmtree(filepath)
assert not os.path.exists(filepath)
if fill_attrs:
assert inference_data2.attrs["test"] == 1
else:
assert "test" not in inference_data2.attrs