Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Scan logprob inference of non-pure RandomVariable outputs #6578

Merged
merged 1 commit into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,20 @@
from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2
from pytensor.scan.utils import ScanArgs
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
from pytensor.tensor.var import TensorVariable
from pytensor.updates import OrderedUpdates

from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
from pymc.logprob.abstract import (
MeasurableVariable,
_get_measurable_outputs,
_logprob,
get_measurable_outputs,
)
from pymc.logprob.joint_logprob import factorized_joint_logprob
from pymc.logprob.rewriting import (
PreserveRVMappings,
inc_subtensor_ops,
logprob_rewrites_db,
measurable_ir_rewrites_db,
Expand All @@ -66,6 +73,9 @@
class MeasurableScan(Scan):
"""A placeholder used to specify a log-likelihood for a scan sub-graph."""

def __str__(self):
return f"Measurable({super().__str__()})"


MeasurableVariable.register(MeasurableScan)

Expand Down Expand Up @@ -359,6 +369,12 @@ def find_measurable_scans(fgraph, node):
)
for n in local_fgraph_topo:
if isinstance(n.op, MeasurableVariable):
measurable_outputs = get_measurable_outputs(n.op, n)
# This variable's source of measure is used by another inner node,
# So we don't need it to be an output!
if not measurable_outputs:
continue

non_output_node_clients = [
c for c in clients[n] if c not in curr_scanargs.inner_outputs
]
Expand Down Expand Up @@ -494,6 +510,10 @@ def add_opts_to_inner_graphs(fgraph, node):
clone=True,
copy_inputs=False,
copy_orphans=False,
features=[
ShapeFeature(),
PreserveRVMappings({}),
],
)

logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite(inner_fgraph)
Expand Down
22 changes: 21 additions & 1 deletion tests/logprob/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
from pytensor import Mode
from pytensor.raise_op import assert_op
from pytensor.scan.utils import ScanArgs
from scipy import stats

from pymc.logprob.abstract import logprob
from pymc.logprob.joint_logprob import factorized_joint_logprob
from pymc.logprob.joint_logprob import factorized_joint_logprob, logp
from pymc.logprob.scan import (
construct_scan,
convert_outer_out_to_in,
Expand Down Expand Up @@ -458,3 +459,22 @@ def test_mode_is_kept(remove_asserts):
else:
with pytest.raises(AssertionError):
x_logp(x=x_test_val)


def test_scan_non_pure_rv_output():
grw, _ = pytensor.scan(
fn=lambda xtm1: at.random.normal() + xtm1,
outputs_info=[at.zeros(())],
n_steps=10,
name="grw",
)

grw_vv = grw.clone()
grw_logp = logp(grw, grw_vv)
assert_no_rvs(grw_logp)

grw_vv_test = np.arange(10) + 1
np.testing.assert_array_almost_equal(
grw_logp.eval({grw_vv: grw_vv_test}),
stats.norm.logpdf(np.ones(10)),
)