Skip to content

Commit

Permalink
Marginalize DiscreteMarkovChain
Browse files Browse the repository at this point in the history
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
  • Loading branch information
ricardoV94 and jessegrabowski committed Jan 16, 2024
1 parent c6cd151 commit 958ada4
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 15 deletions.
113 changes: 99 additions & 14 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,16 @@
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.transforms import Chain
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
from pytensor import Mode
from pytensor import Mode, scan
from pytensor.compile import SharedVariable
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import (
Constant,
FunctionGraph,
ancestors,
clone_replace,
vectorize_graph,
)
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
from pytensor.graph.replace import vectorize_graph
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.elemwise import Elemwise
Expand All @@ -33,6 +28,8 @@

__all__ = ["MarginalModel"]

from pymc_experimental.distributions import DiscreteMarkovChain


class MarginalModel(Model):
"""Subclass of PyMC Model that implements functionality for automatic
Expand Down Expand Up @@ -245,16 +242,25 @@ def marginalize(
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
]

supported_dists = (Bernoulli, Categorical, DiscreteUniform)
for rv_to_marginalize in rvs_to_marginalize:
if rv_to_marginalize not in self.free_RVs:
raise ValueError(
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
)
if not isinstance(rv_to_marginalize.owner.op, supported_dists):

rv_op = rv_to_marginalize.owner.op
if isinstance(rv_op, DiscreteMarkovChain):
if rv_op.n_lags > 1:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
)
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
)
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
raise NotImplementedError(
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
f"Supported distribution include {supported_dists}"
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
)

if rv_to_marginalize.name in self.named_vars_to_dims:
Expand Down Expand Up @@ -490,6 +496,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
"""Base class for Finite Discrete Marginalized RVs"""


class DiscreteMarginalMarkovChainRV(MarginalRV):
"""Base class for Discrete Marginal Markov Chain RVs"""


def static_shape_ancestors(vars):
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
return [
Expand Down Expand Up @@ -618,11 +628,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
cloned_outputs = clone_replace(outputs, replace=replace_inputs)

marginalization_op = FiniteDiscreteMarginalRV(
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
marginalize_constructor = DiscreteMarginalMarkovChainRV
else:
marginalize_constructor = FiniteDiscreteMarginalRV

marginalization_op = marginalize_constructor(
inputs=list(replace_inputs.values()),
outputs=cloned_outputs,
ndim_supp=ndim_supp,
)

marginalized_rvs = marginalization_op(*replace_inputs.keys())
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
return rvs_to_marginalize, marginalized_rvs
Expand All @@ -638,6 +654,9 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
elif isinstance(op, DiscreteUniform):
lower, upper = constant_fold(rv.owner.inputs[3:])
return tuple(range(lower, upper + 1))
elif isinstance(op, DiscreteMarkovChain):
p = rv.owner.inputs[0]
return tuple(range(pt.get_vector_length(p[-1])))

raise NotImplementedError(f"Cannot compute domain for op {op}")

Expand Down Expand Up @@ -728,3 +747,69 @@ def logp_fn(marginalized_rv_const, *non_sequences):

# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
return joint_logps, *(pt.constant(0),) * (len(values) - 1)


@_logprob.register(DiscreteMarginalMarkovChainRV)
def marginal_hmm_logp(op, values, *inputs, **kwargs):

marginalized_rvs_node = op.make_node(*inputs)
inner_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)

chain_rv, *dependent_rvs = inner_rvs
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
domain = pt.arange(P.shape[-1], dtype="int32")

# Construct logp in two steps
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)

# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
chain_value = chain_rv.clone()
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))

# Reduce and add the batch dims beyond the chain dimension
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
chain_rv.type, logp_emissions_dict.values()
)

# Add a batch dimension for the domain of the chain
chain_shape = constant_fold(tuple(chain_rv.shape))
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})

# Step 2: Compute the transition probabilities
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
# We do it entirely in logs, though.

# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
# the initial distribution. This is robust to everything the user can throw at it.
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")(
batch_chain_value[..., 0]
)
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]

def step_alpha(logp_emission, log_alpha, log_P):
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
return logp_emission + step_log_prob

P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
log_alpha_seq, _ = scan(
step_alpha,
non_sequences=[log_P],
outputs_info=[log_alpha_init],
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
)
# Final logp is just the sum of the last scan state
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)

# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
dummy_logps = (pt.constant(0),) * (len(values) - 1)
return joint_logp, *dummy_logps
87 changes: 86 additions & 1 deletion pymc_experimental/tests/model/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy.special import log_softmax, logsumexp
from scipy.stats import halfnorm, norm

from pymc_experimental.distributions import DiscreteMarkovChain
from pymc_experimental.model.marginal_model import (
FiniteDiscreteMarginalRV,
MarginalModel,
Expand Down Expand Up @@ -467,7 +468,7 @@ def test_not_supported_marginalized():
y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10]))
with pytest.raises(
NotImplementedError,
match="Marginalization of withe dependent Multivariate RVs not implemented",
match="Marginalization with dependent Multivariate RVs not implemented",
):
m.marginalize(x)

Expand Down Expand Up @@ -642,3 +643,87 @@ def dist(idx, size):
):
pt = {"norm": test_value}
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))


@pytest.mark.parametrize("batch_chain", (True,), ids=lambda x: f"batch_chain={x}")
@pytest.mark.parametrize("batch_emission", (True,), ids=lambda x: f"batch_emission={x}")
def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
if batch_chain and not batch_emission:
pytest.skip("Redundant implicit combination")

with MarginalModel() as m:
P = [[0, 1], [1, 0]]
init_dist = pm.Categorical.dist(p=[1, 0])
chain = DiscreteMarkovChain(
"chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None
)
emission = pm.Normal(
"emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
)

m.marginalize([chain])
logp_fn = m.compile_logp()

test_value = np.array([-1, 1, -1, 1])
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
if batch_emission:
test_value = np.broadcast_to(test_value, (3, 4))
expected_logp *= 3
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)


@pytest.mark.parametrize(
"categorical_emission",
[
False,
# Categorical has a core vector parameter,
# so it is not possible to build a graph that uses elemwise operations exclusively
pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError)),
],
)
def test_marginalized_hmm_categorical_emission(categorical_emission):
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
with MarginalModel() as m:
P = np.array([[0.5, 0.5], [0.3, 0.7]])
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
if categorical_emission:
emission = pm.Categorical(
"emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6])
)
else:
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
m.marginalize([chain])

test_value = np.array([0, 0, 1])
expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
logp_fn = m.compile_logp()
np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp)


@pytest.mark.parametrize("batch_emission1", (False, True))
@pytest.mark.parametrize("batch_emission2", (False, True))
def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2):
emission1_shape = (2, 4) if batch_emission1 else (4,)
emission2_shape = (2, 4) if batch_emission2 else (4,)
with MarginalModel() as m:
P = [[0, 1], [1, 0]]
init_dist = pm.Categorical.dist(p=[1, 0])
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3)
emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape)
emission_2 = pm.Normal(
"emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape
)

with pytest.warns(UserWarning, match="multiple dependent variables"):
m.marginalize([chain])

logp_fn = m.compile_logp()

test_value = np.array([-1, 1, -1, 1])
multiplier = 2 + batch_emission1 + batch_emission2
expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
test_value_emission1 = np.broadcast_to(test_value, emission1_shape)
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
np.testing.assert_allclose(logp_fn(test_point), expected_logp)

0 comments on commit 958ada4

Please sign in to comment.