-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
concatenate_inference method broken #90
Comments
Here is a slightly modified version of def concatenate_inferences(
inf_list: List[az.InferenceData],
coords: dict,
concatenation_name: str = "feature"
) -> az.InferenceData:
"""Concatenates multiple single feature fits into one object.
:param inf_list: List of InferenceData objects for each feature
:type inf_list: List[az.InferenceData]
:param coords: Coordinates containing concatenation name labels
:type coords: dict
:param concatenation_name: Name of feature dimension used when
concatenating, defaults to "feature"
:type concatenation_name: str
:returns: Combined InferenceData object
:rtype: az.InferenceData
"""
group_list = []
group_list.append([x.posterior for x in inf_list])
# group_list.append([x.sample_stats for x in inf_list])
if "log_likelihood" in inf_list[0].groups():
group_list.append([x.log_likelihood for x in inf_list])
if "posterior_predictive" in inf_list[0].groups():
group_list.append([x.posterior_predictive for x in inf_list])
po_ds = xr.concat(group_list[0], concatenation_name)
#ss_ds = xr.concat(group_list[1], concatenation_name)
group_dict = {
"posterior": po_ds,
#"sample_stats": ss_ds
}
if "log_likelihood" in inf_list[0].groups():
ll_ds = xr.concat(group_list[1], concatenation_name)
group_dict["log_likelihood"] = ll_ds
if "posterior_predictive" in inf_list[0].groups():
pp_ds = xr.concat(group_list[2], concatenation_name)
group_dict["posterior_predictive"] = pp_ds
all_group_inferences = []
for group in group_dict:
# Set concatenation dim coords
group_ds = group_dict[group].assign_coords(
{concatenation_name: coords[concatenation_name]}
)
group_inf = az.InferenceData(**{group: group_ds}) # hacky
all_group_inferences.append(group_inf)
return az.concat(*all_group_inferences) Pull request incoming shortly |
gibsramen
added a commit
that referenced
this issue
Oct 4, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See https://github.com/biocore/BIRDMAn/blob/main/birdman/inference.py#L65
Getting this fixed would be super helpful, since it is much easier to use the xarray ops to compute summary statistics.
Sometimes inference objects don't have sample_stats attributes, which shouldn't cause a failure.
But even after commenting that out, there are weird indexing errors that can arise from merging inference data objects, that should probably be validated ahead of time.
The text was updated successfully, but these errors were encountered: