diff --git a/aeppl/__init__.py b/aeppl/__init__.py index e0b39c5c..17ccf61f 100644 --- a/aeppl/__init__.py +++ b/aeppl/__init__.py @@ -12,6 +12,7 @@ # isort: off # Add rewrites to the DBs import aeppl.censoring +import aeppl.convolutions import aeppl.cumsum import aeppl.mixture import aeppl.scan diff --git a/aeppl/convolutions.py b/aeppl/convolutions.py new file mode 100644 index 00000000..2d66f660 --- /dev/null +++ b/aeppl/convolutions.py @@ -0,0 +1,53 @@ +import aesara +import aesara.tensor as at +from aesara.graph.rewriting.basic import EquilibriumGraphRewriter, node_rewriter +from aesara.scalar.basic import Add +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.random.basic import NormalRV, normal + +from aeppl.rewriting import logprob_rewrites_db + + +@node_rewriter((Elemwise,)) +def add_independent_normals(fgraph, node): + if not isinstance(node.op.scalar_op, Add): + return None + + X_rv, Y_rv = node.inputs + + if not (X_rv.owner and Y_rv.owner) or not ( + isinstance(X_rv.owner.op, NormalRV) and isinstance(Y_rv.owner.op, NormalRV) + ): + return None + + old_rv = node.outputs[0] + + mu_x, sigma_x, mu_y, sigma_y, _ = at.broadcast_arrays( + *(X_rv.owner.inputs[-2:] + Y_rv.owner.inputs[-2:] + [old_rv]) + ) + + rng = X_rv.owner.inputs[0] + + new_node = normal.make_node( + rng, # temporary rng? + old_rv.shape, + old_rv.dtype, + mu_x + mu_y, + at.sqrt(sigma_x**2 + sigma_y**2), + ) + new_normal_rv = new_node.default_output() + + if old_rv.name: + new_normal_rv.name = old_rv.name + + return [new_normal_rv] + + +logprob_rewrites_db.register( + "add_independent_normals", + EquilibriumGraphRewriter( + [add_independent_normals], + max_use_ratio=aesara.config.optdb__max_use_ratio, + ), + "basic", +) diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py new file mode 100644 index 00000000..45b3266d --- /dev/null +++ b/tests/test_convolutions.py @@ -0,0 +1,96 @@ +import aesara.tensor as at +import numpy as np +import pytest +from aesara.tensor.random.basic import NormalRV + +from aeppl.rewriting import construct_ir_fgraph + + +@pytest.mark.parametrize( + "mu_x, mu_y, x_shape, y_shape, new_rv_mu", + [ + ( + np.array([1, 10, 100]), + np.array(2), + None, + None, + np.array([3, 12, 102]), + ), + (np.array([1, 10, 100]), np.array(2), (), (), np.array([3, 12, 102])), + ( + np.array([1, 10, 100]), + np.array(2), + (5, 3), + (), + np.broadcast_to(np.array([3, 12, 102]), (5, 3)), + ), + ( + np.array([[1, 10, 100]]), + np.array([[0.2], [2], [20], [200], [2000]]), + None, + None, + np.array( + [ + [1.2, 10.2, 100.2], + [3, 12, 102], + [21, 30, 120], + [201, 210, 300], + [2001, 2010, 2100], + ] + ), + ), + ( + np.broadcast_to(np.array([1, 10, 100]), (5, 3)), + np.array([2, 20, 200]), + (2, 5, 3), + None, + np.broadcast_to(np.array([3, 30, 300]), (2, 5, 3)), + ), + ( + np.array([[1, 10, 100]]), + np.array([[0.2], [2], [20], [200], [2000]]), + (2, 5, 3), + None, + np.broadcast_to( + np.array( + [ + [1.2, 10.2, 100.2], + [3, 12, 102], + [21, 30, 120], + [201, 210, 300], + [2001, 2010, 2100], + ] + ), + (2, 5, 3), + ), + ), + ( + np.array(1), + np.array(6), + (5, 1), + (1,), + np.full((5, 3), 7), + ), + ], +) +def test_add_independent_normals(mu_x, mu_y, x_shape, y_shape, new_rv_mu): + srng = at.random.RandomStream(29833) + + X_rv = srng.normal(mu_x, 0.03, size=x_shape) + X_rv.name = "X" + + Y_rv = srng.normal(mu_y, 0.04, size=y_shape) + Y_rv.name = "Y" + + Z_rv = X_rv + Y_rv + Z_rv.name = "Z" + z_vv = Z_rv.clone() + + fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv}) + + new_rv = fgraph.outputs[0].owner.inputs[0] + + assert isinstance(new_rv.owner.op, NormalRV) + assert np.allclose(new_rv.owner.inputs[3].eval(), new_rv_mu) + assert np.allclose(new_rv.owner.inputs[4].eval(), 0.05) + assert new_rv.name == "Z" diff --git a/tests/test_transforms.py b/tests/test_transforms.py index a7ff1fcd..650dd269 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -649,15 +649,6 @@ def test_transformed_rv_and_value(): ) -def test_loc_transform_multiple_rvs_fails1(): - x_rv1 = at.random.normal(name="x_rv1") - x_rv2 = at.random.normal(name="x_rv2") - y_rv = x_rv1 + x_rv2 - - with pytest.raises(DensityNotFound): - joint_logprob(y_rv) - - def test_nested_loc_transform_multiple_rvs_fails2(): x_rv1 = at.random.normal(name="x_rv1") x_rv2 = at.cos(at.random.normal(name="x_rv2"))