Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite for sum of normal RVs #239

Merged
merged 1 commit into from
Apr 19, 2023

Conversation

larryshamalama
Copy link
Contributor

@larryshamalama larryshamalama commented Mar 7, 2023

Closes 1/3 of #238, for now. As I'm opening this PR, the rewrite works for scalar-valued normal RVs:

import aesara.tensor as at
from aeppl.rewriting import construct_ir_fgraph

import aesara

srng = at.random.RandomStream(0)
X_rv = srng.normal(1, 3)
Y_rv = srng.normal(2, 4)
Z_rv = X_rv + Y_rv

fgraph, _, _ = construct_ir_fgraph({Z_rv: Z_rv.copy()})

aesara.dprint(fgraph.outputs[0])
# (below is subject to change...)
# ValuedVariable [id A]
#  |normal_rv{0, (0, 0), floatX, False}.1 [id B]
#  | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x12EB2D7E0>) [id C]
#  | |TensorConstant{[]} [id D]
#  | |TensorConstant{11} [id E]
#  | |Elemwise{add,no_inplace} [id F]
#  | | |TensorConstant{1} [id G]
#  | | |TensorConstant{2} [id H]
#  | |Elemwise{sqrt,no_inplace} [id I]
#  |   |Elemwise{add,no_inplace} [id J]
#  |     |Elemwise{pow,no_inplace} [id K]
#  |     | |TensorConstant{3} [id L]
#  |     | |TensorConstant{2} [id M]
#  |     |Elemwise{pow,no_inplace} [id N]
#  |       |TensorConstant{4} [id O]
#  |       |TensorConstant{2} [id P]
#  |normal_rv{0, (0, 0), floatX, False}.1 [id B]

I'm happy to receive comments and then address the list below. The main to-do is to extend to matrix-valued normal RVs.

Some questions:

  • Checking independence in AePPL. I believe that this identity would hold for conditional independence. I would have to think about this.
  • What to put for the rng argument in the make_node? For now, I left it as the rng of the first RV input, but this is incorrect.
  • I added the rewrite as an EquilibriumGraphRewriter. Also not sure about this.
  • I added this rewrite in a new file, math_stat.py. Would there be a better name for this file? Or should this be added to an existing file?

Happy to hear any comments!

@larryshamalama larryshamalama marked this pull request as draft March 7, 2023 00:03
@larryshamalama larryshamalama force-pushed the conv-norm-rv branch 2 times, most recently from 590044e to c8ceb6d Compare March 7, 2023 12:15
@codecov
Copy link

codecov bot commented Mar 7, 2023

Codecov Report

Patch coverage: 95.65% and project coverage change: -0.01 ⚠️

Comparison is base (906c10d) 95.76% compared to head (325afb8) 95.76%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #239      +/-   ##
==========================================
- Coverage   95.76%   95.76%   -0.01%     
==========================================
  Files          12       13       +1     
  Lines        2006     2029      +23     
  Branches      243      246       +3     
==========================================
+ Hits         1921     1943      +22     
  Misses         46       46              
- Partials       39       40       +1     
Impacted Files Coverage Δ
aeppl/convolutions.py 95.65% <95.65%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@brandonwillard brandonwillard added enhancement New feature or request graph rewriting Involves the implementation of rewrites to Aesara graphs rv-transforms Involves transforms applied to random variables labels Mar 7, 2023
aeppl/math_stat.py Outdated Show resolved Hide resolved
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good!

I've added some comments regarding the handling of RNG objects, but that's a more general concern that we're trying to address in all of our rewriting work, and not a statement about this specific approach/implementation.

mu_y, sigma_y = Y_rv.owner.inputs[-2:]

new_node = normal.make_node(
X_rv.owner.inputs[0], # temporary rng?
Copy link
Member

@brandonwillard brandonwillard Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tough one we've been trying to deal with more generally. The real complications come from any "updates" that are associated with the RNG objects.

More specifically, RNG objects are usually SharedVariables, and those can have SharedVariable.default_update attributes that hold onto other Variables (i.e. graphs representing the new value of the SharedVariable after each call to a compiled Aesara function with this update). In the case of RNG objects created by RandomStreams, those default updates are the RNG objects output after drawing a sample from a RandomVariable node. In other words, the updates mechanism replaces a shared RNG object with a copy of the RNG object after a sample has been drawn using it, so the Aesara updates mechanism is emulating an in-place update of the RNG (e.g. just as rng.normal() automatically updates the internal state of rng in NumPy).

When performing replacements like this, it's possible that rng = X_rv.owner.inputs[0] will have a rng.default_update containing the original X_rv graph, and, if someone attempted to aesara.compile the resulting new_node, Aesara would pick up that old graph from the re-used RNG's default_update attribute and add it to the compiled results. We definitely don't want to have to sample some unrelated graphs just to update the RNG objects, especially when those update graphs aren't consistent with the underlying sampling process.

Anyway, I'm pointing this out because it's a general design issue and usability complication of which I'm trying to make more people aware—mostly so we can fix it/improve the usability.

In this exact case, we can probably just clone the RNG SharedVariables and add our own default_updates to those (e.g. similar to how RandomStream.gen does). Also, the graphs produced here are only supposed to be used as an intermediate representation for obtaining log-probabilities, so, as long as we don't expect people to actually compile and sample these graphs, the default updates shouldn't matter, and we can probably just clone the RNGs and remove their default updates altogether.

Copy link
Contributor Author

@larryshamalama larryshamalama Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, there's a lot of info here and I have the general picture. Thanks for the explanation.

When you say clone the RNG SharedVariable, do you mean:

new_rng = X_rv.owner.inputs[0].clone()
# something more needs to be done to `new_rng...

new_node = normal.make_node(rng, *other_inputs)
new_node.inputs[0].default_update = new_node

Something is certainly off because x_rv.owner.inputs[0].default_update == x_rv yields False.

I'm also confused because it feels like there are two RNGs here, one from X_rv and one from Y_rv... Or should I create a new RNG object akin to what's being done in RandomStream.gen?

Copy link
Member

@brandonwillard brandonwillard Mar 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you say clone the RNG SharedVariable, do you mean:

Yeah, with one important difference, though:

new_rng = X_rv.owner.inputs[0].clone()
# something more needs to be done to `new_rng...

new_node = normal.make_node(rng, *other_inputs)
new_rng.default_update = new_node.outputs[0]

In other words, it's the new_rng that needs to be updated with the values of the RNGs output by new_node.

I'm also confused because it feels like there are two RNGs here, one from X_rv and one from Y_rv

Which RNG we use is really just a choice. The only thing we need to consider is the user-level seeding, and, as long as we choose an RNG from one of the existing RandomVariables, we should maintain some consistency with the seeding. If we use RandomStream.gen, then we can't generate a connection with the user's seeding unless we have their RandomStream instance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, because this whole RandomVariable + SharedVariable.default_update situation is such a mess, I've created this: aesara-devs/aesara#1478.

tests/test_math_stat.py Outdated Show resolved Hide resolved
tests/test_math_stat.py Outdated Show resolved Hide resolved
tests/test_math_stat.py Outdated Show resolved Hide resolved
aeppl/__init__.py Outdated Show resolved Hide resolved
@brandonwillard
Copy link
Member

brandonwillard commented Mar 15, 2023

This is an elaboration of #239 (comment).

  • What to put for the rng argument in the make_node? For now, I left it as the rng of the first RV input, but this is incorrect.

That's a very important question! Those inputs are usually SharedVariables with non-None SharedVariable.default_update values. More specifically, those values are the RNG objects output by a RandomVariable Op—i.e. they're the updated RNGs obtained after sampling a random variable.

Altogether, the SharedVariable.default_update value of an RNG is a Variable representing a graph that draws samples, and it's intended to serve as a way of specifying something like an in-place update on the RNG state.

Here's an illustration:

import aesara
import aesara.tensor as at


srng = at.random.RandomStream(23092)

X_rv = srng.normal(0, 1, name="X")

aesara.dprint(X_rv, print_default_updates=True)

# normal_rv{0, (0, 0), floatX, False}.1 [id A] 'X'
#  |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B] <- [id A]
#  |TensorConstant{[]} [id C]
#  |TensorConstant{11} [id D]
#  |TensorConstant{0} [id E]
#  |TensorConstant{1} [id F]
#
# Default updates:
#
# normal_rv{0, (0, 0), floatX, False}.0 [id A]
#  |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B] <- [id A]
#  |TensorConstant{[]} [id C]
#  |TensorConstant{11} [id D]
#  |TensorConstant{0} [id E]
#  |TensorConstant{1} [id F]

Notice the .1 at the end of the first normal_rv line; that indicates that the second output (i.e. the sample value) is returned from the RandomVariable node given by id A. The first input to that node is a SharedVariable RNG object (i.e. id B), and its debug print line ends in [id B] <- [id A], which indicates that the SharedVariable is updated by one of the outputs of id A. The default update graph printed at the end shows that the first output (i.e. the .0) of the same id A node is the "update value".

When this graph is aesara.function-compiled, the SharedVariable.default_update graph will be discovered and added as a new output to the entire compiled graph, then, when the compiled function is evaluated, the Linker that executes the code will replace the value in the SharedVariable with the value of that new update output that was added (i.e. the RNG value after sampling).

In pure Python, this whole situation is similar to the following:

from copy import copy
import numpy as np


rng = np.random.default_rng(2309)


def draw_normal(rng):
    new_rng = copy(rng)
    res = new_rng.normal(0, 1)
    return new_rng, res


# Replace the old `rng` with the returned RNG
rng, res = draw_normal(rng)

Now, problems arise when a SharedVariable RNG with a non-None .default_update is used as an input to a different RandomVariable Op, because then the old sampling graph will show up in the compiled results.

This is what happens when we re-use an RNG with default updates for another graph:

# The RNG from `X`
rng_X = X_rv.owner.inputs[0]

# Manually create a `RandomVariable` with a specific RNG
Y_rv = at.random.gamma(0.5, 0.5, name="Y", rng=rng_X)

# Compile the function
Y_rv_fn = aesara.function([], Y_rv)

# View the compiled graph
aesara.dprint(Y_rv_fn)

# gamma_rv{0, (0, 0), floatX, True}.1 [id A] 'Y' 1
#  |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B]
#  |TensorConstant{[]} [id C]
#  |TensorConstant{11} [id D]
#  |TensorConstant{0.5} [id E]
#  |TensorConstant{2.0} [id F]
# normal_rv{0, (0, 0), floatX, False}.0 [id G] 0
#  |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F44C55C1120>) [id B]
#  |TensorConstant{[]} [id C]
#  |TensorConstant{11} [id D]
#  |TensorConstant{0} [id H]
#  |TensorConstant{1} [id I]

As you can see, the compiled graph is sampling a gamma and normal variable, but the normal sampling is only done in order to update the RNG SharedVariable used by the gamma!

If we want to reuse existing RNG SharedVariable objects we need to get rid of the default updates on them. Ultimately, this is a big design complication, because it adds "state" to RandomVariables and—more specifically—ties RNG objects to specific graphs. Some related issues/discussions: aesara-devs/aesara#898, aesara-devs/aesara#738, aesara-devs/aesara#543, aesara-devs/aesara#454, aesara-devs/aesara#1251.

In AePPL, we generally don't sample the IR we produce, so this isn't an immediate problem, but it could easily creep into graphs somewhere down the line and cause real issues.

As far as this rewrite is concerned, we're probably fine copying the SharedVariable and (re)specifying the .default_update graph.

Copy link
Contributor Author

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @brandonwillard for the thorough reply. It took me a while to have an idea of the problem that you are addressing. @rlouf I incorporated your suggestions to the unit tests

As far as this rewrite is concerned, we're probably fine copying the SharedVariable and (re)specifying the .default_update graph.

The only thing that I have yet to fully address is the RNG of the newly created Normal RV node. I gave some details as a comment in this code review

aeppl/convolutions.py Outdated Show resolved Hide resolved
@larryshamalama larryshamalama marked this pull request as ready for review April 9, 2023 01:40
@larryshamalama
Copy link
Contributor Author

@brandonwillard Is this PR close to the end? Should I get started on the remaining rewrites in #238 and build on top of this PR?

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brandonwillard Is this PR close to the end? Should I get started on the remaining rewrites in #238 and build on top of this PR?

As far as the rewrite is concerned, it looks good. We can merge it as-is, but I don't think it will be used because of the rewrite ordering and existing transform-based approach.

We'll need to change the rewrite DB setup in order to allow rewrites like this (i.e. ones that provide more/better distribution information) to come first in place of the generic transforms approach. That concern is independent of this work, though.

aeppl/convolutions.py Outdated Show resolved Hide resolved
@brandonwillard
Copy link
Member

OK, I'm about to push a change that adds support for subtraction and changes the DB to which the rewrite is registered. I'll merge after that.

@brandonwillard
Copy link
Member

This should be good to merge once the tests pass.

Thanks again, @larryshamalama; this was a great addition!

@brandonwillard brandonwillard disabled auto-merge April 19, 2023 00:15
@brandonwillard brandonwillard merged commit 93c83e4 into aesara-devs:main Apr 19, 2023
@larryshamalama larryshamalama deleted the conv-norm-rv branch May 4, 2023 20:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting Involves the implementation of rewrites to Aesara graphs rv-transforms Involves transforms applied to random variables
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants