Skip to content

Commit

Permalink
record TrandformedVariable sample (#387)
Browse files Browse the repository at this point in the history
* record TrandformedVariable sample

* make record optional and iterate

* change default record to False

* add suffix

* add tests

* adding another test

* update test for un-named base rv

* add test for un-named base rv of length 2

* make test more compact
  • Loading branch information
sbidari authored Aug 20, 2024
1 parent c96c8ee commit df621e0
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 8 deletions.
22 changes: 19 additions & 3 deletions src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,14 +860,17 @@ def __init__(
self.transforms = transforms
self.validate()

def sample(self, **kwargs) -> tuple:
def sample(self, record=False, **kwargs) -> tuple:
"""
Sample method. Call self.base_rv.sample()
and then apply the transforms specified
in self.transforms.
Parameters
----------
record : bool, optional
Whether to record the value of the deterministic
RandomVariable. Defaults to False.
**kwargs :
Keyword arguments passed to self.base_rv.sample()
Expand All @@ -878,8 +881,7 @@ def sample(self, **kwargs) -> tuple:
"""

untransformed_values = self.base_rv.sample(**kwargs)

return tuple(
transformed_values = tuple(
SampledValue(
t(uv.value),
t_start=self.t_start,
Expand All @@ -888,6 +890,20 @@ def sample(self, **kwargs) -> tuple:
for t, uv in zip(self.transforms, untransformed_values)
)

if record:
if len(untransformed_values) == 1:
numpyro.deterministic(self.name, transformed_values[0].value)
else:
suffixes = (
untransformed_values._fields
if hasattr(untransformed_values, "_fields")
else range(len(transformed_values))
)
for suffix, tv in zip(suffixes, transformed_values):
numpyro.deterministic(f"{self.name}_{suffix}", tv.value)

return transformed_values

def sample_length(self):
"""
Sample length for a transformed
Expand Down
116 changes: 111 additions & 5 deletions src/test/test_transformed_rv_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Tests for TransformedRandomVariable class
"""

from typing import NamedTuple

import jax
import numpyro
import numpyro.distributions as dist
import pytest
Expand All @@ -12,6 +15,7 @@
import pyrenew.transformation as t
from pyrenew.metaclass import (
DistributionalRV,
Model,
RandomVariable,
SampledValue,
TransformedRandomVariable,
Expand All @@ -22,22 +26,23 @@ class LengthTwoRV(RandomVariable):
"""
Class for a RandomVariable
with sample_length 2
and values 1 and 5
"""

def sample(self, **kwargs):
"""
Deterministic sampling method
Sampling method
that returns a length-2 tuple
Returns
-------
tuple
(SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), SampledValue(5, t_start=self.t_start, t_unit=self.t_unit))
(SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(val, t_start=self.t_start, t_unit=self.t_unit))
"""
val = numpyro.sample("my_normal", dist.Normal(0, 1))
return (
SampledValue(1, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(5, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
)

def sample_length(self):
Expand All @@ -62,6 +67,66 @@ def validate(self):
return None


class RVSamples(NamedTuple):
"""
A container to hold the output of `NamedBaseRV()`.
"""

rv1: SampledValue | None = None
rv2: SampledValue | None = None

def __repr__(self):
return f"RVSamples(rv1={self.rv1},rv2={self.rv2})"


class NamedBaseRV(RandomVariable):
"""
Class for a RandomVariable
returning NamedTuples "rv1", and "rv2"
"""

def sample(self, **kwargs):
"""
Sampling method that returns two named tuples
Returns
-------
tuple
(rv1= SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
rv2= SampledValue(val, t_start=self.t_start, t_unit=self.t_unit))
"""
val = numpyro.sample("my_normal", dist.Normal(0, 1))
return RVSamples(
rv1=SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
rv2=SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
)

def validate(self):
"""
No validation.
Returns
-------
None
"""
return None


class MyModel(Model):
"""
Model class to create and run variable name recording
"""

def __init__(self, rv): # numpydoc ignore=GL08
self.rv = rv

def validate(self): # numpydoc ignore=GL08
pass

def sample(self, **kwargs): # numpydoc ignore=GL08
return self.rv(record=True, **kwargs)


def test_transform_rv_validation():
"""
Test that a TransformedRandomVariable validation
Expand Down Expand Up @@ -148,3 +213,44 @@ def test_transforms_applied_at_sampling():
),
(l2_transformed_sample[0].value, l2_transformed_sample[1].value),
)


def test_transforms_variable_naming():
"""
Tests TransformedRandomVariable name
recording is as expected.
"""
transformed_dist_named_base_rv = TransformedRandomVariable(
"transformed_rv",
NamedBaseRV(),
(t.ExpTransform(), t.IdentityTransform()),
)

transformed_dist_unnamed_base_rv = TransformedRandomVariable(
"transformed_rv",
DistributionalRV(name="my_normal", distribution=dist.Normal(0, 1)),
(t.ExpTransform(), t.IdentityTransform()),
)

transformed_dist_unnamed_base_l2_rv = TransformedRandomVariable(
"transformed_rv",
LengthTwoRV(),
(t.ExpTransform(), t.IdentityTransform()),
)

mymodel1 = MyModel(transformed_dist_named_base_rv)
mymodel1.run(num_samples=1, num_warmup=10, rng_key=jax.random.key(4))

assert "transformed_rv_rv1" in mymodel1.mcmc.get_samples()
assert "transformed_rv_rv2" in mymodel1.mcmc.get_samples()

mymodel2 = MyModel(transformed_dist_unnamed_base_rv)
mymodel2.run(num_samples=1, num_warmup=10, rng_key=jax.random.key(5))

assert "transformed_rv" in mymodel2.mcmc.get_samples()

mymodel3 = MyModel(transformed_dist_unnamed_base_l2_rv)
mymodel3.run(num_samples=1, num_warmup=10, rng_key=jax.random.key(4))

assert "transformed_rv_0" in mymodel3.mcmc.get_samples()
assert "transformed_rv_1" in mymodel3.mcmc.get_samples()

0 comments on commit df621e0

Please sign in to comment.