Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow xarray.Dataarray input to plots. #1120

Merged
merged 13 commits into from
Mar 26, 2020
9 changes: 9 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
| emcee sampler: Automatically extracts data
| pyro MCMC: Automatically extracts data
| xarray.Dataset: adds to InferenceData as only group
| xarray.DataArray: creates an xarray dataset as the only group, gives the
array an arbitrary name, if name not set
| dict: creates an xarray dataset as the only group
| numpy array: creates an xarray dataset as the only group, gives the
array an arbitrary name
Expand Down Expand Up @@ -97,6 +99,10 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
# Cases that convert to xarray
if isinstance(obj, xr.Dataset):
dataset = obj
elif isinstance(obj, xr.DataArray):
if obj.name is None:
obj.name = 'plot'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the name is None, we should set it to x to follow current convention with numpy arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, true. I will change it.

dataset = obj.to_dataset()
elif isinstance(obj, dict):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif isinstance(obj, np.ndarray):
Expand All @@ -109,6 +115,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
return from_cmdstan(**kwargs)
else:
allowable_types = (
"xarray dataarray",
"xarray dataset",
"dict",
"netcdf filename",
Expand Down Expand Up @@ -149,6 +156,8 @@ def convert_to_dataset(obj, *, group="posterior", coords=None, dims=None):
pystan fit: Automatically extracts data
pymc3 trace: Automatically extracts data
xarray.Dataset: adds to InferenceData as only group
xarray.DataArray: creates an xarray dataset as the only group, gives the
array an arbitrary name, if name not set
dict: creates an xarray dataset as the only group
numpy array: creates an xarray dataset as the only group, gives the
array an arbitrary name
Expand Down