Skip to content

Commit

Permalink
Add __call__ As Alias For RandomVariable sample Method (#253)
Browse files Browse the repository at this point in the history
* add __call__ alias for sample

* non-tutorial sample() call replacements

* tutorial description sample() changes

* further non-tutorial replacements of sample(), where appropriate

* tutorial edit sample()

* catch a few more

---------

Co-authored-by: damonbayer <xum8@cdc.gov>
  • Loading branch information
AFg6K7h4fhy2 and damonbayer authored Jul 16, 2024
1 parent 7bd2250 commit fb8a8a4
Show file tree
Hide file tree
Showing 32 changed files with 101 additions and 90 deletions.
10 changes: 5 additions & 5 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ class InfFeedback(RandomVariable):
I0_vec = I0[-gen_int_rev.size :]
# Sampling inf feedback strength and adjusting the shape
inf_feedback_strength, *_ = self.infection_feedback_strength.sample(
inf_feedback_strength, *_ = self.infection_feedback_strength(
**kwargs,
)
inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0]
)
# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs)
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)
# Generating the infections with feedback
Expand All @@ -230,13 +230,13 @@ class InfFeedback(RandomVariable):

The core of the class is implemented in the `sample()` method. Things to highlight from the above code:

1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback.sample()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable.sample()` calls are expected to include the `**kwargs` argument, even if unused.
1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable()` calls are expected to include the `**kwargs` argument, even if unused.

2. **Calls to `RandomVariable.sample()`**: All calls to `RandomVariable.sample()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength.sample()` and `infection_feedback_pmf.sample()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).
2. **Calls to `RandomVariable()`**: All calls to `RandomVariable()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength()` and `infection_feedback_pmf()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback.sample()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.
4. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

```{python}
# | label: simulation2
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ rt_proc = process.RtWeeklyDiffProcess(

```{python}
with npro.handlers.seed(rng_seed=20):
sim_data = rt_proc.sample(duration=30)
sim_data = rt_proc(duration=30)
# Plotting the Rt values
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -84,7 +84,7 @@ Like before, we can use the `sample` method to generate samples from the day of

```{python}
with npro.handlers.seed(rng_seed=20):
sim_data = dayofweek.sample(duration=30)
sim_data = dayofweek(duration=30)
# Plotting the effect values
import matplotlib.pyplot as plt
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ import numpyro.distributions as dist
from pyrenew.process import SimpleRandomWalkProcess
```

To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q.sample(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance.
To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance.

```{python}
# | label: fig-randwalk
# | fig-cap: Random walk example
np.random.seed(3312)
q = SimpleRandomWalkProcess(dist.Normal(0, 0.001))
with seed(rng_seed=np.random.randint(0, 1000)):
q_samp = q.sample(n_timepoints=100)
q_samp = q(n_timepoints=100)
plt.plot(np.exp(q_samp[0]))
```
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The tuple `(t_unit, t_start)` can encode different types of time series data. Fo

## How it relates to periodicity

The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect.sample()` and `RtPeriodicDiffProcess.sample()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.
The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.

## Unimplemented features

Expand All @@ -37,7 +37,7 @@ With random variables possibly spanning different time scales, *e.g.*, weekly, d

### Array alignment

Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel.sample()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:
Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:

```python
Rt_aligned, infections_aligned = align([Rt, infections])
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def pad_x_to_match_y(

class PeriodicProcessSample(NamedTuple):
"""
A container for holding the output from `process.PeriodicProcess.sample()`.
A container for holding the output from `process.PeriodicProcess()`.
Attributes
----------
Expand Down
11 changes: 5 additions & 6 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class HospitalAdmissionsSample(NamedTuple):
"""
A container to hold the output of `latent.HospAdmissions.sample()`.
A container to hold the output of `latent.HospAdmissions()`.
Attributes
----------
Expand Down Expand Up @@ -173,14 +173,14 @@ def sample(
HospitalAdmissionsSample
"""

infection_hosp_rate, *_ = self.infect_hosp_rate_rv.sample(**kwargs)
infection_hosp_rate, *_ = self.infect_hosp_rate_rv(**kwargs)

infection_hosp_rate_t = infection_hosp_rate * latent_infections

(
infection_to_admission_interval,
*_,
) = self.infection_to_admission_interval_rv.sample(**kwargs)
) = self.infection_to_admission_interval_rv(**kwargs)

latent_hospital_admissions = jnp.convolve(
infection_hosp_rate_t,
Expand All @@ -191,13 +191,12 @@ def sample(
# Applying the day of the week effect
latent_hospital_admissions = (
latent_hospital_admissions
* self.day_of_week_effect_rv.sample(**kwargs)[0]
* self.day_of_week_effect_rv(**kwargs)[0]
)

# Applying probability of hospitalization effect
latent_hospital_admissions = (
latent_hospital_admissions
* self.hosp_report_prob_rv.sample(**kwargs)[0]
latent_hospital_admissions * self.hosp_report_prob_rv(**kwargs)[0]
)

npro.deterministic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def seed_infections(self, I_pre_seed: ArrayLike):
raise ValueError(
f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}."
)
(rate,) = self.rate.sample()
(rate,) = self.rate()
if rate.size != 1:
raise ValueError(
f"rate must be an array of size 1. Got size {rate.size}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def sample(self) -> tuple:
tuple
a tuple where the only element is an array with the number of seeded infections at each time point.
"""
(I_pre_seed,) = self.I_pre_seed_rv.sample()
(I_pre_seed,) = self.I_pre_seed_rv()
infection_seeding = self.infection_seed_method(I_pre_seed)
npro.deterministic(self.name, infection_seeding)

Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class InfectionsSample(NamedTuple):
"""
A container for holding the output from `latent.Infections.sample()`.
A container for holding the output from `latent.Infections()`.
Attributes
----------
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def sample(
I0 = I0[-gen_int_rev.size :]

# Sampling inf feedback strength
inf_feedback_strength, *_ = self.infection_feedback_strength.sample(
inf_feedback_strength, *_ = self.infection_feedback_strength(
**kwargs,
)

Expand All @@ -175,7 +175,7 @@ def sample(
)

# Sampling inf feedback pmf
inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs)
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)

inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)

Expand Down
12 changes: 9 additions & 3 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _assert_sample_and_rtype(
"""
Return type-checking for RandomVariable's sample function
Objects passed as `RandomVariable` should (a) have a sample() method that
Objects passed as `RandomVariable` should (a) have a `sample()` method that
(b) returns either a tuple or a named tuple.
Parameters
Expand Down Expand Up @@ -103,7 +103,7 @@ class RandomVariable(metaclass=ABCMeta):
are expected to be used internally mostly for tasks including padding,
alignment of time series, and other time-aware operations.
Both attributes give information about the output of the sample() method,
Both attributes give information about the output of the `sample()` method,
in other words, the relative time units of the returning value.
Attributes
Expand Down Expand Up @@ -138,7 +138,7 @@ def set_timeseries(
t_start : int
The start of the time series relative to the
model time. It could be negative, indicating
that the sample() method returns timepoints
that the `sample()` method returns timepoints
that occur prior to the model t = 0.
t_unit : int
Expand Down Expand Up @@ -202,6 +202,12 @@ def validate(**kwargs) -> None:
"""
pass

def __call__(self, **kwargs):
"""
Alias for `sample()`.
"""
return self.sample(**kwargs)


class DistributionalRVSample(NamedTuple):
"""
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/model/admissionsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ def sample(
infection_hosp_rate,
latent_hosp_admissions,
*_,
) = self.latent_hosp_admissions_rv.sample(
) = self.latent_hosp_admissions_rv(
latent_infections=basic_model.latent_infections,
**kwargs,
)

(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
) = self.hosp_admission_obs_process_rv(
mu=latent_hosp_admissions[-n_datapoints:],
obs=data_observed_hosp_admissions,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,28 +192,28 @@ def sample(
n_timepoints = n_timepoints_to_simulate + padding
# Sampling from Rt (possibly with a given Rt, depending on
# the Rt_process (RandomVariable) object.)
Rt, *_ = self.Rt_process_rv.sample(
Rt, *_ = self.Rt_process_rv(
n_timepoints=n_timepoints,
**kwargs,
)

# Getting the generation interval
gen_int, *_ = self.gen_int_rv.sample(**kwargs)
gen_int, *_ = self.gen_int_rv(**kwargs)

# Sampling initial infections
I0, *_ = self.I0_rv.sample(**kwargs)
I0, *_ = self.I0_rv(**kwargs)
# Sampling from the latent process
(
post_initialization_latent_infections,
*_,
) = self.latent_infections_rv.sample(
) = self.latent_infections_rv(
Rt=Rt,
gen_int=gen_int,
I0=I0,
**kwargs,
)

observed_infections, *_ = self.infection_obs_process_rv.sample(
observed_infections, *_ = self.infection_obs_process_rv(
mu=post_initialization_latent_infections[padding:],
obs=data_observed_infections,
**kwargs,
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/process/firstdifferencear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def sample(
Parameters
----------
duration : int
Passed to ARProcess.sample().s
Passed to ARProcess()
init_val : ArrayLike, optional
Starting point of the AR process, by default None.
init_rate_of_change : ArrayLike, optional
Passed to ARProcess.sample, by default None.
name : str, optional
Passed to ARProcess.sample(), by default "trend_rw"
Passed to ARProcess(), by default "trend_rw"
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample()
calls, should there be any.
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/periodiceffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class PeriodicEffectSample(NamedTuple):
"""
A container for holding the output from
`process.PeriodicEffect.sample()`.
`process.PeriodicEffect()`.
Attributes
----------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class RtPeriodicDiffProcessSample(NamedTuple):
"""
A container for holding the output from `process.RtPeriodicDiffProcess.sample()`.
A container for holding the output from `process.RtPeriodicDiffProcess()`.
Attributes
----------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/rtrandomwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def sample(

Rt0_trans = self.Rt_transform(Rt0)
Rt_trans_proc = SimpleRandomWalkProcess(self.Rt_rw_dist)
Rt_trans_ts, *_ = Rt_trans_proc.sample(
Rt_trans_ts, *_ = Rt_trans_proc(
n_timepoints=n_timepoints,
name="Rt_transformed_rw",
init=Rt0_trans,
Expand Down
6 changes: 6 additions & 0 deletions model/src/pyrenew/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,11 @@ def sample(self) -> dict:
coefficients=coefficients,
)

def __call__(self, **kwargs):
"""
Alias for `sample()`.
"""
return self.sample(**kwargs)

def __repr__(self):
return "GLMPrediction " + str(self.name)
10 changes: 5 additions & 5 deletions model/src/test/test_ar_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def test_ar_can_be_sampled():
ar1 = ARProcess(5, jnp.array([0.95]), jnp.array([0.5]))
with numpyro.handlers.seed(rng_seed=62):
# can sample with and without inits
ar1.sample(3532, inits=jnp.array([50.0]))
ar1.sample(5023)
ar1(duration=3532, inits=jnp.array([50.0]))
ar1(duration=5023)

ar3 = ARProcess(5, jnp.array([0.05, 0.025, 0.025]), jnp.array([0.5]))
with numpyro.handlers.seed(rng_seed=62):
# can sample with and without inits
ar3.sample(1230)
ar3.sample(52, inits=jnp.array([50.0, 49.9, 48.2]))
ar3(duration=1230)
ar3(duration=52, inits=jnp.array([50.0, 49.9, 48.2]))


def test_ar_samples_correctly_distributed():
Expand All @@ -36,6 +36,6 @@ def test_ar_samples_correctly_distributed():
with numpyro.handlers.seed(rng_seed=62):
# check it regresses to mean
# when started away from it
long_ts, *_ = ar1.sample(10000, inits=ar_inits)
long_ts, *_ = ar1(duration=10000, inits=ar_inits)
assert_almost_equal(long_ts[0], ar_inits)
assert jnp.abs(long_ts[-1] - ar_mean) < 4 * noise_sd
12 changes: 6 additions & 6 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ def test_deterministic():
var5 = NullProcess()

testing.assert_array_equal(
var1.sample()[0],
var1()[0],
jnp.array(
[
1,
]
),
)
testing.assert_array_equal(
var2.sample()[0],
var2()[0],
jnp.array([0.25, 0.25, 0.2, 0.3]),
)
testing.assert_array_equal(
var3.sample(duration=5)[0],
var3(duration=5)[0],
jnp.array([1, 2, 3, 4, 4]),
)

testing.assert_array_equal(
var3.sample(duration=3)[0],
var3(duration=3)[0],
jnp.array(
[
1,
Expand All @@ -58,5 +58,5 @@ def test_deterministic():
),
)

testing.assert_equal(var4.sample()[0], None)
testing.assert_equal(var5.sample(duration=1)[0], None)
testing.assert_equal(var4()[0], None)
testing.assert_equal(var5(duration=1)[0], None)
Loading

0 comments on commit fb8a8a4

Please sign in to comment.