Skip to content

Commit

Permalink
Add rewrite for sum of normal RVs
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama committed Mar 10, 2023
1 parent 1f341ce commit cb2bf14
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 9 deletions.
1 change: 1 addition & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# isort: off
# Add rewrites to the DBs
import aeppl.censoring
import aeppl.convolutions
import aeppl.cumsum
import aeppl.mixture
import aeppl.scan
Expand Down
53 changes: 53 additions & 0 deletions aeppl/convolutions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import aesara
import aesara.tensor as at
from aesara.graph.rewriting.basic import EquilibriumGraphRewriter, node_rewriter
from aesara.scalar.basic import Add
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.basic import NormalRV, normal

from aeppl.rewriting import logprob_rewrites_db


@node_rewriter((Elemwise,))
def add_independent_normals(fgraph, node):
if not isinstance(node.op.scalar_op, Add):
return None

X_rv, Y_rv = node.inputs

if not (X_rv.owner and Y_rv.owner) or not (
isinstance(X_rv.owner.op, NormalRV) and isinstance(Y_rv.owner.op, NormalRV)
):
return None

old_rv = node.outputs[0]

mu_x, sigma_x, mu_y, sigma_y, _ = at.broadcast_arrays(
*(X_rv.owner.inputs[-2:] + Y_rv.owner.inputs[-2:] + [old_rv])
)

rng = X_rv.owner.inputs[0]

new_node = normal.make_node(
rng, # temporary rng?
old_rv.shape,
old_rv.dtype,
mu_x + mu_y,
at.sqrt(sigma_x**2 + sigma_y**2),
)
new_normal_rv = new_node.default_output()

if old_rv.name:
new_normal_rv.name = old_rv.name

return [new_normal_rv]


logprob_rewrites_db.register(
"add_independent_normals",
EquilibriumGraphRewriter(
[add_independent_normals],
max_use_ratio=aesara.config.optdb__max_use_ratio,
),
"basic",
)
96 changes: 96 additions & 0 deletions tests/test_convolutions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import aesara.tensor as at
import numpy as np
import pytest
from aesara.tensor.random.basic import NormalRV

from aeppl.rewriting import construct_ir_fgraph


@pytest.mark.parametrize(
"mu_x, mu_y, x_shape, y_shape, new_rv_mu",
[
(
np.array([1, 10, 100]),
np.array(2),
None,
None,
np.array([3, 12, 102]),
),
(np.array([1, 10, 100]), np.array(2), (), (), np.array([3, 12, 102])),
(
np.array([1, 10, 100]),
np.array(2),
(5, 3),
(),
np.broadcast_to(np.array([3, 12, 102]), (5, 3)),
),
(
np.array([[1, 10, 100]]),
np.array([[0.2], [2], [20], [200], [2000]]),
None,
None,
np.array(
[
[1.2, 10.2, 100.2],
[3, 12, 102],
[21, 30, 120],
[201, 210, 300],
[2001, 2010, 2100],
]
),
),
(
np.broadcast_to(np.array([1, 10, 100]), (5, 3)),
np.array([2, 20, 200]),
(2, 5, 3),
None,
np.broadcast_to(np.array([3, 30, 300]), (2, 5, 3)),
),
(
np.array([[1, 10, 100]]),
np.array([[0.2], [2], [20], [200], [2000]]),
(2, 5, 3),
None,
np.broadcast_to(
np.array(
[
[1.2, 10.2, 100.2],
[3, 12, 102],
[21, 30, 120],
[201, 210, 300],
[2001, 2010, 2100],
]
),
(2, 5, 3),
),
),
(
np.array(1),
np.array(6),
(5, 1),
(1,),
np.full((5, 3), 7),
),
],
)
def test_add_independent_normals(mu_x, mu_y, x_shape, y_shape, new_rv_mu):
srng = at.random.RandomStream(29833)

X_rv = srng.normal(mu_x, 0.03, size=x_shape)
X_rv.name = "X"

Y_rv = srng.normal(mu_y, 0.04, size=y_shape)
Y_rv.name = "Y"

Z_rv = X_rv + Y_rv
Z_rv.name = "Z"
z_vv = Z_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv})

new_rv = fgraph.outputs[0].owner.inputs[0]

assert isinstance(new_rv.owner.op, NormalRV)
assert np.allclose(new_rv.owner.inputs[3].eval(), new_rv_mu)
assert np.allclose(new_rv.owner.inputs[4].eval(), 0.05)
assert new_rv.name == "Z"
9 changes: 0 additions & 9 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,15 +649,6 @@ def test_transformed_rv_and_value():
)


def test_loc_transform_multiple_rvs_fails1():
x_rv1 = at.random.normal(name="x_rv1")
x_rv2 = at.random.normal(name="x_rv2")
y_rv = x_rv1 + x_rv2

with pytest.raises(DensityNotFound):
joint_logprob(y_rv)


def test_nested_loc_transform_multiple_rvs_fails2():
x_rv1 = at.random.normal(name="x_rv1")
x_rv2 = at.cos(at.random.normal(name="x_rv2"))
Expand Down

0 comments on commit cb2bf14

Please sign in to comment.