-
-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow xarray.Dataarray input to plots. #1120
Changes from 8 commits
a6c1f97
93f829d
20383ce
bdf87c6
837585e
7b311c7
8b85c90
ebaad2c
1cf8951
33648b8
b1b3189
b9c24af
4a59e64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -873,3 +873,41 @@ def test_id_conversion_args(self): | |
assert isinstance(inference_data, InferenceData) | ||
assert set(inference_data.posterior.coords["Ivies"].values) == set(IVIES) | ||
assert inference_data.posterior["theta"].dims == ("chain", "draw", "Ivies") | ||
|
||
|
||
|
||
class TestDataArrayToDataset: | ||
def test_1d_dataset(self): | ||
size = 100 | ||
dataset = convert_to_dataset(xr.DataArray(np.random.randn(1, size), dims=('chain', 'draw'))) | ||
assert len(dataset.data_vars) == 1 | ||
|
||
assert dataset.chain.shape == (1, ) | ||
assert dataset.draw.shape == (size, ) | ||
|
||
|
||
def test_nd_to_dataset(self): | ||
shape = (1, 2, 3, 4, 5) | ||
dataset = convert_to_dataset( | ||
xr.DataArray(np.random.randn(*shape), | ||
dims=('chain', 'draw', 'dim_0', 'dim_1', 'dim_2'))) | ||
assert len(dataset.data_vars) == 1 | ||
var_name = list(dataset.data_vars)[0] | ||
|
||
assert dataset.chain.shape == shape[:1] | ||
assert dataset.draw.shape == shape[1:2] | ||
assert dataset[var_name].shape == shape | ||
|
||
def test_nd_to_inference_data(self): | ||
shape = (1, 2, 3, 4, 5) | ||
inference_data = convert_to_inference_data( | ||
xr.DataArray(np.random.randn(*shape), | ||
dims=('chain', 'draw', 'dim_0', 'dim_1', 'dim_2')), group="prior") | ||
assert hasattr(inference_data, "prior") | ||
assert len(inference_data.prior.data_vars) == 1 | ||
var_name = list(inference_data.prior.data_vars)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit pick, move this up above the asserts so all the asserts are together. That way when reading the tests we can see all the logic in one place and all the asserts in another. Same for test above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah sure. |
||
|
||
assert inference_data.prior.chain.shape == shape[:1] | ||
assert inference_data.prior.draw.shape == shape[1:2] | ||
assert inference_data.prior[var_name].shape == shape | ||
assert repr(inference_data).startswith("Inference data with groups") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this last assert should be removed, it will only fail if the inference data object has not been created, and in this case all asserts before this one would have already failed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well observed. I'll change it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could set the name of the input dataarray in this test to make sure the name is not overwritten during conversion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I forgot to test the name. I'll make suitable changes