diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index d9eb85a838..6eed6a1b35 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -50,13 +50,20 @@ from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2 from pytensor.scan.utils import ScanArgs from pytensor.tensor.random.type import RandomType +from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor from pytensor.tensor.var import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob +from pymc.logprob.abstract import ( + MeasurableVariable, + _get_measurable_outputs, + _logprob, + get_measurable_outputs, +) from pymc.logprob.joint_logprob import factorized_joint_logprob from pymc.logprob.rewriting import ( + PreserveRVMappings, inc_subtensor_ops, logprob_rewrites_db, measurable_ir_rewrites_db, @@ -66,6 +73,9 @@ class MeasurableScan(Scan): """A placeholder used to specify a log-likelihood for a scan sub-graph.""" + def __str__(self): + return f"Measurable({super().__str__()})" + MeasurableVariable.register(MeasurableScan) @@ -359,6 +369,12 @@ def find_measurable_scans(fgraph, node): ) for n in local_fgraph_topo: if isinstance(n.op, MeasurableVariable): + measurable_outputs = get_measurable_outputs(n.op, n) + # This variable's source of measure is used by another inner node, + # So we don't need it to be an output! + if not measurable_outputs: + continue + non_output_node_clients = [ c for c in clients[n] if c not in curr_scanargs.inner_outputs ] @@ -494,6 +510,10 @@ def add_opts_to_inner_graphs(fgraph, node): clone=True, copy_inputs=False, copy_orphans=False, + features=[ + ShapeFeature(), + PreserveRVMappings({}), + ], ) logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite(inner_fgraph) diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 0b4577b1a9..551bb51471 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -42,9 +42,10 @@ from pytensor import Mode from pytensor.raise_op import assert_op from pytensor.scan.utils import ScanArgs +from scipy import stats from pymc.logprob.abstract import logprob -from pymc.logprob.joint_logprob import factorized_joint_logprob +from pymc.logprob.joint_logprob import factorized_joint_logprob, logp from pymc.logprob.scan import ( construct_scan, convert_outer_out_to_in, @@ -458,3 +459,22 @@ def test_mode_is_kept(remove_asserts): else: with pytest.raises(AssertionError): x_logp(x=x_test_val) + + +def test_scan_non_pure_rv_output(): + grw, _ = pytensor.scan( + fn=lambda xtm1: at.random.normal() + xtm1, + outputs_info=[at.zeros(())], + n_steps=10, + name="grw", + ) + + grw_vv = grw.clone() + grw_logp = logp(grw, grw_vv) + assert_no_rvs(grw_logp) + + grw_vv_test = np.arange(10) + 1 + np.testing.assert_array_almost_equal( + grw_logp.eval({grw_vv: grw_vv_test}), + stats.norm.logpdf(np.ones(10)), + )