Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

record TrandformedVariable sample #387

Merged
merged 13 commits into from
Aug 20, 2024
21 changes: 18 additions & 3 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,14 +859,17 @@ def __init__(
self.transforms = transforms
self.validate()

def sample(self, **kwargs) -> tuple:
def sample(self, record=False, **kwargs) -> tuple:
"""
Sample method. Call self.base_rv.sample()
and then apply the transforms specified
in self.transforms.

Parameters
----------
record : bool, optional
Whether to record the value of the deterministic
RandomVariable. Defaults to False.
**kwargs :
Keyword arguments passed to self.base_rv.sample()

Expand All @@ -877,8 +880,7 @@ def sample(self, **kwargs) -> tuple:
"""

untransformed_values = self.base_rv.sample(**kwargs)

return tuple(
transformed_values = tuple(
SampledValue(
t(uv.value),
t_start=self.t_start,
Expand All @@ -887,6 +889,19 @@ def sample(self, **kwargs) -> tuple:
for t, uv in zip(self.transforms, untransformed_values)
)

if record:
if hasattr(untransformed_values, "_fields"):
for i, tv in enumerate(transformed_values):
suffix = untransformed_values._fields[i]
numpyro.deterministic(f"{self.name}_{suffix}", tv.value)
elif len(untransformed_values) == 1:
numpyro.deterministic(self.name, transformed_values[0].value)
else:
for i, tv in enumerate(transformed_values):
numpyro.deterministic(f"{self.name}_{i}", tv.value)

return transformed_values

def sample_length(self):
"""
Sample length for a transformed
Expand Down
116 changes: 111 additions & 5 deletions model/src/test/test_transformed_rv_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
Tests for TransformedRandomVariable class
"""

from typing import NamedTuple

import jax
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
import pytest
from numpy.testing import assert_almost_equal
from pyrenew.metaclass import (
DistributionalRV,
Model,
RandomVariable,
SampledValue,
TransformedRandomVariable,
Expand All @@ -21,22 +25,23 @@ class LengthTwoRV(RandomVariable):
"""
Class for a RandomVariable
with sample_length 2
and values 1 and 5
"""

def sample(self, **kwargs):
"""
Deterministic sampling method
Sampling method
that returns a length-2 tuple

Returns
-------
tuple
(SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), SampledValue(5, t_start=self.t_start, t_unit=self.t_unit))
(SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(val, t_start=self.t_start, t_unit=self.t_unit))
"""
val = numpyro.sample("my_normal", dist.Normal(0, 1))
return (
SampledValue(1, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(5, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
)

def sample_length(self):
Expand All @@ -61,6 +66,66 @@ def validate(self):
return None


class RVSamples(NamedTuple):
"""
A container to hold the output of `NamedBaseRV()`.
"""

rv1: SampledValue | None = None
rv2: SampledValue | None = None

def __repr__(self):
return f"RVSamples(rv1={self.rv1},rv2={self.rv2})"


class NamedBaseRV(RandomVariable):
"""
Class for a RandomVariable
returning NamedTuples "rv1", and "rv2"
"""

def sample(self, **kwargs):
"""
Sampling method that returns two named tuples

Returns
-------
tuple
(rv1= SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
rv2= SampledValue(val, t_start=self.t_start, t_unit=self.t_unit))
"""
val = numpyro.sample("my_normal", dist.Normal(0, 1))
return RVSamples(
rv1=SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
rv2=SampledValue(val, t_start=self.t_start, t_unit=self.t_unit),
)

def validate(self):
"""
No validation.

Returns
-------
None
"""
return None


class MyModel(Model):
"""
Model class to create and run variable name recording
"""

def __init__(self, rv): # numpydoc ignore=GL08
self.rv = rv

def validate(self): # numpydoc ignore=GL08
pass

def sample(self, **kwargs): # numpydoc ignore=GL08
return self.rv(record=True, **kwargs)


def test_transform_rv_validation():
"""
Test that a TransformedRandomVariable validation
Expand Down Expand Up @@ -147,3 +212,44 @@ def test_transforms_applied_at_sampling():
),
(l2_transformed_sample[0].value, l2_transformed_sample[1].value),
)


def test_transforms_variable_naming():
"""
Tests TransformedRandomVariable name
recording is as expected.
"""
transformed_dist_named_base_rv = TransformedRandomVariable(
"transformed_rv",
NamedBaseRV(),
(t.ExpTransform(), t.IdentityTransform()),
)

transformed_dist_unnamed_base_rv = TransformedRandomVariable(
"transformed_rv",
DistributionalRV(name="my_normal", distribution=dist.Normal(0, 1)),
(t.ExpTransform(), t.IdentityTransform()),
)

transformed_dist_unnamed_base_l2_rv = TransformedRandomVariable(
"transformed_rv",
LengthTwoRV(),
(t.ExpTransform(), t.IdentityTransform()),
)

mymodel1 = MyModel(transformed_dist_named_base_rv)
mymodel1.run(num_samples=1, num_warmup=10, rng_key=jax.random.key(4))

assert "transformed_rv_rv1" in mymodel1.mcmc.get_samples()
assert "transformed_rv_rv2" in mymodel1.mcmc.get_samples()

mymodel2 = MyModel(transformed_dist_unnamed_base_rv)
mymodel2.run(num_samples=1, num_warmup=10, rng_key=jax.random.key(5))

assert "transformed_rv" in mymodel2.mcmc.get_samples()

mymodel3 = MyModel(transformed_dist_unnamed_base_l2_rv)
mymodel3.run(num_samples=1, num_warmup=10, rng_key=jax.random.key(4))

assert "transformed_rv_0" in mymodel3.mcmc.get_samples()
assert "transformed_rv_1" in mymodel3.mcmc.get_samples()