Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support HMM via marginalization of DiscreteMarkovChain #257

Merged
merged 4 commits into from
Feb 16, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 2, 2023

The following example defines a 2-state HMM, with a 0.9 transition probability of staying in the same state, and a Normal emission centered around -1 for state 0 and 1 for state 1.

import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pymc as pm
from pymc_experimental import MarginalModel
from pymc_experimental.distributions import DiscreteMarkovChain

with MarginalModel() as m:
    P = [[0.9, 0.1], [0.1, 0.9]]
    init_dist = pm.Categorical.dist(p=[1, 0])
    chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=10)
    emission = pm.Normal("emission", mu=chain * 2 - 1, sigma=0.5)

    m.marginalize([chain])
    
    with m:
        idata = pm.sample(100)

plt.plot(az.extract(idata)["emission"].values, color="k", alpha=0.03)
plt.yticks([-1, 1])
plt.ylabel("Emission")
plt.xlabel("Step");

image

Not implemented

Higher order lags and batch P matrices not supported due to complexity (and me not groking the exact API)

Closes #167

@junpenglao
Copy link
Member

Is it using the Viterbi algorithm?

@jessegrabowski
Copy link
Member

jessegrabowski commented Nov 21, 2023

Currently just the forward algorithm* to compute the logp

*It's not pure forward because we are computing and storing p(data | state) for all data-state pairs outside the scan over state transition probabilities. We should be O(N^2*T) on compute, but we're not maximally efficient on memory.

If I understand well, viterbi just gives the most probable sequence of hidden states in a maximum likelihood setting? We should be able to back that out of the posterior pretty easily. You'll need to school me if I'm over simplifying.

@ricardoV94 ricardoV94 added the enhancements New feature or request label Nov 21, 2023
pymc_experimental/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/marginal_model.py Outdated Show resolved Hide resolved
@junpenglao
Copy link
Member

If I understand well, viterbi just gives the most probable sequence of hidden states in a maximum likelihood setting? We should be able to back that out of the posterior pretty easily. You'll need to school me if I'm over simplifying.

yes Viterbi gives the posterior mode - but you are marginalizing the state to compute the likelihood here right?

@ricardoV94
Copy link
Member Author

but you are marginalizing the state to compute the likelihood here right?

Yes, but to be precise: to compute the logp of any dependent variables, which may be observed/unobserved or a mix.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 21, 2023

Seems like our "clever" approach is not correct. We need to combine the emission probabilities as we compute the state probabilities iteratively. I thought we could factor them out but it doesn't seem to be the case.

@jessegrabowski
Copy link
Member

jessegrabowski commented Nov 21, 2023

Seems like our "clever" approach is not correct. We need to combine the emission probabilities as we compute the state probabilities iteratively. I thought we could factor them out but it doesn't seem to be the case.

I added the example from this youtube vid as a test case, so we can get to a solution.

I'm in the process of refactoring the logp function to compute alpha correctly, but it's typically a nested loop. Here's numpy code:

transition_probs = np.array([[0.5, 0.5], 
                             [0.3, 0.7]])
initial_probs = np.array([0.375, 0.625])

T = 3
data = [0, 0, 1]
log_alpha = np.zeros((T, 2))
x_dists = [stats.bernoulli(p=0.2), stats.bernoulli(p=0.6)]

def eval_logp(x, dists):
    return np.array([d.logpmf(x) for d in dists])

log_alpha[0, :] = np.log(initial_probs) + eval_logp(data[0], x_dists)
for t in range(1, T):
    obs = data[t]
    for s in range(transition_probs.shape[0]):
        step_log_prob = x_dists[s].logpmf(obs) + np.log(transition_probs[:, s]) + log_alpha[t-1, :]
        log_alpha[t, s] = logsumexp(step_log_prob)

I'm trying to think how we can vectorize the inner loop, open to suggestions.

Nvm figured this out, it looks like:

for t in range(1, T):
    obs = data[t]
    step_log_prob = np.log(transition_probs) + log_alpha[t-1, :, None]    
    log_alpha[t, :] = eval_logp(obs, x_dists) + logsumexp(step_log_prob, axis=0)

@ricardoV94 ricardoV94 marked this pull request as ready for review January 16, 2024 15:31
@ricardoV94 ricardoV94 changed the title Marginalize hmm Support HMM via marginalization of DiscreteMarkovChain Jan 16, 2024
@ricardoV94 ricardoV94 force-pushed the marginalize_hmm branch 3 times, most recently from 958ada4 to f97dd6d Compare January 16, 2024 17:21
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy where it's landed. I wish we solved the two major problems (categorical emission and lags), but I'd rather have it merged and getting used than keep it in limbo while we wait for free time to make it perfect. I'll open an issue about those points after it's merged.

)
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a markov chain have a non-matrix transition probability?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be valid for batch dims

pymc_experimental/model/marginal_model.py Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved

# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
# the initial distribution. This is robust to everything the user can throw at it.
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No way to avoid this lambda here with vectorize_graph? I recall this used to be a little function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way to avoid it is to be a little function, but seems like a fine use for lambda?

pymc_experimental/tests/model/test_marginal_model.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member Author

ricardoV94 commented Feb 5, 2024

Categorical is one of the goals I have with #300

I think it's already working there, but I need to rebase and check once we merge this

@ricardoV94
Copy link
Member Author

ricardoV94 commented Feb 5, 2024

The lags is a nice follow up. The current distribution doesn't have a clear API for lags and batch dims, which further stopped me from addressing it here

We just need to agree on this and then it should be straightforward to support both.

The design question is: how do you specify a markov chain with 2 lags and an extra batch dimension? Say something with shape (5, 100) with two lags but different transition matrixes for each of the five batched chains

@jessegrabowski
Copy link
Member

jessegrabowski commented Feb 5, 2024

Yeah good questions. You're right it's not clear. I guess the distribution has to store the n_lags variable and marginalize will have to ask it? Not sure. In general, the way lags are handled are not good -- at higher orders the transition matrix is almost certainly going to be sparse, so it makes more sense to make one huge k**n, k**n sparse matrix and store a hash table to index into it for specific lag tuples.

We could let the user declare the lagged matrices as a tensor (since it's a bit more natural IMO at least) then internally flatten it down and build the index table, then rebuild the tensors after sampling.

But this is all for another PR, I 100% agree.

Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think issues should be open for some the marginalizations that are not implemented now as well. But otherwise, looks good to merge.

raise NotImplementedError(
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
f"Supported distribution include {supported_dists}"
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the old error message was more helpful

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was but it not gonna scale

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe link to the docs where it lists all the supported distributions then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The notes state that functionality is restricted, only finite discrete RVs are supported which is kind of true. Although we don't yet support Truncated/Censored of infinite discrete RVs which thus become finite: #95

We also don't support Multinomial which in theory is finite... So I think we the disclaimer functionality is restricted and this error message indicating the type of the RV that could not be marginalized it's fair game?

ricardoV94 and others added 4 commits February 16, 2024 13:44
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request marginalization
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cover DiscreteMarkovChain distributions with the marginal models
4 participants