Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concatenate_inference method broken #90

Closed
mortonjt opened this issue Sep 25, 2023 · 1 comment
Closed

concatenate_inference method broken #90

mortonjt opened this issue Sep 25, 2023 · 1 comment

Comments

@mortonjt
Copy link

mortonjt commented Sep 25, 2023

See https://github.com/biocore/BIRDMAn/blob/main/birdman/inference.py#L65

Getting this fixed would be super helpful, since it is much easier to use the xarray ops to compute summary statistics.

Sometimes inference objects don't have sample_stats attributes, which shouldn't cause a failure.

But even after commenting that out, there are weird indexing errors that can arise from merging inference data objects, that should probably be validated ahead of time.

Traceback (most recent call last):
  File "/home/centos/birdman/summarize.py", line 136, in <module>
    samples = concatenate_inferences(inf_list, 'y_predict', 'log_lhood', coords)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/centos/birdman/summarize.py", line 116, in concatenate_inferences
    {concatenation_name: coords[concatenation_name],
                         ~~~~~~^^^^^^^^^^^^^^^^^^^^
KeyError: 'feature'
@mortonjt
Copy link
Author

Here is a slightly modified version of concatenate_inferences

def concatenate_inferences(
    inf_list: List[az.InferenceData],
    coords: dict,
    concatenation_name: str = "feature"
) -> az.InferenceData:
    """Concatenates multiple single feature fits into one object.                                                                                                                       
                                                                                                                                                                                        
    :param inf_list: List of InferenceData objects for each feature                                                                                                                     
    :type inf_list: List[az.InferenceData]                                                                                                                                              
                                                                                                                                                                                        
    :param coords: Coordinates containing concatenation name labels                                                                                                                     
    :type coords: dict                                                                                                                                                                  
                                                                                                                                                                                        
    :param concatenation_name: Name of feature dimension used when                                                                                                                      
        concatenating, defaults to "feature"                                                                                                                                            
    :type concatenation_name: str                                                                                                                                                       
                                                                                                                                                                                        
    :returns: Combined InferenceData object                                                                                                                                             
    :rtype: az.InferenceData                                                                                                                                                            
    """
    group_list = []
    group_list.append([x.posterior for x in inf_list])
    # group_list.append([x.sample_stats for x in inf_list])                                                                                                                            
    if "log_likelihood" in inf_list[0].groups():
        group_list.append([x.log_likelihood for x in inf_list])
    if "posterior_predictive" in inf_list[0].groups():
        group_list.append([x.posterior_predictive for x in inf_list])

    po_ds = xr.concat(group_list[0], concatenation_name)
    #ss_ds = xr.concat(group_list[1], concatenation_name)                                                                                                                              
    group_dict = {
        "posterior": po_ds,
        #"sample_stats": ss_ds                                                                                                                                                         
    }
    if "log_likelihood" in inf_list[0].groups():
        ll_ds = xr.concat(group_list[1], concatenation_name)
        group_dict["log_likelihood"] = ll_ds
    if "posterior_predictive" in inf_list[0].groups():
        pp_ds = xr.concat(group_list[2], concatenation_name)
        group_dict["posterior_predictive"] = pp_ds

    all_group_inferences = []
    for group in group_dict:
	# Set concatenation dim coords                                                                                                                                                 
	group_ds = group_dict[group].assign_coords(
            {concatenation_name: coords[concatenation_name]}
        )

        group_inf = az.InferenceData(**{group: group_ds})  # hacky                                                                                                                     
        all_group_inferences.append(group_inf)

    return az.concat(*all_group_inferences)

Pull request incoming shortly

gibsramen added a commit that referenced this issue Oct 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant