Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for PyJAGS in ArviZ #1219

Merged
merged 60 commits into from
Jun 9, 2020
Merged

Conversation

michaelnowotny
Copy link
Contributor

@michaelnowotny michaelnowotny commented Jun 2, 2020

This pull request adds functionality to translate posterior samples generated by PyJAGS to an ArviZ inference data object via a new public function 'from_pyjags'. A set of unit tests checks that a round trip translation from PyJAGS to ArviZ and back retains the same information. A new section has been added to the documentation notebook illustrating the use ArviZ in conjunction with PyJAGS.

Description

Checklist

  • Follows official PR format
  • Includes a sample plot to visually illustrate the changes (only for plot-related functions)
  • New features are properly documented (with an example if appropriate)?
  • Includes new or updated tests to cover the new feature
  • Code style correct (follows pylint and black guidelines)
  • Changes are listed in changelog

@StanczakDominik
Copy link
Contributor

With respect to the Pylint failure: StackOverflow at https://stackoverflow.com/a/54045715/4417567 and existing arviz code such as

# pylint: disable=no-member, invalid-name, redefined-outer-name
both seem to suggest disabling that particular linting. There's another workaround at https://stackoverflow.com/a/57015304/4417567 with the name parameter to fixture, but it seems to be rather more complicated.

@ahartikainen
Copy link
Contributor

Hi, thanks for the PR.

I think io_pyjags.py could follow similar structure as other _io.py files. This will make it easier to expand the functionality and add special handling for different groups.

For e.g

https://github.com/arviz-devs/arviz/blob/master/arviz/data/io_pystan.py

Create a converter class

class PyStanConverter:
    """Encapsulate PyStan specific logic."""

    def __init__(
        self,
        *,
        posterior=None,
        posterior_predictive=None,
        predictions=None,
        prior=None,
        prior_predictive=None,
        observed_data=None,
        constant_data=None,
        predictions_constant_data=None,
        log_likelihood=None,
        coords=None,
        dims=None,
        save_warmup=None,
    ):
        self.posterior = posterior
        self.posterior_predictive = posterior_predictive
        self.predictions = predictions
        self.prior = prior
        self.prior_predictive = prior_predictive
        self.observed_data = observed_data
        self.constant_data = constant_data
        self.predictions_constant_data = predictions_constant_data
        self.log_likelihood = log_likelihood
        self.coords = coords
        self.dims = dims
        self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup

        import pystan  # pylint: disable=import-error

        self.pystan = pystan

Then each group is handled with a class method which return xarray datasets

    @requires("posterior")
    def posterior_to_xarray(self):
        """Extract posterior samples from fit."""
        posterior = self.posterior
       ....
        data, data_warmup = get_draws(posterior, ignore=ignore, warmup=self.save_warmup)

        return (
            dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
            dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
        )

Then finally a method to transform class to InferenceData

    def to_inference_data(self):
        """Convert all available data to an InferenceData object.
        Note that if groups can not be created (i.e., there is no `fit`, so
        the `posterior` and `sample_stats` can not be extracted), then the InferenceData
        will not have those groups.
        """
        data_dict = self.data_to_xarray()
        return InferenceData(
            save_warmup=self.save_warmup,
            **{
                "posterior": self.posterior_to_xarray(),
                "sample_stats": self.sample_stats_to_xarray(),
                "log_likelihood": self.log_likelihood_to_xarray(),
                "posterior_predictive": self.posterior_predictive_to_xarray(),
                "predictions": self.predictions_to_xarray(),
                "prior": self.prior_to_xarray(),
                "sample_stats_prior": self.sample_stats_prior_to_xarray(),
                "prior_predictive": self.prior_predictive_to_xarray(),
                **({} if data_dict is None else data_dict),
            },
        )

To extract library specific data can use "external" functions as is the case in PyStan which return dictionaries containing data

def get_draws(fit, variables=None, ignore=None, warmup=False):
    """Extract draws from PyStan fit."""
    ....
    return data, data_warmup
def get_sample_stats(fit, warmup=False):
    """Extract sample stats from PyStan fit."""
    ...
    return data, data_warmup
```

And `from_xyz` function then can be defined as 

def from_pystan(
    posterior=None,
    *,
    posterior_predictive=None,
    predictions=None,
    prior=None,
    ...
):
    """Convert PyStan data into an InferenceData object.
    For a usage example read the
    :doc:`Cookbook section on from_pystan </notebooks/InferenceDataCookbook>`
    Parameters
    ----------
    posterior : StanFit4Model or stan.fit.Fit
        PyStan fit object for posterior.
    posterior_predictive : str, a list of str
        Posterior predictive samples for the posterior.
    predictions : str, a list of str
        Out-of-sample predictions for the posterior.
    prior : StanFit4Model or stan.fit.Fit
        PyStan fit object for prior.
    ...
    Returns
    -------
    InferenceData object
    """
        return PyStanConverter(
            posterior=posterior,
            posterior_predictive=posterior_predictive,
            predictions=predictions,
            prior=prior,
            ...
        ).to_inference_data()
```

@ahartikainen
Copy link
Contributor

Also, notice that PyStan returns also the warmup iterations, but other libs can return only posterior samples. So the structure might be a bit different.

- removed dependence on az.convert_to_inference_data
- added support for prior distribution
@michaelnowotny
Copy link
Contributor Author

Thank you for the input Ari! As you suspected, the solution looks slightly different from the PyStan implementation. All that JAGS sample returns is a dictionary mapping variable names to chains. There is no fit object as in PyStan.

Warm-up iterations can be handled by drawing samples and simply not saving them and then drawing the actual samples that are kept. With ArviZ it is now possible to simply draw both warm-up and actual samples in one step using PyJAGS and letting the from_pyjags function know how many of those were warmup samples.

Copy link
Contributor

@ahartikainen ahartikainen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

Added some comments

arviz/data/io_pyjags.py Outdated Show resolved Hide resolved
arviz/data/io_pyjags.py Outdated Show resolved Hide resolved
arviz/data/io_pyjags.py Outdated Show resolved Hide resolved
arviz/data/io_pyjags.py Outdated Show resolved Hide resolved
@OriolAbril
Copy link
Member

Could one of you take the notebook and rerun this particular section once everything else is set for release?

We are rerunning all notebooks in #1217 so this should be no problem.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that there is the try except block when importing pyjags, do we want to maybe test with pyjags installed on latest external build and no pyjags in special?

I can help with Azure to do this and also adding even another test to check library attrs are correctly stored.

arviz/data/io_pyjags.py Outdated Show resolved Hide resolved
arviz/data/io_pyjags.py Outdated Show resolved Hide resolved
arviz/tests/external_tests/test_data_pyjags.py Outdated Show resolved Hide resolved
@OriolAbril OriolAbril closed this Jun 6, 2020
@OriolAbril
Copy link
Member

Sorry, big fingers on mobile

@OriolAbril OriolAbril reopened this Jun 6, 2020
Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I don't want to block merging.

I am still curious about warmup_prior existence, does this mean that jags does not use forward sampling to get prior samples? Also why are posterior and prior warmups the same?

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nits on documentation

arviz/data/io_pyjags.py Show resolved Hide resolved
doc/api.rst Outdated Show resolved Hide resolved
Copy link
Contributor

@ahartikainen ahartikainen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

At some point I think we should use pyjags to create a prior and posterior for tests.

But that can be done later.

michaelnowotny and others added 2 commits June 8, 2020 11:32
Co-authored-by: Oriol Abril <oriol.abril.pla@gmail.com>
Co-authored-by: Oriol Abril <oriol.abril.pla@gmail.com>
@michaelnowotny
Copy link
Contributor Author

LGTM, I don't want to block merging.

I am still curious about warmup_prior existence, does this mean that jags does not use forward sampling to get prior samples? Also why are posterior and prior warmups the same?

Good question. As far as I can see JAGS itself has no functionality to explicitly sample from the prior. However one can simply create a model that leaves out the likelihood (and doesn't use any data). I have included such a separate prior model in the eight schools example in the inference data cookbook. Theoretically, one could directly sample hierarchically from the graph of (hyper)priors. One would not need to resort to Metropolis-Hastings (or even Gibbs) sampling. But I do not know whether JAGS realized that and samples directly. The possibility of warmup iterations seems like more robust move.

More modern MCMC packages such as PyMC3 have features to sample from all kinds of distributions associated with Bayesian analysis including predictive distributions and so it makes sense for ArviZ to support those. JAGS is basically a Gibbs sampler that draws from a target distribution described by a graphical model. It is not a Turing complete language like Stan that lets you generate quantities and solve ODEs while sampling.

@OriolAbril OriolAbril merged commit d8b14cf into arviz-devs:master Jun 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants