diff --git a/pymc/data.py b/pymc/data.py index 4f34dfacb3a..98f8a399704 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -423,6 +423,11 @@ def Data( # `convert_observed_data` takes care of parameter `value` and # transforms it to something digestible for PyTensor. arr = convert_observed_data(value) + if isinstance(arr, np.ma.MaskedArray): + raise NotImplementedError( + "Masked arrays or arrays with `nan` entries are not supported. " + "Pass them directly to `observed` if you want to trigger auto-imputation" + ) if mutable is None: warnings.warn( diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6ea639d9f64..d35d339e8f2 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -25,13 +25,14 @@ from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph -from pytensor.graph import node_rewriter +from pytensor.graph import FunctionGraph, node_rewriter from pytensor.graph.basic import Node, Variable from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.var import TensorVariable from typing_extensions import TypeAlias @@ -49,6 +50,7 @@ ) from pymc.exceptions import BlockModelAccessError from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob +from pymc.logprob.basic import logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.model import BlockModelAccess from pymc.printing import str_for_dist @@ -1148,3 +1150,145 @@ def logcdf(value, c): -np.inf, 0, ) + + +class PartialObservedRV(SymbolicRandomVariable): + """RandomVariable with partially observed subspace, as indicated by a boolean mask. + + See `create_partial_observed_rv` for more details. + """ + + +def create_partial_observed_rv( + rv: TensorVariable, + mask: Union[np.ndarray, TensorVariable], +) -> Tuple[ + Tuple[TensorVariable, TensorVariable], Tuple[TensorVariable, TensorVariable], TensorVariable +]: + """Separate observed and unobserved components of a RandomVariable. + + This function may return two independent RandomVariables or, if not possible, + two variables from a common `PartialObservedRV` node + + Parameters + ---------- + rv : TensorVariable + mask : tensor_like + Constant or variable boolean mask. True entries correspond to components of the variable that are not observed. + + Returns + ------- + observed_rv and mask : Tuple of TensorVariable + The observed component of the RV and respective indexing mask + unobserved_rv and mask: Tuple of TensorVariable + The unobserved component of the RV and respective indexing mask + joined_rv: + The symbolic join of the observed and unobserved components. + """ + if not mask.dtype == "bool": + raise ValueError( + f"mask must be an array or tensor of boolean dtype, got dtype: {mask.dtype}" + ) + + if mask.ndim > rv.ndim: + raise ValueError(f"mast can't have more dims than rv, got ndim: {mask.ndim}") + + antimask = ~mask + + can_rewrite = False + # Only pure RVs can be rewritten + if isinstance(rv.owner.op, RandomVariable): + ndim_supp = rv.owner.op.ndim_supp + + # All univariate RVs can be rewritten + if ndim_supp == 0: + can_rewrite = True + + # Multivariate RVs can be rewritten if masking does not split within support dimensions + else: + batch_dims = rv.type.ndim - ndim_supp + constant_mask = getattr(as_tensor_variable(mask), "data", None) + + # Indexing does not overlap with core dimensions + if mask.ndim <= batch_dims: + can_rewrite = True + + # Try to handle special case where mask is constant across support dimensions, + # TODO: This could be done by the rewrite itself + elif constant_mask is not None: + # We check if a constant_mask that only keeps the first entry of each support dim + # is equivalent to the original one after re-expanding. + trimmed_mask = constant_mask[(...,) + (0,) * ndim_supp] + expanded_mask = np.broadcast_to( + np.expand_dims(trimmed_mask, axis=tuple(range(-ndim_supp, 0))), + shape=constant_mask.shape, + ) + if np.array_equal(constant_mask, expanded_mask): + mask = trimmed_mask + antimask = ~trimmed_mask + can_rewrite = True + + if can_rewrite: + # Rewrite doesn't work with boolean masks. Should be fixed after https://github.com/pymc-devs/pytensor/pull/329 + mask, antimask = mask.nonzero(), antimask.nonzero() + + masked_rv = rv[mask] + fgraph = FunctionGraph(outputs=[masked_rv], clone=False) + [unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) + + antimasked_rv = rv[antimask] + fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False) + [observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) + + # Make a clone of the observedRV, with a distinct rng so that observed and + # unobserved are never treated as equivalent (and mergeable) nodes by pytensor. + _, size, _, *inps = observed_rv.owner.inputs + observed_rv = observed_rv.owner.op(*inps, size=size) + + # For all other cases use the more general PartialObservedRV + else: + # The symbolic graph simply splits the observed and unobserved components, + # so they can be given separate values. + dist_, mask_ = rv.type(), as_tensor_variable(mask).type() + observed_rv_, unobserved_rv_ = dist_[~mask_], dist_[mask_] + + observed_rv, unobserved_rv = PartialObservedRV( + inputs=[dist_, mask_], + outputs=[observed_rv_, unobserved_rv_], + ndim_supp=rv.owner.op.ndim_supp, + )(rv, mask) + + joined_rv = pt.empty(rv.shape, dtype=rv.type.dtype) + joined_rv = pt.set_subtensor(joined_rv[mask], unobserved_rv) + joined_rv = pt.set_subtensor(joined_rv[antimask], observed_rv) + + return (observed_rv, antimask), (unobserved_rv, mask), joined_rv + + +@_logprob.register(PartialObservedRV) +def partial_observed_rv_logprob(op, values, dist, mask, **kwargs): + # For the logp, simply join the values + [obs_value, unobs_value] = values + antimask = ~mask + joined_value = pt.empty_like(dist) + joined_value = pt.set_subtensor(joined_value[mask], unobs_value) + joined_value = pt.set_subtensor(joined_value[antimask], obs_value) + joined_logp = logp(dist, joined_value) + + # If we have a univariate RV we can split apart the logp terms + if op.ndim_supp == 0: + return joined_logp[antimask], joined_logp[mask] + # Otherwise, we can't (always/ easily) split apart logp terms. + # We return the full logp for the observed value, and a 0-nd array for the unobserved value + else: + return joined_logp.ravel(), pt.zeros((0,), dtype=joined_logp.type.dtype) + + +@_moment.register(PartialObservedRV) +def partial_observed_rv_moment(op, partial_obs_rv, rv, mask): + # Unobserved output + if partial_obs_rv.owner.outputs.index(partial_obs_rv) == 1: + return moment(rv)[mask] + # Observed output + else: + return moment(rv)[~mask] diff --git a/pymc/model.py b/pymc/model.py index ec94772192c..fdd87520320 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -44,11 +44,9 @@ from pytensor.compile import DeepCopyOp, get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant, Variable, graph_inputs -from pytensor.graph.fg import FunctionGraph from pytensor.scalar import Cast from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.random.type import RandomType from pytensor.tensor.sharedvar import ScalarSharedVariable from pytensor.tensor.var import TensorConstant, TensorVariable @@ -1409,67 +1407,27 @@ def make_obs_var( if total_size is not None: raise ValueError("total_size is not compatible with imputed variables") - if not isinstance(rv_var.owner.op, RandomVariable): - raise NotImplementedError( - "Automatic inputation is only supported for univariate RandomVariables." - f" {rv_var} of type {type(rv_var.owner.op)} is not supported." - ) - - if rv_var.owner.op.ndim_supp > 0: - raise NotImplementedError( - f"Automatic inputation is only supported for univariate " - f"RandomVariables, but {rv_var} is multivariate" - ) + from pymc.distributions.distribution import create_partial_observed_rv - # We can get a random variable comprised of only the unobserved - # entries by lifting the indices through the `RandomVariable` `Op`. + ( + (observed_rv, observed_mask), + (unobserved_rv, _), + joined_rv, + ) = create_partial_observed_rv(rv_var, mask) + observed_data = pt.as_tensor(data.data[observed_mask]) - masked_rv_var = rv_var[mask.nonzero()] - - fgraph = FunctionGraph( - [i for i in graph_inputs((masked_rv_var,)) if not isinstance(i, Constant)], - [masked_rv_var], - clone=False, - ) + # Register ObservedRV corresponding to observed component + observed_rv.name = f"{name}_observed" + self.create_value_var(observed_rv, transform=None, value_var=observed_data) + self.add_named_variable(observed_rv) + self.observed_RVs.append(observed_rv) - (missing_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) + # Register FreeRV corresponding to unobserved components + self.register_rv(unobserved_rv, f"{name}_missing", transform=transform) - self.register_rv(missing_rv_var, f"{name}_missing", transform=transform) - - # Now, we lift the non-missing observed values and produce a new - # `rv_var` that contains only those. - # - # The end result is two disjoint distributions: one for the missing - # values, and another for the non-missing values. - - antimask_idx = (~mask).nonzero() - nonmissing_data = pt.as_tensor_variable(data[antimask_idx].data) - unmasked_rv_var = rv_var[antimask_idx] - unmasked_rv_var = unmasked_rv_var.owner.clone().default_output() - - fgraph = FunctionGraph( - [i for i in graph_inputs((unmasked_rv_var,)) if not isinstance(i, Constant)], - [unmasked_rv_var], - clone=False, - ) - (observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) - # Make a clone of the RV, but let it create a new rng so that observed and - # missing are not treated as equivalent nodes by pytensor. This would happen - # if the size of the masked and unmasked array happened to coincide - _, size, _, *inps = observed_rv_var.owner.inputs - observed_rv_var = observed_rv_var.owner.op(*inps, size=size, name=f"{name}_observed") - observed_rv_var.tag.observations = nonmissing_data - - self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) - self.add_named_variable(observed_rv_var) - self.observed_RVs.append(observed_rv_var) - - # Create deterministic that combines observed and missing + # Register Deterministic that combines observed and missing # Note: This can widely increase memory consumption during sampling for large datasets - rv_var = pt.empty(data.shape, dtype=observed_rv_var.type.dtype) - rv_var = pt.set_subtensor(rv_var[mask.nonzero()], missing_rv_var) - rv_var = pt.set_subtensor(rv_var[antimask_idx], observed_rv_var) - rv_var = Deterministic(name, rv_var, self, dims) + rv_var = Deterministic(name, joined_rv, self, dims) else: if sps.issparse(data): diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index a16bf43db59..07207725651 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -352,7 +352,6 @@ def test_missing_data_model(self): # See https://github.com/pymc-devs/pymc/issues/5255 assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3) - @pytest.mark.xfail(reason="Multivariate partial observed RVs not implemented for V4") def test_mv_missing_data_model(self): data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1) @@ -361,19 +360,25 @@ def test_mv_missing_data_model(self): mu = pm.Normal("mu", 0, 1, size=2) sd_dist = pm.HalfNormal.dist(1.0, size=2) # pylint: disable=unpacking-non-sequence - chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True) + chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist) # pylint: enable=unpacking-non-sequence with pytest.warns(ImputationWarning): y = pm.MvNormal("y", mu=mu, chol=chol, observed=data) - inference_data = pm.sample(100, chains=2, return_inferencedata=True) + inference_data = pm.sample( + tune=100, + draws=100, + chains=2, + step=pm.Metropolis(), + idata_kwargs=dict(log_likelihood=True), + ) # make sure that data is really missing - assert isinstance(y.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)) + assert isinstance(y.owner.inputs[0].owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)) test_dict = { "posterior": ["mu", "chol_cov"], - "observed_data": ["y"], - "log_likelihood": ["y"], + "observed_data": ["y_observed"], + "log_likelihood": ["y_observed"], } fails = check_multiple_attrs(test_dict, inference_data) assert not fails diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 5921b68fb16..d5c2b184061 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -23,7 +23,7 @@ import pytest import scipy.stats as st -from pytensor import scan +from pytensor import scan, shared from pytensor.tensor import TensorVariable import pymc as pm @@ -42,14 +42,16 @@ CustomDist, CustomDistRV, CustomSymbolicDistRV, + PartialObservedRV, SymbolicRandomVariable, _moment, + create_partial_observed_rv, moment, ) from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple from pymc.distributions.transforms import log from pymc.exceptions import BlockModelAccessError -from pymc.logprob.basic import logcdf, logp +from pymc.logprob.basic import conditional_logp, logcdf, logp from pymc.model import Deterministic, Model from pymc.pytensorf import collect_default_updates from pymc.sampling import draw, sample @@ -700,3 +702,217 @@ def test_dtype(self, floatX): assert pm.DiracDelta.dist(2**16).dtype == "int32" assert pm.DiracDelta.dist(2**32).dtype == "int64" assert pm.DiracDelta.dist(2.0).dtype == floatX + + +class TestPartialObservedRV: + @pytest.mark.parametrize("symbolic_rv", (False, True)) + def test_univariate(self, symbolic_rv): + data = np.array([0.25, 0.5, 0.25]) + mask = np.array([False, False, True]) + + rv = pm.Normal.dist([1, 2, 3]) + if symbolic_rv: + # We use a Censored Normal so that PartialObservedRV is needed, + # but don't use the bounds for testing the logp + rv = pm.Censored.dist(rv, lower=-100, upper=100) + (obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask) + + # Test types + if symbolic_rv: + assert isinstance(obs_rv.owner.op, PartialObservedRV) + assert isinstance(unobs_rv.owner.op, PartialObservedRV) + else: + assert isinstance(obs_rv.owner.op, Normal) + assert isinstance(unobs_rv.owner.op, Normal) + + # Tesh shapes + assert tuple(obs_rv.shape.eval()) == (2,) + assert tuple(unobs_rv.shape.eval()) == (1,) + assert tuple(joined_rv.shape.eval()) == (3,) + + # Test logp + logp = conditional_logp( + {obs_rv: pt.as_tensor(data[~mask]), unobs_rv: pt.as_tensor(data[mask])} + ) + obs_logp, unobs_logp = pytensor.function([], list(logp.values()))() + np.testing.assert_allclose(obs_logp, st.norm([1, 2]).logpdf([0.25, 0.5])) + np.testing.assert_allclose(unobs_logp, st.norm([3]).logpdf([0.25])) + + @pytest.mark.parametrize( + "mask", + [ + pt.constant(np.array([[True, True, True, True]])), + pt.constant(np.array([[False, False, False, False]])), + ], + ) + def test_multivariate_constant_mask_separable(self, mask): + obs_component_selected = not mask.data[0, 0] + obs_data = np.array([[0.1, 0.4, 0.1, 0.4]]) + unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]]) + + rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4)) + (obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask) + + # Test types + assert isinstance(obs_rv.owner.op, pm.Dirichlet) + assert isinstance(unobs_rv.owner.op, pm.Dirichlet) + + # Test shapes + if obs_component_selected: + expected_obs_shape = (1, 4) + expected_unobs_shape = (0, 4) + else: + expected_obs_shape = (0, 4) + expected_unobs_shape = (1, 4) + assert tuple(obs_rv.shape.eval()) == expected_obs_shape + assert tuple(unobs_rv.shape.eval()) == expected_unobs_shape + assert tuple(joined_rv.shape.eval()) == (1, 4) + + # Test logp + logp = conditional_logp( + { + obs_rv: pt.as_tensor(obs_data)[obs_mask], + unobs_rv: pt.as_tensor(unobs_data)[unobs_mask], + } + ) + obs_logp, unobs_logp = pytensor.function([], list(logp.values()))() + if obs_component_selected: + expected_obs_logp = pm.logp(rv, obs_data).eval() + expected_unobs_logp = [] + else: + expected_obs_logp = [] + expected_unobs_logp = pm.logp(rv, unobs_data).eval() + np.testing.assert_allclose(obs_logp, expected_obs_logp) + np.testing.assert_allclose(unobs_logp, expected_unobs_logp) + + def test_multivariate_constant_mask_unseparable(self): + mask = pt.constant(np.array([[True, True, False, False]])) + obs_data = np.array([[0.1, 0.4, 0.1, 0.4]]) + unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]]) + + rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4)) + (obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask) + + # Test types + assert isinstance(obs_rv.owner.op, PartialObservedRV) + assert isinstance(unobs_rv.owner.op, PartialObservedRV) + + # Test shapes + assert tuple(obs_rv.shape.eval()) == (2,) + assert tuple(unobs_rv.shape.eval()) == (2,) + assert tuple(joined_rv.shape.eval()) == (1, 4) + + # Test logp + logp = conditional_logp( + { + obs_rv: pt.as_tensor(obs_data)[obs_mask], + unobs_rv: pt.as_tensor(unobs_data)[unobs_mask], + } + ) + obs_logp, unobs_logp = pytensor.function([], list(logp.values()))() + + # For non-separable cases the logp always shows up in the observed variable + expected_logp = pm.logp(rv, [[0.1, 0.4, 0.4, 0.1]]).eval() + np.testing.assert_almost_equal(obs_logp, expected_logp) + np.testing.assert_array_equal(unobs_logp, []) + + def test_multivariate_shared_mask_separable(self): + mask = shared(np.array([True])) + obs_data = np.array([[0.1, 0.4, 0.1, 0.4]]) + unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]]) + + rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4)) + (obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask) + + # Test types + # Multivariate RVs with shared masks on the last component are always unseparable. + assert isinstance(obs_rv.owner.op, pm.Dirichlet) + assert isinstance(unobs_rv.owner.op, pm.Dirichlet) + + # Test shapes + assert tuple(obs_rv.shape.eval()) == (0, 4) + assert tuple(unobs_rv.shape.eval()) == (1, 4) + assert tuple(joined_rv.shape.eval()) == (1, 4) + + # Test logp + logp = conditional_logp( + { + obs_rv: pt.as_tensor(obs_data)[obs_mask], + unobs_rv: pt.as_tensor(unobs_data)[unobs_mask], + } + ) + logp_fn = pytensor.function([], list(logp.values())) + obs_logp, unobs_logp = logp_fn() + expected_logp = pm.logp(rv, unobs_data).eval() + np.testing.assert_almost_equal(obs_logp, []) + np.testing.assert_array_equal(unobs_logp, expected_logp) + + # Test that we can update a shared mask + mask.set_value(np.array([False])) + + assert tuple(obs_rv.shape.eval()) == (1, 4) + assert tuple(unobs_rv.shape.eval()) == (0, 4) + + new_expected_logp = pm.logp(rv, obs_data).eval() + assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak + obs_logp, unobs_logp = logp_fn() + np.testing.assert_almost_equal(obs_logp, new_expected_logp) + np.testing.assert_array_equal(unobs_logp, []) + + def test_multivariate_shared_mask_unseparable(self): + # Even if the mask is initially not mixing support dims, + # it could later be changed in a way that does! + mask = shared(np.array([[True, True, True, True]])) + obs_data = np.array([[0.1, 0.4, 0.1, 0.4]]) + unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]]) + + rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4)) + (obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask) + + # Test types + # Multivariate RVs with shared masks on the last component are always unseparable. + assert isinstance(obs_rv.owner.op, PartialObservedRV) + assert isinstance(unobs_rv.owner.op, PartialObservedRV) + + # Test shapes + assert tuple(obs_rv.shape.eval()) == (0,) + assert tuple(unobs_rv.shape.eval()) == (4,) + assert tuple(joined_rv.shape.eval()) == (1, 4) + + # Test logp + logp = conditional_logp( + { + obs_rv: pt.as_tensor(obs_data)[obs_mask], + unobs_rv: pt.as_tensor(unobs_data)[unobs_mask], + } + ) + logp_fn = pytensor.function([], list(logp.values())) + obs_logp, unobs_logp = logp_fn() + # For non-separable cases the logp always shows up in the observed variable + # Even in this case where all entries come from an unobserved component + expected_logp = pm.logp(rv, unobs_data).eval() + np.testing.assert_almost_equal(obs_logp, expected_logp) + np.testing.assert_array_equal(unobs_logp, []) + + # Test that we can update a shared mask + mask.set_value(np.array([[False, False, True, True]])) + + assert tuple(obs_rv.shape.eval()) == (2,) + assert tuple(unobs_rv.shape.eval()) == (2,) + + new_expected_logp = pm.logp(rv, [0.1, 0.4, 0.4, 0.1]).eval() + assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak + obs_logp, unobs_logp = logp_fn() + np.testing.assert_almost_equal(obs_logp, new_expected_logp) + np.testing.assert_array_equal(unobs_logp, []) + + def test_moment(self): + x = pm.GaussianRandomWalk.dist(init_dist=pm.Normal.dist(-5), mu=1, steps=9) + ref_moment = moment(x).eval() + assert not np.allclose(ref_moment[::2], ref_moment[1::2]) # Otherwise test is weak + + (obs_x, _), (unobs_x, _), _ = create_partial_observed_rv( + x, mask=np.array([False, True] * 5) + ) + np.testing.assert_allclose(moment(obs_x).eval(), ref_moment[::2]) + np.testing.assert_allclose(moment(unobs_x).eval(), ref_moment[1::2]) diff --git a/tests/test_model.py b/tests/test_model.py index 10a509ae094..9306b129022 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -41,9 +41,10 @@ from pymc import Deterministic, Potential from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.distributions import Normal, transforms -from pymc.distributions.transforms import log +from pymc.distributions.distribution import PartialObservedRV +from pymc.distributions.transforms import log, simplex from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning -from pymc.logprob.basic import transformed_conditional_logp +from pymc.logprob.basic import conditional_logp, transformed_conditional_logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Point, ValueGradFunction, modelcontext from pymc.testing import SeededTest @@ -1398,32 +1399,43 @@ def test_missing_logp2(self): assert m_logp == m_missing_logp - def test_missing_multivariate(self): - """Test model with missing variables whose transform changes base shape still works""" + def test_missing_multivariate_separable(self): + with pm.Model() as m_miss: + with pytest.warns(ImputationWarning): + x = pm.Dirichlet( + "x", + a=[1, 2, 3], + observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]), + ) + assert (m_miss["x_missing"].owner.op, pm.Dirichlet) + assert (m_miss["x_observed"].owner.op, pm.Dirichlet) + with pm.Model() as m_unobs: + x = pm.Dirichlet("x", a=[1, 2, 3], shape=(1, 3)) + + inp_vals = simplex.forward(np.array([[0.3, 0.3, 0.4]])).eval() + np.testing.assert_allclose( + m_miss.compile_logp(jacobian=False)({"x_missing_simplex__": inp_vals}), + m_unobs.compile_logp(jacobian=False)({"x_simplex__": inp_vals}) * 2, + ) + + def test_missing_multivariate_unseparable(self): with pm.Model() as m_miss: - with pytest.raises( - NotImplementedError, - match="Automatic inputation is only supported for univariate RandomVariables", - ): - with pytest.warns(ImputationWarning): - x = pm.Dirichlet( - "x", - a=[1, 2, 3], - observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]), - ) - - # TODO: Test can be used when local_subtensor_rv_lift supports multivariate distributions - # from pymc.distributions.transforms import simplex - # - # with pm.Model() as m_unobs: - # x = pm.Dirichlet("x", a=[1, 2, 3]) - # - # inp_vals = simplex.forward(np.array([0.3, 0.3, 0.4])).eval() - # assert np.isclose( - # m_miss.compile_logp()({"x_missing_simplex__": inp_vals}), - # m_unobs.compile_logp(jacobian=False)({"x_simplex__": inp_vals}) * 2, - # ) + with pytest.warns(ImputationWarning): + x = pm.Dirichlet( + "x", + a=[1, 2, 3], + observed=np.array([[0.3, 0.3, np.nan], [np.nan, np.nan, 0.4]]), + ) + + assert isinstance(m_miss["x_missing"].owner.op, PartialObservedRV) + assert isinstance(m_miss["x_observed"].owner.op, PartialObservedRV) + + inp_values = np.array([0.3, 0.3, 0.4]) + np.testing.assert_allclose( + m_miss.compile_logp()({"x_missing": [0.4, 0.3, 0.3]}), + st.dirichlet.logpdf(inp_values, [1, 2, 3]) * 2, + ) def test_missing_vector_parameter(self): with pm.Model() as m: @@ -1482,11 +1494,10 @@ def test_dims(self): x = pm.Normal("x", observed=data, dims=("observed",)) assert model.named_vars_to_dims == {"x": ("observed",)} - def test_error_non_random_variable(self): + def test_symbolic_random_variable(self): data = np.array([np.nan] * 3 + [0] * 7) with pm.Model() as model: - msg = "x of type is not supported" - with pytest.raises(NotImplementedError, match=msg): + with pytest.warns(ImputationWarning): x = pm.Censored( "x", pm.Normal.dist(), @@ -1494,6 +1505,10 @@ def test_error_non_random_variable(self): upper=10, observed=data, ) + np.testing.assert_almost_equal( + model.compile_logp()({"x_missing": [0] * 3}), + st.norm.logcdf(0) * 10, + ) class TestShared(SeededTest):