diff --git a/birdman/inference.py b/birdman/inference.py index 15fe528..ab3657c 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -84,21 +84,19 @@ def concatenate_inferences( """ group_list = [] group_list.append([x.posterior for x in inf_list]) - group_list.append([x.sample_stats for x in inf_list]) if "log_likelihood" in inf_list[0].groups(): group_list.append([x.log_likelihood for x in inf_list]) if "posterior_predictive" in inf_list[0].groups(): group_list.append([x.posterior_predictive for x in inf_list]) po_ds = xr.concat(group_list[0], concatenation_name) - ss_ds = xr.concat(group_list[1], concatenation_name) - group_dict = {"posterior": po_ds, "sample_stats": ss_ds} + group_dict = {"posterior": po_ds} if "log_likelihood" in inf_list[0].groups(): - ll_ds = xr.concat(group_list[2], concatenation_name) + ll_ds = xr.concat(group_list[1], concatenation_name) group_dict["log_likelihood"] = ll_ds if "posterior_predictive" in inf_list[0].groups(): - pp_ds = xr.concat(group_list[3], concatenation_name) + pp_ds = xr.concat(group_list[2], concatenation_name) group_dict["posterior_predictive"] = pp_ds all_group_inferences = [] diff --git a/tests/test_inference.py b/tests/test_inference.py index cb55ac0..e7bbe79 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,6 +1,9 @@ import numpy as np +import pytest from birdman import inference as mu +from birdman.default_models import NegativeBinomialSingle +from birdman import ModelIterator class TestToInference: @@ -78,3 +81,31 @@ def test_serial_ppll(self, example_model): nb_data = example_model.fit.stan_variable(v) nb_data = np.array(np.split(nb_data, 4, axis=0)) np.testing.assert_array_almost_equal(nb_data, inf_data) + + +@pytest.mark.parametrize("method", ["mcmc", "vi"]) +def test_concat(table_biom, metadata, method): + tbl = table_biom + md = metadata + + model_iterator = ModelIterator( + table=tbl, + model=NegativeBinomialSingle, + formula="host_common_name", + metadata=md, + ) + + infs = [] + for fname, model in model_iterator: + model.compile_model() + model.fit_model(method, num_draws=100) + infs.append(model.to_inference()) + + inf_concat = mu.concatenate_inferences( + infs, + coords={"feature": tbl.ids("observation")}, + ) + print(inf_concat.posterior) + exp_feat_ids = tbl.ids("observation") + feat_ids = inf_concat.posterior.coords["feature"].to_numpy() + assert (exp_feat_ids == feat_ids).all()