From bf754a637db33d891eb765a1015e5141826fa6e3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 21 Mar 2022 11:10:21 -0400 Subject: [PATCH 1/3] Implement DiscreteHMM.sample() --- pyro/distributions/hmm.py | 27 +++++++++++++++++++++++++++ tests/distributions/test_hmm.py | 14 ++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 28e065429a..241dc650ed 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -15,6 +15,7 @@ matrix_and_mvn_to_gaussian, mvn_to_gaussian, ) +from pyro.ops.indexing import Vindex from pyro.ops.special import safe_log from pyro.ops.tensor_utils import cholesky, cholesky_solve @@ -423,6 +424,32 @@ def filter(self, value): # Convert to a distribution. return Categorical(logits=logp, validate_args=self._validate_args) + @torch.no_grad() + def sample(self, sample_shape=torch.Size()): + assert self.duration is not None + + # Sample initial state. + S = self.initial_logits.size(-1) # state space size + init_shape = torch.Size(sample_shape) + self.batch_shape + (S,) + init_logits = self.initial_logits.expand(init_shape) + x = Categorical(logits=init_logits).sample() + + # Sample hidden states over time. + trans_shape = self.batch_shape + (self.duration, S, S) + trans_logits = self.transition_logits.expand(trans_shape) + xs = [] + for t in range(self.duration): + x = Categorical(logits=Vindex(trans_logits)[..., t, x, :]).sample() + xs.append(x) + x = torch.stack(xs, dim=-1) + + # Sample observations conditioned on hidden states. + obs_shape = self.batch_shape + (self.duration, S) + obs_dist = self.observation_dist.expand(obs_shape) + y = obs_dist.sample(sample_shape) + y = Vindex(y)[(Ellipsis, x) + (slice(None),) * obs_dist.event_dim] + return y + class GaussianHMM(HiddenMarkovModel): """ diff --git a/tests/distributions/test_hmm.py b/tests/distributions/test_hmm.py index ac0e67919d..97329c826b 100644 --- a/tests/distributions/test_hmm.py +++ b/tests/distributions/test_hmm.py @@ -55,6 +55,15 @@ def check_expand(old_dist, old_data): assert new_dist.log_prob(new_data).shape == new_batch_shape +def check_sample_shape(d): + if d.duration is None: + return + for sample_shape in [(), (2,), (3,)]: + sample = d.sample(sample_shape) + expected_shape = torch.Size(sample_shape + d.shape()) + assert sample.shape == expected_shape + + @pytest.mark.parametrize("num_steps", list(range(1, 20))) @pytest.mark.parametrize("state_dim", [2, 3]) @pytest.mark.parametrize("batch_shape", [(), (5,), (2, 4)], ids=str) @@ -179,6 +188,7 @@ def test_discrete_hmm_shape( expected_shape = broadcast_shape(init_shape, trans_shape[:-1], obs_shape[:-1]) assert actual.shape == expected_shape check_expand(d, data) + check_sample_shape(d) final = d.filter(data) assert isinstance(final, dist.Categorical) @@ -243,6 +253,7 @@ def test_discrete_hmm_categorical(num_steps): actual = d.log_prob(data) assert actual.shape == d.batch_shape check_expand(d, data) + check_sample_shape(d) # Check loss against TraceEnum_ELBO. @config_enumerate @@ -278,6 +289,7 @@ def test_discrete_hmm_diag_normal(num_steps): actual = d.log_prob(data) assert actual.shape == d.batch_shape check_expand(d, data) + check_sample_shape(d) # Check loss against TraceEnum_ELBO. @config_enumerate @@ -365,6 +377,7 @@ def test_gaussian_hmm_shape( actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) + check_sample_shape(d) x = d.rsample() assert x.shape == d.shape() @@ -1019,6 +1032,7 @@ def test_independent_hmm_shape( actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) + check_sample_shape(d) x = d.rsample() assert x.shape == d.shape() From 1fe8d7a02039cb831894c04cf913db90ce1b6f12 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 21 Mar 2022 21:07:01 -0400 Subject: [PATCH 2/3] Add comments --- pyro/distributions/hmm.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 241dc650ed..83d4ba979f 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -273,8 +273,10 @@ def _validate_sample(self, value): class DiscreteHMM(HiddenMarkovModel): """ Hidden Markov Model with discrete latent state and arbitrary observation - distribution. This uses [1] to parallelize over time, achieving - O(log(time)) parallel complexity. + distribution. + + This uses [1] to parallelize over time, achieving O(log(time)) parallel + complexity for computing :meth:`log_prob` and :meth:`filter`. The event_shape of this distribution includes time on the left:: @@ -290,6 +292,10 @@ class DiscreteHMM(HiddenMarkovModel): # homogeneous + homogeneous case: event_shape = (1,) + observation_dist.event_shape + The :meth:`sample` method is sequential (not parallized), slow, and memory + inefficient. It is intended for data generation only and is not recommended + during inference. + **References:** [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019) @@ -444,6 +450,14 @@ def sample(self, sample_shape=torch.Size()): x = torch.stack(xs, dim=-1) # Sample observations conditioned on hidden states. + # Note the simple sample-then-slice approach here generalizes to all + # distributions, but is inefficient. To implement a general optimal + # slice-then-sample strategy would require distributions to support + # slicing https://github.com/pyro-ppl/pyro/issues/3052. A simpler + # implementation might register a few slicing operators as is done with + # pyro.contrib.forecast.util.reshape_batch(). If you as a user need + # this function to be cheaper, feel free to submit a PR implementing + # one of these approaches. obs_shape = self.batch_shape + (self.duration, S) obs_dist = self.observation_dist.expand(obs_shape) y = obs_dist.sample(sample_shape) From f71c60746e9fad5b4c8416efe752c08bf4043812 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 21 Mar 2022 21:19:41 -0400 Subject: [PATCH 3/3] Add a test of distribution --- tests/distributions/test_hmm.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/distributions/test_hmm.py b/tests/distributions/test_hmm.py index 97329c826b..1bd88a3d1f 100644 --- a/tests/distributions/test_hmm.py +++ b/tests/distributions/test_hmm.py @@ -313,6 +313,21 @@ def model(data): assert_close(actual_loss, expected_loss) +def test_discrete_hmm_distribution(): + init_probs = torch.tensor([0.9, 0.1]) + trans_probs = torch.tensor( + [ + [[0.9, 0.1], [0.1, 0.9]], # noisy identity + [[0.1, 0.9], [0.9, 0.1]], # noisy flip + ] + ) + obs_dist = dist.Normal(torch.tensor([0.0, 1.0]), 0.1) + hmm = dist.DiscreteHMM(init_probs.log(), trans_probs.log(), obs_dist) + actual = hmm.sample([1000000]).mean(0) + expected = torch.tensor([0.1 * 0.9 + 0.9 * 0.1, 0.9**3 + 3 * 0.9 * 0.1**2]) + assert_close(actual, expected, atol=1e-3) + + @pytest.mark.parametrize("obs_dim", [1, 2]) @pytest.mark.parametrize("hidden_dim", [1, 3]) @pytest.mark.parametrize(