diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 4d8d619c79..e08039916b 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -1,6 +1,6 @@ """Code for loading and manipulating data structures.""" from .inference_data import InferenceData -from .io_netcdf import load_data, save_data +from .io_netcdf import from_netcdf, to_netcdf, load_data, save_data from .datasets import load_arviz_data, list_datasets, clear_data_home from .base import numpy_to_data_array, dict_to_dataset from .converters import convert_to_dataset, convert_to_inference_data @@ -13,8 +13,6 @@ __all__ = [ "InferenceData", - "load_data", - "save_data", "load_arviz_data", "list_datasets", "clear_data_home", @@ -28,4 +26,8 @@ "from_cmdstan", "from_pyro", "from_tfp", + "from_netcdf", + "to_netcdf", + "load_data", + "save_data", ] diff --git a/arviz/data/datasets.py b/arviz/data/datasets.py index d3bc2a72d6..afbb7bcd22 100644 --- a/arviz/data/datasets.py +++ b/arviz/data/datasets.py @@ -6,7 +6,7 @@ import shutil from urllib.request import urlretrieve -from .io_netcdf import load_data +from .io_netcdf import from_netcdf LocalFileMetadata = namedtuple("LocalFileMetadata", ["filename", "description"]) @@ -143,7 +143,7 @@ def load_arviz_data(dataset=None, data_home=None): """ if dataset in LOCAL_DATASETS: resource = LOCAL_DATASETS[dataset] - return load_data(resource.filename) + return from_netcdf(resource.filename) elif dataset in REMOTE_DATASETS: remote = REMOTE_DATASETS[dataset] @@ -158,7 +158,7 @@ def load_arviz_data(dataset=None, data_home=None): "file may be corrupted. Run `arviz.clear_data_home()` and try " "again, or please open an issue.".format(file_path, checksum, remote.checksum) ) - return load_data(file_path) + return from_netcdf(file_path) else: raise ValueError( "Dataset {} not found! The following are available:\n{}".format( diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 8d80e96c83..ffc05d1a59 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -6,7 +6,7 @@ class InferenceData: """Container for accessing netCDF files using xarray.""" - def __init__(self, *_, **kwargs): + def __init__(self, **kwargs): """Initialize InferenceData object from keyword xarray datasets. Examples diff --git a/arviz/data/io_cmdstan.py b/arviz/data/io_cmdstan.py index 5ae7cfcc16..ba2cacba66 100644 --- a/arviz/data/io_cmdstan.py +++ b/arviz/data/io_cmdstan.py @@ -665,8 +665,8 @@ def _unpack_dataframes(dfs): def from_cmdstan( - *, posterior=None, + *, posterior_predictive=None, prior=None, prior_predictive=None, diff --git a/arviz/data/io_emcee.py b/arviz/data/io_emcee.py index 3d583e1ba6..1ccba5890a 100644 --- a/arviz/data/io_emcee.py +++ b/arviz/data/io_emcee.py @@ -54,7 +54,7 @@ def _verify_names(sampler, var_names, arg_names): class EmceeConverter: """Encapsulate emcee specific logic.""" - def __init__(self, sampler, *_, var_names=None, arg_names=None, coords=None, dims=None): + def __init__(self, *, sampler, var_names=None, arg_names=None, coords=None, dims=None): var_names, arg_names = _verify_names(sampler, var_names, arg_names) self.sampler = sampler self.var_names = var_names @@ -94,7 +94,7 @@ def to_inference_data(self): ) -def from_emcee(sampler, *, var_names=None, arg_names=None, coords=None, dims=None): +def from_emcee(sampler=None, *, var_names=None, arg_names=None, coords=None, dims=None): """Convert emcee data into an InferenceData object. Parameters @@ -111,5 +111,5 @@ def from_emcee(sampler, *, var_names=None, arg_names=None, coords=None, dims=Non Map variable names to their coordinates """ return EmceeConverter( - sampler, var_names=var_names, arg_names=arg_names, coords=coords, dims=dims + sampler=sampler, var_names=var_names, arg_names=arg_names, coords=coords, dims=dims ).to_inference_data() diff --git a/arviz/data/io_netcdf.py b/arviz/data/io_netcdf.py index 8e2d665127..8dc1057ef7 100644 --- a/arviz/data/io_netcdf.py +++ b/arviz/data/io_netcdf.py @@ -1,9 +1,10 @@ """Input and output support for data.""" +import warnings from .inference_data import InferenceData from .converters import convert_to_inference_data -def load_data(filename): +def from_netcdf(filename): """Load netcdf file back into an arviz.InferenceData. Parameters @@ -14,7 +15,7 @@ def load_data(filename): return InferenceData.from_netcdf(filename) -def save_data(data, filename, *, group="posterior", coords=None, dims=None): +def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None): """Save dataset as a netcdf file. WARNING: Only idempotent in case `data` is InferenceData @@ -39,3 +40,58 @@ def save_data(data, filename, *, group="posterior", coords=None, dims=None): """ inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims) return inference_data.to_netcdf(filename) + + +def load_data(filename): + """Load netcdf file back into an arviz.InferenceData. + + Parameters + ---------- + filename : str + name or path of the file to load trace + + Note + ---- + This function is deprecated and will be removed in 0.4. + Use `from_netcdf` instead. + """ + warnings.warn( + "The 'load_data' function is deprecated as of 0.3.2, use 'from_netcdf' instead", + DeprecationWarning, + ) + return from_netcdf(filename=filename) + + +def save_data(data, filename, *, group="posterior", coords=None, dims=None): + """Save dataset as a netcdf file. + + WARNING: Only idempotent in case `data` is InferenceData + + Parameters + ---------- + data : InferenceData, or any object accepted by `convert_to_inference_data` + Object to be saved + filename : str + name or path of the file to load trace + group : str (optional) + In case `data` is not InferenceData, this is the group it will be saved to + coords : dict (optional) + See `convert_to_inference_data` + dims : dict (optional) + See `convert_to_inference_data` + + Returns + ------- + str + filename saved to + + Note + ---- + This function is deprecated and will be removed in 0.4. + Use `to_netcdf` instead. + """ + warnings.warn( + "The 'save_data' function is deprecated as of 0.3.2, use 'to_netcdf' instead", + DeprecationWarning, + ) + return to_netcdf(data=data, filename=filename, group=group, coords=coords, dims=dims) diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index 8b658c29e5..e928fdf666 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -10,7 +10,7 @@ class PyMC3Converter: """Encapsulate PyMC3 specific logic.""" def __init__( - self, *_, trace=None, prior=None, posterior_predictive=None, coords=None, dims=None + self, *, trace=None, prior=None, posterior_predictive=None, coords=None, dims=None ): self.trace = trace self.prior = prior @@ -142,7 +142,7 @@ def to_inference_data(self): ) -def from_pymc3(*, trace=None, prior=None, posterior_predictive=None, coords=None, dims=None): +def from_pymc3(trace=None, *, prior=None, posterior_predictive=None, coords=None, dims=None): """Convert pymc3 data into an InferenceData object.""" return PyMC3Converter( trace=trace, diff --git a/arviz/data/io_pyro.py b/arviz/data/io_pyro.py index 9ca1d1fb08..db20545be5 100644 --- a/arviz/data/io_pyro.py +++ b/arviz/data/io_pyro.py @@ -28,7 +28,7 @@ def _get_var_names(posterior): class PyroConverter: """Encapsulate Pyro specific logic.""" - def __init__(self, posterior, *_, coords=None, dims=None): + def __init__(self, *, posterior, coords=None, dims=None): """Convert pyro data into an InferenceData object. Parameters @@ -103,7 +103,7 @@ def to_inference_data(self): ) -def from_pyro(posterior, *, coords=None, dims=None): +def from_pyro(posterior=None, *, coords=None, dims=None): """Convert pyro data into an InferenceData object. Parameters diff --git a/arviz/data/io_pystan.py b/arviz/data/io_pystan.py index 00d9ec38a3..ee5828e845 100644 --- a/arviz/data/io_pystan.py +++ b/arviz/data/io_pystan.py @@ -15,7 +15,7 @@ class PyStanConverter: def __init__( self, - *_, + *, posterior=None, posterior_predictive=None, prior=None, @@ -168,7 +168,7 @@ class PyStan3Converter: # pylint: disable=too-many-instance-attributes def __init__( self, - *_, + *, posterior=None, posterior_model=None, posterior_predictive=None, @@ -519,8 +519,8 @@ def infer_dtypes(fit, model=None): # pylint disable=too-many-instance-attributes def from_pystan( - *, posterior=None, + *, posterior_predictive=None, prior=None, prior_predictive=None, @@ -535,11 +535,11 @@ def from_pystan( Parameters ---------- - posterior : StanFit4Model + posterior : StanFit4Model or stan.fit.Fit PyStan fit object for posterior. posterior_predictive : str, a list of str Posterior predictive samples for the posterior. - prior : StanFit4Model + prior : StanFit4Model or stan.fit.Fit PyStan fit object for prior. prior_predictive : str, a list of str Posterior predictive samples for the prior. diff --git a/arviz/data/io_tfp.py b/arviz/data/io_tfp.py index 61bcaaac55..364bfeeb12 100644 --- a/arviz/data/io_tfp.py +++ b/arviz/data/io_tfp.py @@ -12,8 +12,8 @@ class TfpConverter: def __init__( self, - posterior, *, + posterior, var_names=None, model_fn=None, feed_dict=None, @@ -167,7 +167,7 @@ def to_inference_data(self): def from_tfp( - posterior, + posterior=None, *, var_names=None, model_fn=None, diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index dd6b8b08da..79fedda2d6 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -350,11 +350,7 @@ def stan_extract_dict(fit, var_names=None): continue # in future fix the correct number of draws if fit.save_warmup is True - new_shape = ( - *fit.dims[fit.param_names.index(var)], - -1, - fit.num_chains, - ) # pylint: disable=protected-access + new_shape = (*fit.dims[fit.param_names.index(var)], -1, fit.num_chains) values = fit._draws[fit._parameter_indexes(var), :] # pylint: disable=protected-access values = values.reshape(new_shape, order="F") values = np.moveaxis(values, [-2, -1], [1, 0]) diff --git a/arviz/tests/test_data.py b/arviz/tests/test_data.py index a8814185bf..13fb6f91c6 100644 --- a/arviz/tests/test_data.py +++ b/arviz/tests/test_data.py @@ -12,10 +12,16 @@ from_cmdstan, from_pymc3, from_pystan, + from_pyro, from_emcee, + from_netcdf, + to_netcdf, + load_data, + save_data, load_arviz_data, list_datasets, clear_data_home, + InferenceData, ) from ..data.io_pystan import get_draws, get_draws_stan3 # pylint: disable=unused-import from ..data.datasets import REMOTE_DATASETS, LOCAL_DATASETS, RemoteFileMetadata @@ -311,6 +317,78 @@ def test__verify_arg_names(self, obj): with pytest.raises(ValueError): from_emcee(obj, arg_names=["not", "enough"]) + def test_inference_data(self, obj): + inference_data = self.get_inference_data(obj) + assert hasattr(inference_data, "posterior") + + +class TestIONetCDFUtils: + @pytest.fixture(scope="class") + def data(self, draws, chains): + class Data: + model, obj = load_cached_models(eight_schools_params, draws, chains)["pymc3"] + + return Data + + def get_inference_data(self, data, eight_schools_params): # pylint: disable=W0613 + with data.model: + prior = pm.sample_prior_predictive() + posterior_predictive = pm.sample_posterior_predictive(data.obj) + + return from_pymc3( + trace=data.obj, + prior=prior, + posterior_predictive=posterior_predictive, + coords={"school": np.arange(eight_schools_params["J"])}, + dims={"theta": ["school"], "eta": ["school"]}, + ) + + def test_io_function(self, data, eight_schools_params): + inference_data = self.get_inference_data( # pylint: disable=W0612 + data, eight_schools_params + ) + assert hasattr(inference_data, "posterior") + here = os.path.dirname(os.path.abspath(__file__)) + data_directory = os.path.join(here, "saved_models") + filepath = os.path.join(data_directory, "io_function_testfile.nc") + # az -function + to_netcdf(inference_data, filepath) + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + inference_data2 = from_netcdf(filepath) + assert hasattr(inference_data2, "posterior") + os.remove(filepath) + assert not os.path.exists(filepath) + # Test deprecated functions + save_data(inference_data, filepath) + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + inference_data3 = load_data(filepath) + assert hasattr(inference_data3, "posterior") + os.remove(filepath) + assert not os.path.exists(filepath) + + def test_io_method(self, data, eight_schools_params): + inference_data = self.get_inference_data( # pylint: disable=W0612 + data, eight_schools_params + ) + assert hasattr(inference_data, "posterior") + here = os.path.dirname(os.path.abspath(__file__)) + data_directory = os.path.join(here, "saved_models") + filepath = os.path.join(data_directory, "io_method_testfile.nc") + if os.path.exists(filepath): + os.remove(filepath) + # InferenceData method + if os.path.exists(filepath): + os.remove(filepath) + inference_data.to_netcdf(filepath) + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + inference_data2 = InferenceData.from_netcdf(filepath) + assert hasattr(inference_data2, "posterior") + os.remove(filepath) + assert not os.path.exists(filepath) + class TestPyMC3NetCDFUtils: @pytest.fixture(scope="class") @@ -333,6 +411,10 @@ def get_inference_data(self, data, eight_schools_params): dims={"theta": ["school"], "eta": ["school"]}, ) + def test_posterior(self, data, eight_schools_params): + inference_data = self.get_inference_data(data, eight_schools_params) + assert hasattr(inference_data, "posterior") + def test_sampler_stats(self, data, eight_schools_params): inference_data = self.get_inference_data(data, eight_schools_params) assert hasattr(inference_data, "sample_stats") @@ -346,6 +428,22 @@ def test_prior(self, data, eight_schools_params): assert hasattr(inference_data, "prior") +class TestPyroNetCDFUtils: + @pytest.fixture(scope="class") + def data(self, draws, chains): + class Data: + obj = load_cached_models(eight_schools_params, draws, chains)["pyro"] + + return Data + + def get_inference_data(self, data): + return from_pyro(posterior=data.obj) + + def test_inference_data(self, data): + inference_data = self.get_inference_data(data) + assert hasattr(inference_data, "posterior") + + class TestPyStanNetCDFUtils: @pytest.fixture(scope="class") def data(self, draws, chains): @@ -436,20 +534,24 @@ def test_inference_data(self, data, eight_schools_params): inference_data2 = self.get_inference_data2(data, eight_schools_params) inference_data3 = self.get_inference_data3(data, eight_schools_params) inference_data4 = self.get_inference_data4(data) + # inference_data 1 assert hasattr(inference_data1.sample_stats, "log_likelihood") assert hasattr(inference_data1.posterior, "theta") assert hasattr(inference_data1.prior, "theta") assert hasattr(inference_data1.observed_data, "y") + # inference_data 2 assert hasattr(inference_data2.posterior_predictive, "y_hat") assert hasattr(inference_data2.prior_predictive, "y_hat") assert hasattr(inference_data2.sample_stats, "lp") assert hasattr(inference_data2.sample_stats_prior, "lp") assert hasattr(inference_data2.observed_data, "y") + # inference_data 3 assert hasattr(inference_data3.posterior_predictive, "y_hat") assert hasattr(inference_data3.prior_predictive, "y_hat") assert hasattr(inference_data3.sample_stats, "lp") assert hasattr(inference_data3.sample_stats_prior, "lp") assert hasattr(inference_data3.observed_data, "y") + # inference_data 4 assert hasattr(inference_data4.posterior, "theta") assert hasattr(inference_data4.prior, "theta") @@ -505,19 +607,19 @@ class TestTfpNetCDFUtils: @pytest.fixture(scope="class") def data(self, draws, chains): class Data: - obj = load_cached_models({}, draws, chains)[ # pylint: disable=E1120 + # Returns result of from_tfp + obj = load_cached_models({}, draws, chains)[ # pylint: disable=no-value-for-parameter "tensorflow_probability" ] return Data - def get_inference_data(self, data, eight_schools_params): # pylint: disable=W0613 + def get_inference_data(self, data): return data.obj - def test_inference_data(self, data, eight_schools_params): - inference_data1 = self.get_inference_data( # pylint: disable=W0612 - data, eight_schools_params - ) + def test_inference_data(self, data): + inference_data = self.get_inference_data(data) + assert hasattr(inference_data, "posterior") class TestCmdStanNetCDFUtils: