Skip to content

Commit

Permalink
Refactor logic to reduce add batched logp dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 16, 2024
1 parent 8b35cc6 commit a9224b6
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
vectorize_graph,
)
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorVariable
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Shape
from pytensor.tensor.special import log_softmax
Expand Down Expand Up @@ -381,41 +381,36 @@ def transform_input(inputs):

rv_dict = {}
rv_dims = {}
for seed, rv in zip(seeds, vars_to_recover):
for seed, marginalized_rv in zip(seeds, vars_to_recover):
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
if not isinstance(rv.owner.op, supported_dists):
if not isinstance(marginalized_rv.owner.op, supported_dists):
raise NotImplementedError(
f"RV with distribution {rv.owner.op} cannot be recovered. "
f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. "
f"Supported distribution include {supported_dists}"
)

m = self.clone()
rv = m.vars_to_clone[rv]
m.unmarginalize([rv])
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)
marginalized_rv = m.vars_to_clone[marginalized_rv]
m.unmarginalize([marginalized_rv])
dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
joint_logps = m.logp(vars=[marginalized_rv] + dependent_vars, sum=False)

marginalized_value = m.rvs_to_values[rv]
marginalized_value = m.rvs_to_values[marginalized_rv]
other_values = [v for v in m.value_vars if v is not marginalized_value]

# Handle batch dims for marginalized value and its dependent RVs
joint_logp = joint_logps[-1]
for dv in joint_logps[:-1]:
dbcast = dv.type.broadcastable
mbcast = marginalized_value.type.broadcastable
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
values_axis_bcast = [
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
]
joint_logp += dv.sum(values_axis_bcast)
marginalized_logp, *dependent_logps = joint_logps
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
marginalized_rv.type, dependent_logps
)

rv_shape = constant_fold(tuple(rv.shape))
rv_domain = get_domain_of_finite_discrete_rv(rv)
rv_shape = constant_fold(tuple(marginalized_rv.shape))
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
rv_domain_tensor = pt.moveaxis(
pt.full(
(*rv_shape, len(rv_domain)),
rv_domain,
dtype=rv.dtype,
dtype=marginalized_rv.dtype,
),
-1,
0,
Expand All @@ -431,7 +426,7 @@ def transform_input(inputs):
joint_logps_norm = log_softmax(joint_logps, axis=-1)
if return_samples:
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
if isinstance(rv.owner.op, DiscreteUniform):
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
sample_rv_outs += rv_domain[0]

rv_loglike_fn = compile_pymc(
Expand All @@ -456,18 +451,20 @@ def transform_input(inputs):
logps, samples = zip(*logvs)
logps = np.array(logps)
samples = np.array(samples)
rv_dict[rv.name] = samples.reshape(
rv_dict[marginalized_rv.name] = samples.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
)
else:
logps = np.array(logvs)

rv_dict["lp_" + rv.name] = logps.reshape(
rv_dict["lp_" + marginalized_rv.name] = logps.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
)
if rv.name in m.named_vars_to_dims:
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name])
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"]
if marginalized_rv.name in m.named_vars_to_dims:
rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
"lp_" + marginalized_rv.name + "_dim"
]

coords, dims = coords_and_dims_for_inferencedata(self)
dims.update(rv_dims)
Expand Down Expand Up @@ -647,6 +644,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
raise NotImplementedError(f"Cannot compute domain for op {op}")


def _add_reduce_batch_dependent_logps(
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
):
"""Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""

mbcast = marginalized_type.broadcastable
reduced_logps = []
for dependent_logp in dependent_logps:
dbcast = dependent_logp.type.broadcastable
dim_diff = len(dbcast) - len(mbcast)
mbcast_aligned = (True,) * dim_diff + mbcast
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
reduced_logps.append(dependent_logp.sum(vbcast_axis))
return pt.add(*reduced_logps)


@_logprob.register(FiniteDiscreteMarginalRV)
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
Expand All @@ -662,17 +675,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)

# Reduce logp dimensions corresponding to broadcasted variables
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]]
for inner_rv, inner_value in inner_rvs_to_values.items():
if inner_rv is marginalized_rv:
continue
vbcast = inner_value.type.broadcastable
mbcast = marginalized_rv.type.broadcastable
mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)

# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv])
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
marginalized_rv.type, logps_dict.values()
)

# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
# values of the marginalized RV
# Some inputs are not root inputs (such as transformed projections of value variables)
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
Expand Down Expand Up @@ -700,6 +708,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
)

# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
# TODO: Try vectorize here
if len(marginalized_rv_domain) <= 10:
joint_logps = [
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)
Expand Down

0 comments on commit a9224b6

Please sign in to comment.