Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow xarray.Dataarray input to plots. #1120

Merged
merged 13 commits into from
Mar 26, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* Allow xarray.Dataarray input for plots.(#1120)
### Maintenance and fixes
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
* Fixed hist kind of `plot_dist` with multidimensional input (#1115)
Expand Down
9 changes: 9 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
| emcee sampler: Automatically extracts data
| pyro MCMC: Automatically extracts data
| xarray.Dataset: adds to InferenceData as only group
| xarray.DataArray: creates an xarray dataset as the only group, gives the
array an arbitrary name, if name not set
| dict: creates an xarray dataset as the only group
| numpy array: creates an xarray dataset as the only group, gives the
array an arbitrary name
Expand Down Expand Up @@ -97,6 +99,10 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
# Cases that convert to xarray
if isinstance(obj, xr.Dataset):
dataset = obj
elif isinstance(obj, xr.DataArray):
if obj.name is None:
obj.name = 'x'
dataset = obj.to_dataset()
elif isinstance(obj, dict):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif isinstance(obj, np.ndarray):
Expand All @@ -109,6 +115,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
return from_cmdstan(**kwargs)
else:
allowable_types = (
"xarray dataarray",
"xarray dataset",
"dict",
"netcdf filename",
Expand Down Expand Up @@ -149,6 +156,8 @@ def convert_to_dataset(obj, *, group="posterior", coords=None, dims=None):
pystan fit: Automatically extracts data
pymc3 trace: Automatically extracts data
xarray.Dataset: adds to InferenceData as only group
xarray.DataArray: creates an xarray dataset as the only group, gives the
array an arbitrary name, if name not set
dict: creates an xarray dataset as the only group
numpy array: creates an xarray dataset as the only group, gives the
array an arbitrary name
Expand Down
38 changes: 38 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,41 @@ def test_id_conversion_args(self):
assert isinstance(inference_data, InferenceData)
assert set(inference_data.posterior.coords["Ivies"].values) == set(IVIES)
assert inference_data.posterior["theta"].dims == ("chain", "draw", "Ivies")



class TestDataArrayToDataset:
def test_1d_dataset(self):
size = 100
dataset = convert_to_dataset(xr.DataArray(np.random.randn(1, size), dims=('chain', 'draw')))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could set the name of the input dataarray in this test to make sure the name is not overwritten during conversion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I forgot to test the name. I'll make suitable changes

assert len(dataset.data_vars) == 1

assert dataset.chain.shape == (1, )
assert dataset.draw.shape == (size, )


def test_nd_to_dataset(self):
shape = (1, 2, 3, 4, 5)
dataset = convert_to_dataset(
xr.DataArray(np.random.randn(*shape),
dims=('chain', 'draw', 'dim_0', 'dim_1', 'dim_2')))
assert len(dataset.data_vars) == 1
var_name = list(dataset.data_vars)[0]

assert dataset.chain.shape == shape[:1]
assert dataset.draw.shape == shape[1:2]
assert dataset[var_name].shape == shape

def test_nd_to_inference_data(self):
shape = (1, 2, 3, 4, 5)
inference_data = convert_to_inference_data(
xr.DataArray(np.random.randn(*shape),
dims=('chain', 'draw', 'dim_0', 'dim_1', 'dim_2')), group="prior")
assert hasattr(inference_data, "prior")
assert len(inference_data.prior.data_vars) == 1
var_name = list(inference_data.prior.data_vars)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit pick, move this up above the asserts so all the asserts are together. That way when reading the tests we can see all the logic in one place and all the asserts in another. Same for test above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure.


assert inference_data.prior.chain.shape == shape[:1]
assert inference_data.prior.draw.shape == shape[1:2]
assert inference_data.prior[var_name].shape == shape
assert repr(inference_data).startswith("Inference data with groups")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this last assert should be removed, it will only fail if the inference data object has not been created, and in this case all asserts before this one would have already failed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well observed. I'll change it