Replies: 4 comments 12 replies
-
I think the multiple-output/default output issue should be treated distinctly from the seeding issue. It doesn't make sense to refactor the way seeding works just because we are hiding one of the outputs by default and that causes problems (we could always decide to not hide). |
Beta Was this translation helpful? Give feedback.
-
I believe this was mentioned before, but the basic interface and functionality of Unless that's not completely true and I've missed something important, this proposal is almost entirely about adopting JAX's high-level interface and refactoring Speaking of copying the RNG state: one of the reasons |
Beta Was this translation helpful? Give feedback.
-
@rlouf How would seeding update happen across function calls (or within Scan iterations) with your approach? Edit: Also thanks for the detailed write up! |
Beta Was this translation helpful? Give feedback.
-
Following up after a while... This redesign is definitely worthwhile as a means of decoupling RNG state evolution from the process of drawing samples from RNG states, which is valuable at an interface/representation level, so I'm willing to go forward with it for that reason alone. |
Beta Was this translation helpful? Give feedback.
-
Summary
I propose a new approach to represent
RandomVariable
and PRNG states in Aesara’s IR, based on the design of splittable PRNGs. The representation introduces minimal change to the existingRandomVariable
interface while being more expressive. It should be easy to transpile to Aesara’s current compilation target, and is compatible with the higher-levelRandomStream
interface.References
On counter-based PRNGs:
On splittable PRNG design for functional programs:
On the NumPy side:
Desiderata
A good PRNG design satisfies the following conditions:
Motivation behind this proposal
This proposal is motivated by two issues that illustrate the shortcomings of the current representation of
RandomVariable
\s and PRNG states in Aesara:In #1036 the use of
default_output
to hide the PRNG state from the user is causing multiple headaches in the etuplization ofRandomVariable
\s and unification/reification of expressions withRandomVariable
\s. This is the onlyOp
in Aesara that makes use of this property, and to “special-casing” the etuplization logic forRandomVariable
\s often appeared as the easiest solution.In #66 in AeMCMC, expressing expansions like the following convolution of two normal variables is overly complex:
It is indeed not clear what the values of
rng_x
andrng_y
should be given the value ofrng
. A few other application-related shortcomings of the current representation will be given below.Proposal
In the following we focus on the symbolic representation of random variables and PRNG states in Aesara’s IR. We leave discussions about compilation targets and solution to the previous issues for the end.
If we represent the internal state of the PRNG by the type
RandState
(short forRandomStateType
), the current design ofRandomVariable
\s can be summarized by the following simplified signature:In other words,
RandomVariable
\s are responsible for both advancing the state of the PRNG, and producing a random value. This double responsibility is what creates graph dependencies between nodes that have otherwise no data dependency. The following snippet illustrates this:As we can see in the graph representation,
rng_x
(id C) is being used as an input toy
andrng_y
(id I) is being used as an input toz
. There is however no data dependency betweenx
,y
orz
. The intuition that they should not be linked is probably what led to “hiding” these PRNG state outputs so they are not re-used, and theRandomStream
interface.Creating spurious sequential dependencies by threading PRNG states is indeed unsatisfactory from a representation perspective, and unnecessarily complicates the rewrites. It is also problematic for two other reasons:
A natural idea is to simplify the design of
RandomVariable
\s so that it is only responsible for one thing: generating a random value from a PRNG state. TheOp
thus creates anApply
node that takes aRandState
(using the above notation) as input and outputs a (random)Variable
:Providing a
RandState
to aRandomVariable
needs to intentional, and this must be reflected in the user interface. We thus makerng
an explicit input of theRandomVariable
’s__call__
method. This way a user can write:Or, if they want the PRNG state to be shared (silly example, but a legitimate need):
This interface presupposes the existence of two operators. First, to build reproducible programs, we need an operator that creates a
RandState
from a seed, which can be the constructor ofRandState
itself:And then, we need another operator that creates an updated
RandomState
from aRandomState
, so thatRandomVariable
\s created with these two different states would output different numbers. Let’s call itnext
:We can thus fill in the blanks in the previous code examples:
The code has been specifically formatted to illustrate what we gain from this approach.
x_rv
,y_rv
andz_rv
have lost their direct dependency; we could easily execute these three statements in parallel. What we have done implicitly is to create two graphs: the graph between random variables which reflects the dependencies (or lack thereof) on each other’s values, and the graph of the updates of the PRNG states. These graphs almost evolve in parallel.This is similat to what I understand the
RandomStream
interface does: moving the updates of the PRNG states to theupdate
graphs generated by Aesara’s shared variables.The
next
operator is however not completely satisfactory. Let us consider a more complex situation, wherecall
is a function that requires aRandomState
:We can easily find an implementation of
call
that makes the previous code generate a random state collision:To avoid this kind of issues, we must thus require user-defined functions to return the last PRNG state along the result:
Threading PRNG state is still necessary to guarantee correctness and the two
call
functions cannot be called in parallel. The issue arises because, even though we have separated PRNG state update and random value generation, our symbolic structure is still sequential: eachRandState
has one and only one ancestor. We can of course circumvent this issue knowing how many timesnext
is called within the function, by “jumping” the same number of times to obtainrng_y
, but this can quickly become complex (what ifcall
is imported from somewhere else?).It would make things easier if a
RandState
\s could have several children, and if each of these child led to separate streams of random number. Let us define the followingsplit
operator:We require that we can never get the same
RandState
by callingsplit
any number of times on either the left or right returned state. In other words,split
should implicitly defines a binary tree in which all the nodes are unique. This can be easily represented by lettingRandState
holding a number in binary format. The leftmost child state is obtained by appending0
to the parent’s state and the rightmost child state by appending1
:If the generator called by
RandomVariable
can be made a deterministic function of this binary value, the computations are fully reproducible. We added akey
attribute that can be specified by the user at initialization to seed the PRNG state. The tree structure is of course explicit in our graph representation, sincel
andr
depend onrng
via thesplit
operator. Nevertheless, we can increment this internal state when building the graph in a way that allows us to compile without traversing the graph.The
next
operator we previously defined becomes redundant within this representation. Since its interaction with thesplit
operator would require careful thought we leave it aside in the following. Using the new operator our toy example becomes:Note that the “main” sub-graph that contains random variables, and the PRNG sub-graph are still minimally connected.
Finally, it is also natural to implement the
splitn
operator represented by:So we can write the following code:
Implementation
When it comes to practical implementations, this representation is only convenient for counter-based PRNGs like
Philox
implemented in NumPy: we generate a pair of(key, counter)
from ourRandState
\s and pass these as an input to the generator.RandState
andsplit
implementationThe mock implementation of$\mathcal{S}$ of real PRNGs does not usually extend indefinitely. In practice we will need to compress the state using a hashing function that also increments the
RandState
andsplit
above is naive in the sense that the counter spacekey
. To be immediately compatible with NumPy in theperform
function we can use Philox’s hash function to update the state as we build the graph. Since the hash is deterministic we can still walk theRandState
tree in our representation and cheaply recompute the states should we need to.Op and Variable implementations to come.
RandomVariable
The modifications to
RandomVariable
Ops are minimal:__call__
now takes aRandState
as a positional argument;make_node
only returnsout_var
. Thedefault_output
attribute is not needed anymore.RandomStream
We can keep the
RandomStream
API, use a shared variable to hold theRandState
and handle the splitting internally. The RNG sub-graphs are now found in the updates’ graph.In a second time we may consider instantiating
RandState
as shared variables by default to decouple both the random variable and the PRNG state graphs. I am not sure of the tradeoffs here, but it may alleviate concerns related to graph rewrites.Compilation
It is essential that our representation of PRNG states and
RandomVariable
\s in the graph can be easily transpiled to the existing targets (C, Numba, JAX) and future targets. In the following I outline the transpilation process for the current targets.Numba
After #1245 Aesara will support NumPy’s Generator API. Furthermore NumPy has support for Philox as a BitGenerator, a counter-based PRNG which can easily accomodate splittable PRNG representations. Assuming we can map each path in the PRNG graph to a
(key, counter)
tuple, the transpilation ofRandomStream
\s using the PhiloxBitGenerator
should be straighforward. For the explicit splitting interface, we can directly translate theRandomVariable
\s to NumPyGenerator
\s and seed these generators at compile time. So that:Becomes:
JAX
Transpilation to JAX would be straightforward, as JAX uses a splittable PRNG representation. We will simply need to perform the following substitutions:
Back to the motivating issues
The problems linked to the existence of the
default_output
attribute disappear sinceRandomVariable
\s do not return PRNG states anymore. The one-to-many difficulty we are facing with the relations between etuplized graphs also disappears with asplit
operator. Using the example from the beginning we can for instance write:Which is guaranteed to be collision-free by construction of the
split
operator, as long as the PRNG state used by the original normal distribution isn’t passed to asplit
operator somewhere else in the original graph (todo: specify API requirements to guarantee uniqueness of the random numbers).Related
Related, not necessarily a justification for this approach, discussions across the Ae* projects:
OpFromGraph
and updatesBeta Was this translation helpful? Give feedback.
All reactions