From 0963976237c03e7f223c9af72fda718f523a8f8a Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 5 Apr 2024 00:51:21 +0300 Subject: [PATCH 1/6] Added Metropolis-Hastings algorithm based resampler (pyro.infer.predictive.MHResampler) that converts weighed samples into equally weighed samples. --- pyro/infer/__init__.py | 3 +- pyro/infer/predictive.py | 107 +++++++++++++++++++++++++++++++++++++-- pyro/infer/util.py | 20 ++++++++ 3 files changed, 125 insertions(+), 5 deletions(-) diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 6934bd29fe..3a6a37ce5b 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -12,7 +12,7 @@ from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS from pyro.infer.mcmc.rwkernel import RandomWalkKernel -from pyro.infer.predictive import Predictive, WeighedPredictive +from pyro.infer.predictive import MHResampler, Predictive, WeighedPredictive from pyro.infer.renyi_elbo import RenyiELBO from pyro.infer.rws import ReweightedWakeSleep from pyro.infer.smcfilter import SMCFilter @@ -44,6 +44,7 @@ "JitTraceMeanField_ELBO", "JitTrace_ELBO", "MCMC", + "MHResampler", "NUTS", "Predictive", "RandomWalkKernel", diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index ea89aff5e5..8244fd7274 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from dataclasses import dataclass +from dataclasses import dataclass, fields from functools import reduce -from typing import List, Union +from typing import Callable, List, Union import torch import pyro import pyro.poutine as poutine from pyro.infer.importance import LogWeightsMixin -from pyro.infer.util import plate_log_prob_sum +from pyro.infer.util import CloneMixin, plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -320,7 +320,7 @@ def get_vectorized_trace(self, *args, **kwargs): @dataclass(frozen=True, eq=False) -class WeighedPredictiveResults(LogWeightsMixin): +class WeighedPredictiveResults(LogWeightsMixin, CloneMixin): """ Return value of call to instance of :class:`WeighedPredictive`. """ @@ -450,3 +450,102 @@ def forward(self, *args, **kwargs): guide_log_prob=guide_log_prob, model_log_prob=model_log_prob, ) + + +class MHResampler(torch.nn.Module): + """ + Resampler for weighed samples that is based on the Metropolis-Hastings algorithm. + """ + + def __init__( + self, + sampler: Callable, + source_samples_slice: slice = slice(0), + stored_samples_slice: slice = slice(0), + ): + super().__init__() + self.sampler = sampler + self.samples = None + self.transition_count = torch.tensor(0, dtype=torch.long) + self.source_samples = [] + self.source_samples_slice = source_samples_slice + self.stored_samples = [] + self.stored_samples_slice = stored_samples_slice + + def forward(self, *args, **kwargs): + """ + Perform single resampling step. + """ + with torch.no_grad(): + new_samples = self.sampler(*args, **kwargs) + # Store samples + self.source_samples.append(new_samples) + self.source_samples = self.source_samples[self.source_samples_slice] + if self.samples is None: + # First set of samples + self.samples = new_samples.clone() + self.transition_count = torch.zeros_like( + new_samples.log_weights, dtype=torch.long + ) + else: + # Apply Metropolis-Hastings algorithm + prob = torch.clamp( + new_samples.log_weights - self.samples.log_weights, max=0.0 + ).exp() + idx = torch.rand(*prob.shape) <= prob + self.transition_count[idx] += 1 + for field_desc in fields(self.samples): + field, new_field = getattr(self.samples, field_desc.name), getattr( + new_samples, field_desc.name + ) + if isinstance(field, dict): + for key in field: + field[key][idx] = new_field[key][idx] + else: + field[idx] = new_field[idx] + self.stored_samples.append(self.samples.clone()) + self.stored_samples = self.stored_samples[self.stored_samples_slice] + return self.samples + + def get_min_sample_transition_count(self): + """ + Return transition count of sample with minimal amount of transitions. + """ + return self.transition_count.min() + + def get_total_transition_count(self): + """ + Return total number of transitions. + """ + return self.transition_count.sum() + + def get_source_samples(self): + """ + Return source samples that were the input to the Metropolis-Hastings algorithm. + """ + return self.get_samples(self.source_samples) + + def get_stored_samples(self): + """ + Return stored samples that were the output of the Metropolis-Hastings algorithm. + """ + return self.get_samples(self.stored_samples) + + def get_samples(self, samples): + """ + Return samples that were sampled during execution of the Metropolis-Hastings algorithm. + """ + retval = dict() + for field_desc in fields(self.samples): + field_name, value = field_desc.name, getattr(self.samples, field_desc.name) + if isinstance(value, dict): + retval[field_name] = dict() + for key in value: + retval[field_name][key] = torch.cat( + [getattr(sample, field_name)[key] for sample in samples] + ) + else: + retval[field_name] = torch.cat( + [getattr(sample, field_name) for sample in samples] + ) + return self.samples.__class__(**retval) diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 13e1d9e12f..2efbb60ed8 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -5,6 +5,7 @@ import numbers from collections import Counter, defaultdict from contextlib import contextmanager +from dataclasses import fields import torch from opt_einsum import shared_intermediates @@ -358,3 +359,22 @@ def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor: [site["packed"]["log_prob"]], ) return log_prob_sum + + +class CloneMixin: + """ + Mixin class that adds ``.clone`` method to ``@dataclasses.dataclass`` decorated classes + that are made up of ``torch.Tensor`` fields. + """ + + def clone(self): + retval = dict() + for field_desc in fields(self): + field_name, value = field_desc.name, getattr(self, field_desc.name) + if isinstance(value, dict): + retval[field_name] = dict() + for key in value: + retval[field_name][key] = value[key].clone() + else: + retval[field_name] = value.clone() + return self.__class__(**retval) From 09e6aab9bd18810e6d6d9367932851662f7d8290 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 5 Apr 2024 16:07:42 +0300 Subject: [PATCH 2/6] Add tests for the pyro.infer.predictive.MHResampler weighed samples resampler. --- tests/infer/test_predictive.py | 72 +++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 319a1196dd..05ab1ceb63 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import logging + import pytest import torch @@ -8,8 +10,9 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO, WeighedPredictive +from pyro.infer import SVI, MHResampler, Predictive, Trace_ELBO, WeighedPredictive from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal +from pyro.ops.stats import quantile, weighed_quantile from tests.common import assert_close @@ -39,9 +42,18 @@ def beta_guide(num_trials): pyro.sample("phi", phi_posterior) -@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) +@pytest.mark.parametrize( + "predictive, num_svi_steps, test_unweighed_convergence", + [ + (Predictive, 5000, None), + (WeighedPredictive, 5000, True), + (WeighedPredictive, 1000, False), + ], +) @pytest.mark.parametrize("parallel", [False, True]) -def test_posterior_predictive_svi_manual_guide(parallel, predictive): +def test_posterior_predictive_svi_manual_guide( + parallel, predictive, num_svi_steps, test_unweighed_convergence +): true_probs = torch.ones(5) * 0.7 num_trials = ( torch.ones(5) * 400 @@ -51,9 +63,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=3.0)), elbo) - for i in range( - 5000 - ): # Increased to 5000 from 1000 in order for guide optimization to converge + for i in range(num_svi_steps): svi.step(num_trials) posterior_predictive = predictive( model, @@ -70,10 +80,52 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): ) marginal_return_vals = weighed_samples.samples["_RETURN"] assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape - # Weights should be uniform as the guide has the same distribution as the model - assert weighed_samples.log_weights.std() < 0.6 - # Effective sample size should be close to actual number of samples taken from the guide - assert weighed_samples.get_ESS() > 0.8 * num_samples + # Resample weighed samples + resampler = MHResampler(posterior_predictive) + for resampling_count in range(10): + resampled_weighed_samples = resampler( + num_trials, model_guide=conditioned_model + ) + resampled_marginal_return_vals = resampled_weighed_samples.samples["_RETURN"] + # Calculate CDF quantiles + quantile_test_point = 0.95 + quantile_test_point_value = quantile( + marginal_return_vals, [quantile_test_point] + )[0] + weighed_quantile_test_point_value = weighed_quantile( + marginal_return_vals, [quantile_test_point], weighed_samples.log_weights + )[0] + resampled_quantile_tesT_point_value = quantile( + resampled_marginal_return_vals, [quantile_test_point] + )[0] + logging.info( + "Unweighed quantile at test point is: " + str(quantile_test_point_value) + ) + logging.info( + "Weighed quantile at test point is: " + + str(weighed_quantile_test_point_value) + ) + logging.info( + "Resampled quantile at test point is: " + + str(resampled_quantile_tesT_point_value) + ) + # Weighed and resampled quantiles should match + assert_close( + weighed_quantile_test_point_value, + resampled_quantile_tesT_point_value, + rtol=0.01, + ) + if test_unweighed_convergence: + # Weights should be uniform as the guide has the same distribution as the model + assert weighed_samples.log_weights.std() < 0.6 + # Effective sample size should be close to actual number of samples taken from the guide + assert weighed_samples.get_ESS() > 0.8 * num_samples + # Weighed and unweighed quantiles should match if guide converged to true model + assert_close( + quantile_test_point_value, + resampled_quantile_tesT_point_value, + rtol=0.01, + ) assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1) From eae323b7693fa009da8a2ee12c1b7d0085107c51 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 5 Apr 2024 17:56:52 +0300 Subject: [PATCH 3/6] Add documentation for the pyro.infer.predictive.MHResampler weighed samples resampler. --- pyro/infer/predictive.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 8244fd7274..f2715da381 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -453,8 +453,30 @@ def forward(self, *args, **kwargs): class MHResampler(torch.nn.Module): - """ - Resampler for weighed samples that is based on the Metropolis-Hastings algorithm. + r""" + Resampler for weighed samples that generates equally weighed samples from the distribution + specified by the weighed samples ``sampler``. + + The resampling is based on the Metropolis-Hastings algorithm. + Given an initial sample :math:`x` subsequent samples are generated by: + + - Sampling from the ``guide`` a new sample candidate :math:`x'` with probability :math:`g(x')`. + - Calculate an acceptance probability + :math:`A(x', x) = \min\left(1, \frac{P(x')}{P(x)} \frac{g(x)}{g(x')}\right)` + with :math:`P` being the ``model``. + - With probability :math:`A(x', x)` accept the new sample candidate :math:`x'` + as the next sample, otherwise set the current sample :math:`x` as the next sample. + + The above is the Metropolis-Hastings algorithm with the new sample candidate + proposal distribution being equal to the ``guide`` and independent of the + current sample such that :math:`g(x')=g(x' \mid x)`. + + In case the ``guide`` perfectly tracks the ``model`` this sampler will do nothing + as the acceptance probability :math:`A(x', x)` will always be one. + + :param callable sampler: When called returns :class:`WeighedPredictiveResults`. + :param slice source_samples_slice: Select source samples for storage (default is `slice(0)`, i.e. none). + :param slice stored_samples_slice: Select output samples for storage (default is `slice(0)`, i.e. none). """ def __init__( @@ -475,6 +497,7 @@ def __init__( def forward(self, *args, **kwargs): """ Perform single resampling step. + Returns :class:`WeighedPredictiveResults` """ with torch.no_grad(): new_samples = self.sampler(*args, **kwargs) From 9b83ace8c550613a7ec8e73163ac63c2241b9720 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Sun, 7 Apr 2024 16:19:47 +0300 Subject: [PATCH 4/6] Add notes on pyro.infer.predictive.MHSampler behavior to the documentation. --- pyro/infer/predictive.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index f2715da381..fa0298066b 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -471,12 +471,23 @@ class MHResampler(torch.nn.Module): proposal distribution being equal to the ``guide`` and independent of the current sample such that :math:`g(x')=g(x' \mid x)`. - In case the ``guide`` perfectly tracks the ``model`` this sampler will do nothing - as the acceptance probability :math:`A(x', x)` will always be one. - :param callable sampler: When called returns :class:`WeighedPredictiveResults`. :param slice source_samples_slice: Select source samples for storage (default is `slice(0)`, i.e. none). :param slice stored_samples_slice: Select output samples for storage (default is `slice(0)`, i.e. none). + + .. _mhsampler-behavior: + + **Notes on Sampler Behavior:** + + - In case the ``guide`` perfectly tracks the ``model`` this sampler will do nothing + as the acceptance probability :math:`A(x', x)` will always be one. + - Furtheremore, if the guide is approximately separable, i.e. :math:`g(z_A, z_B) \approx g_A(z_A) g_B(z_B)`, + with :math:`g_A(z_A)` pefectly tracking the ``model`` and :math:`g_B(z_B)` poorly tracking the ``model``, + quantiles of :math:`z_A` calculated from samples taken from :class:`MHResampler`, will have much lower + variance then quantiles of :math:`z_A` calculated by using :any:`weighed_quantile`, as the effective sample size + of the calculation using :any:`weighed_quantile` will be low due to :math:`g_B(z_B)` poorly tracking + the ``model``, whereas when using :class:`MHResampler` the poor ``model`` tracking of :math:`g_B(z_B)` has + negligible affect on the effective sample size of :math:`z_A` samples. """ def __init__( From 98d6692db4cee2e6d16e1181cb01b5c0e9dfbb4a Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 9 Apr 2024 11:12:57 +0300 Subject: [PATCH 5/6] Add example to docstring of pyro.infer.predictive.MHResampler. --- pyro/infer/predictive.py | 45 ++++++++++++++++++++++++++++++++++ tests/infer/test_predictive.py | 8 +++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index fa0298066b..43e90d1bda 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -475,6 +475,51 @@ class MHResampler(torch.nn.Module): :param slice source_samples_slice: Select source samples for storage (default is `slice(0)`, i.e. none). :param slice stored_samples_slice: Select output samples for storage (default is `slice(0)`, i.e. none). + The typical use case of :class:`MHResampler` would be to convert weighed samples + generated by :class:`WeighedPredictive` into equally weighed samples from the target distribution. + + Example:: + + def model(): + ... + + def guide(): + ... + + def conditioned_model(): + ... + + # Fit guide + elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) + svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=3.0)), elbo) + for i in range(num_svi_steps): + svi.step() + + # Create callable that returns weighed samples + posterior_predictive = WeighedPredictive(model, + guide=beta_guide, + num_samples=num_samples, + parallel=parallel, + return_sites=["_RETURN"]) + + prob = 0.95 + + weighed_samples = posterior_predictive(model_guide=conditioned_model) + # Calculate quantile directly from weighed samples + weighed_samples_quantile = weighed_quantile(weighed_samples.samples['_RETURN'], + [prob], + weighed_samples.log_weights)[0] + + resampler = MHResampler(posterior_predictive) + for resampling_count in range(10): + resampled_weighed_samples = resampler(model_guide=conditioned_model) + # Calculate quantile from resampled weighed samples (samples are equally weighed) + resampled_weighed_samples_quantile = quantile(resampled_weighed_samples.samples[`_RETURN`], + [prob])[0] + + # Quantiles calculated using both methods should be identical + assert_close(weighed_samples_quantile, resampled_weighed_samples_quantile, rtol=0.01) + .. _mhsampler-behavior: **Notes on Sampler Behavior:** diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 05ab1ceb63..b6f608343c 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -95,7 +95,7 @@ def test_posterior_predictive_svi_manual_guide( weighed_quantile_test_point_value = weighed_quantile( marginal_return_vals, [quantile_test_point], weighed_samples.log_weights )[0] - resampled_quantile_tesT_point_value = quantile( + resampled_quantile_test_point_value = quantile( resampled_marginal_return_vals, [quantile_test_point] )[0] logging.info( @@ -107,12 +107,12 @@ def test_posterior_predictive_svi_manual_guide( ) logging.info( "Resampled quantile at test point is: " - + str(resampled_quantile_tesT_point_value) + + str(resampled_quantile_test_point_value) ) # Weighed and resampled quantiles should match assert_close( weighed_quantile_test_point_value, - resampled_quantile_tesT_point_value, + resampled_quantile_test_point_value, rtol=0.01, ) if test_unweighed_convergence: @@ -123,7 +123,7 @@ def test_posterior_predictive_svi_manual_guide( # Weighed and unweighed quantiles should match if guide converged to true model assert_close( quantile_test_point_value, - resampled_quantile_tesT_point_value, + resampled_quantile_test_point_value, rtol=0.01, ) assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1) From e52b512d6b4d4601f30984dcdc7771d92b1eddc4 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 12 Apr 2024 23:43:59 +0300 Subject: [PATCH 6/6] Elaborated and fixed documentation of pyro.infer.predictive.MHResampler. --- pyro/infer/predictive.py | 11 +++++++++-- tests/infer/test_predictive.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 43e90d1bda..e30099c85e 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -477,6 +477,12 @@ class MHResampler(torch.nn.Module): The typical use case of :class:`MHResampler` would be to convert weighed samples generated by :class:`WeighedPredictive` into equally weighed samples from the target distribution. + Each time an instance of :class:`MHResampler` is called it returns a new set of samples, with the + samples generated by the first call being distributed according to the ``guide``, and with each + subsequent call the distribution of the samples becomes closer to that of the posterior predictive + disdtribution. It might take some experimentation in order to find out in each case how many times one would + need to call an instance of :class:`MHResampler` in order to be close enough to the posterior + predictive distribution. Example:: @@ -497,7 +503,7 @@ def conditioned_model(): # Create callable that returns weighed samples posterior_predictive = WeighedPredictive(model, - guide=beta_guide, + guide=guide, num_samples=num_samples, parallel=parallel, return_sites=["_RETURN"]) @@ -511,7 +517,8 @@ def conditioned_model(): weighed_samples.log_weights)[0] resampler = MHResampler(posterior_predictive) - for resampling_count in range(10): + num_mh_steps = 10 + for mh_step_count in range(num_mh_steps): resampled_weighed_samples = resampler(model_guide=conditioned_model) # Calculate quantile from resampled weighed samples (samples are equally weighed) resampled_weighed_samples_quantile = quantile(resampled_weighed_samples.samples[`_RETURN`], diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index b6f608343c..ca155ed2fd 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -82,7 +82,8 @@ def test_posterior_predictive_svi_manual_guide( assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape # Resample weighed samples resampler = MHResampler(posterior_predictive) - for resampling_count in range(10): + num_mh_steps = 10 + for mh_step_count in range(num_mh_steps): resampled_weighed_samples = resampler( num_trials, model_guide=conditioned_model )