Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize from_xyz functions #490

Merged
merged 12 commits into from
Jan 5, 2019
4 changes: 3 additions & 1 deletion arviz/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Code for loading and manipulating data structures."""
from .inference_data import InferenceData
from .io_netcdf import load_data, save_data
from .io_netcdf import load_data, save_data, from_netcdf, to_netcdf
from .datasets import load_arviz_data, list_datasets, clear_data_home
from .base import numpy_to_data_array, dict_to_dataset
from .converters import convert_to_dataset, convert_to_inference_data
Expand Down Expand Up @@ -28,4 +28,6 @@
"from_cmdstan",
"from_pyro",
"from_tfp",
"from_netcdf",
"to_netcdf",
]
2 changes: 1 addition & 1 deletion arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class InferenceData:
"""Container for accessing netCDF files using xarray."""

def __init__(self, *_, **kwargs):
def __init__(self, **kwargs):
"""Initialize InferenceData object from keyword xarray datasets.

Examples
Expand Down
2 changes: 1 addition & 1 deletion arviz/data/io_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,8 @@ def _unpack_dataframes(dfs):


def from_cmdstan(
*,
posterior=None,
*,
posterior_predictive=None,
prior=None,
prior_predictive=None,
Expand Down
6 changes: 3 additions & 3 deletions arviz/data/io_emcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _verify_names(sampler, var_names, arg_names):
class EmceeConverter:
"""Encapsulate emcee specific logic."""

def __init__(self, sampler, *_, var_names=None, arg_names=None, coords=None, dims=None):
def __init__(self, *, sampler, var_names=None, arg_names=None, coords=None, dims=None):
var_names, arg_names = _verify_names(sampler, var_names, arg_names)
self.sampler = sampler
self.var_names = var_names
Expand Down Expand Up @@ -94,7 +94,7 @@ def to_inference_data(self):
)


def from_emcee(sampler, *, var_names=None, arg_names=None, coords=None, dims=None):
def from_emcee(sampler=None, *, var_names=None, arg_names=None, coords=None, dims=None):
"""Convert emcee data into an InferenceData object.

Parameters
Expand All @@ -111,5 +111,5 @@ def from_emcee(sampler, *, var_names=None, arg_names=None, coords=None, dims=Non
Map variable names to their coordinates
"""
return EmceeConverter(
sampler, var_names=var_names, arg_names=arg_names, coords=coords, dims=dims
sampler=sampler, var_names=var_names, arg_names=arg_names, coords=coords, dims=dims
).to_inference_data()
4 changes: 4 additions & 0 deletions arviz/data/io_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ def save_data(data, filename, *, group="posterior", coords=None, dims=None):
"""
inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
return inference_data.to_netcdf(filename)


from_netcdf = load_data # pylint: disable=invalid-name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just rename the functions rather than alias them and have two of the same function in the API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. Then we only need to mention this change in "what's changed". Or deprecate old functionality.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote for deprecate old functionality. If we want to be safe we can wrap load_data and add a warning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm fine with deprecation, or using the fact that we're not 1.0 yet to just change it (starting up a RELEASE_NOTES.md is a good idea!)

to_netcdf = save_data # pylint: disable=invalid-name
4 changes: 2 additions & 2 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PyMC3Converter:
"""Encapsulate PyMC3 specific logic."""

def __init__(
self, *_, trace=None, prior=None, posterior_predictive=None, coords=None, dims=None
self, *, trace=None, prior=None, posterior_predictive=None, coords=None, dims=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for catching this :/

):
self.trace = trace
self.prior = prior
Expand Down Expand Up @@ -142,7 +142,7 @@ def to_inference_data(self):
)


def from_pymc3(*, trace=None, prior=None, posterior_predictive=None, coords=None, dims=None):
def from_pymc3(trace=None, *, prior=None, posterior_predictive=None, coords=None, dims=None):
"""Convert pymc3 data into an InferenceData object."""
return PyMC3Converter(
trace=trace,
Expand Down
4 changes: 2 additions & 2 deletions arviz/data/io_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _get_var_names(posterior):
class PyroConverter:
"""Encapsulate Pyro specific logic."""

def __init__(self, posterior, *_, coords=None, dims=None):
def __init__(self, *, posterior, coords=None, dims=None):
"""Convert pyro data into an InferenceData object.

Parameters
Expand Down Expand Up @@ -103,7 +103,7 @@ def to_inference_data(self):
)


def from_pyro(posterior, *, coords=None, dims=None):
def from_pyro(posterior=None, *, coords=None, dims=None):
"""Convert pyro data into an InferenceData object.

Parameters
Expand Down
10 changes: 5 additions & 5 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class PyStanConverter:

def __init__(
self,
*_,
*,
posterior=None,
posterior_predictive=None,
prior=None,
Expand Down Expand Up @@ -170,7 +170,7 @@ class PyStan3Converter:
# pylint: disable=too-many-instance-attributes
def __init__(
self,
*_,
*,
posterior=None,
posterior_model=None,
posterior_predictive=None,
Expand Down Expand Up @@ -523,8 +523,8 @@ def infer_dtypes(fit, model=None):

# pylint disable=too-many-instance-attributes
def from_pystan(
*,
posterior=None,
*,
posterior_predictive=None,
prior=None,
prior_predictive=None,
Expand All @@ -539,11 +539,11 @@ def from_pystan(

Parameters
----------
posterior : StanFit4Model
posterior : StanFit4Model or stan.fit.Fit
PyStan fit object for posterior.
posterior_predictive : str, a list of str
Posterior predictive samples for the posterior.
prior : StanFit4Model
prior : StanFit4Model or stan.fit.Fit
PyStan fit object for prior.
prior_predictive : str, a list of str
Posterior predictive samples for the prior.
Expand Down
4 changes: 2 additions & 2 deletions arviz/data/io_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class TfpConverter:

def __init__(
self,
posterior,
*,
posterior,
var_names=None,
model_fn=None,
feed_dict=None,
Expand Down Expand Up @@ -163,7 +163,7 @@ def to_inference_data(self):


def from_tfp(
posterior,
posterior=None,
*,
var_names=None,
model_fn=None,
Expand Down
104 changes: 99 additions & 5 deletions arviz/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
from_cmdstan,
from_pymc3,
from_pystan,
from_pyro,
from_emcee,
from_netcdf,
to_netcdf,
load_arviz_data,
list_datasets,
clear_data_home,
InferenceData,
)
from ..data.datasets import REMOTE_DATASETS, LOCAL_DATASETS, RemoteFileMetadata
from .helpers import ( # pylint: disable=unused-import
Expand Down Expand Up @@ -310,6 +314,72 @@ def test__verify_arg_names(self, obj):
with pytest.raises(ValueError):
from_emcee(obj, arg_names=["not", "enough"])

def test_inference_data(self, obj):
inference_data = self.get_inference_data(obj)
assert hasattr(inference_data, "posterior")


class TestIONetCDFUtils:
@pytest.fixture(scope="class")
def data(self, draws, chains):
class Data:
model, obj = load_cached_models(eight_schools_params, draws, chains)["pymc3"]

return Data

def get_inference_data(self, data, eight_schools_params): # pylint: disable=W0613
with data.model:
prior = pm.sample_prior_predictive()
posterior_predictive = pm.sample_posterior_predictive(data.obj)

return from_pymc3(
trace=data.obj,
prior=prior,
posterior_predictive=posterior_predictive,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)

def test_io_function(self, data, eight_schools_params):
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
)
assert hasattr(inference_data, "posterior")
here = os.path.dirname(os.path.abspath(__file__))
data_directory = os.path.join(here, "saved_models")
filepath = os.path.join(data_directory, "io_function_testfile.nc")
if os.path.exists(filepath):
os.remove(filepath)
# az -function
to_netcdf(inference_data, filepath)
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
inference_data2 = from_netcdf(filepath)
assert hasattr(inference_data2, "posterior")
os.remove(filepath)
assert not os.path.exists(filepath)

def test_io_method(self, data, eight_schools_params):
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
)
assert hasattr(inference_data, "posterior")
here = os.path.dirname(os.path.abspath(__file__))
data_directory = os.path.join(here, "saved_models")
filepath = os.path.join(data_directory, "io_method_testfile.nc")
if os.path.exists(filepath):
os.remove(filepath)
# InferenceData method
if os.path.exists(filepath):
os.remove(filepath)
inference_data.to_netcdf(filepath)
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
inference_data2 = InferenceData.from_netcdf(filepath)
assert hasattr(inference_data2, "posterior")
os.remove(filepath)
assert not os.path.exists(filepath)


class TestPyMC3NetCDFUtils:
@pytest.fixture(scope="class")
Expand All @@ -332,6 +402,10 @@ def get_inference_data(self, data, eight_schools_params):
dims={"theta": ["school"], "eta": ["school"]},
)

def test_posterior(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
assert hasattr(inference_data, "posterior")

def test_sampler_stats(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
assert hasattr(inference_data, "sample_stats")
Expand All @@ -345,6 +419,22 @@ def test_prior(self, data, eight_schools_params):
assert hasattr(inference_data, "prior")


class TestPyroNetCDFUtils:
@pytest.fixture(scope="class")
def data(self, draws, chains):
class Data:
obj = load_cached_models(eight_schools_params, draws, chains)["pyro"]

return Data

def get_inference_data(self, data):
return from_pyro(posterior=data.obj)

def test_inference_data(self, data):
inference_data = self.get_inference_data(data)
assert hasattr(inference_data, "posterior")


class TestPyStanNetCDFUtils:
@pytest.fixture(scope="class")
def data(self, draws, chains):
Expand Down Expand Up @@ -435,20 +525,24 @@ def test_inference_data(self, data, eight_schools_params):
inference_data2 = self.get_inference_data2(data, eight_schools_params)
inference_data3 = self.get_inference_data3(data, eight_schools_params)
inference_data4 = self.get_inference_data4(data)
# inference_data 1
assert hasattr(inference_data1.sample_stats, "log_likelihood")
assert hasattr(inference_data1.posterior, "theta")
assert hasattr(inference_data1.prior, "theta")
assert hasattr(inference_data1.observed_data, "y")
# inference_data 2
assert hasattr(inference_data2.posterior_predictive, "y_hat")
assert hasattr(inference_data2.prior_predictive, "y_hat")
assert hasattr(inference_data2.sample_stats, "lp")
assert hasattr(inference_data2.sample_stats_prior, "lp")
assert hasattr(inference_data2.observed_data, "y")
# inference_data 3
assert hasattr(inference_data3.posterior_predictive, "y_hat")
assert hasattr(inference_data3.prior_predictive, "y_hat")
assert hasattr(inference_data3.sample_stats, "lp")
assert hasattr(inference_data3.sample_stats_prior, "lp")
assert hasattr(inference_data3.observed_data, "y")
# inference_data 4
assert hasattr(inference_data4.posterior, "theta")
assert hasattr(inference_data4.prior, "theta")

Expand All @@ -457,19 +551,19 @@ class TestTfpNetCDFUtils:
@pytest.fixture(scope="class")
def data(self, draws, chains):
class Data:
# Returns result of from_tfp
obj = load_cached_models({}, draws, chains)[ # pylint: disable=E1120
"tensorflow_probability"
]

return Data

def get_inference_data(self, data, eight_school_params): # pylint: disable=W0613
def get_inference_data(self, data): # pylint: disable=W0613
return data.obj

def test_inference_data(self, data, eight_schools_params):
inference_data1 = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
)
def test_inference_data(self, data):
inference_data = self.get_inference_data(data) # pylint: disable=W0612
assert hasattr(inference_data, "posterior")


class TestCmdStanNetCDFUtils:
Expand Down