Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __call__ As Alias For RandomVariable sample Method #253

10 changes: 5 additions & 5 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ class InfFeedback(RandomVariable):
I0_vec = I0[-gen_int_rev.size :]

# Sampling inf feedback strength and adjusting the shape
inf_feedback_strength, *_ = self.infection_feedback_strength.sample(
inf_feedback_strength, *_ = self.infection_feedback_strength(
**kwargs,
)
inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0]
)

# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs)
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)

# Generating the infections with feedback
Expand All @@ -230,13 +230,13 @@ class InfFeedback(RandomVariable):

The core of the class is implemented in the `sample()` method. Things to highlight from the above code:

1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback.sample()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable.sample()` calls are expected to include the `**kwargs` argument, even if unused.
1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable()` calls are expected to include the `**kwargs` argument, even if unused.

2. **Calls to `RandomVariable.sample()`**: All calls to `RandomVariable.sample()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength.sample()` and `infection_feedback_pmf.sample()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).
2. **Calls to `RandomVariable()`**: All calls to `RandomVariable()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength()` and `infection_feedback_pmf()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback.sample()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.
4. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

```{python}
# | label: simulation2
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ rt_proc = process.RtWeeklyDiffProcess(

```{python}
with npro.handlers.seed(rng_seed=20):
sim_data = rt_proc.sample(duration=30)
sim_data = rt_proc(duration=30)

# Plotting the Rt values
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -84,7 +84,7 @@ Like before, we can use the `sample` method to generate samples from the day of

```{python}
with npro.handlers.seed(rng_seed=20):
sim_data = dayofweek.sample(duration=30)
sim_data = dayofweek(duration=30)

# Plotting the effect values
import matplotlib.pyplot as plt
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ import numpyro.distributions as dist
from pyrenew.process import SimpleRandomWalkProcess
```

To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q.sample(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance.
To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance.

```{python}
# | label: fig-randwalk
# | fig-cap: Random walk example
np.random.seed(3312)
q = SimpleRandomWalkProcess(dist.Normal(0, 0.001))
with seed(rng_seed=np.random.randint(0, 1000)):
q_samp = q.sample(n_timepoints=100)
q_samp = q(n_timepoints=100)

plt.plot(np.exp(q_samp[0]))
```
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The tuple `(t_unit, t_start)` can encode different types of time series data. Fo

## How it relates to periodicity

The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect.sample()` and `RtPeriodicDiffProcess.sample()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.
The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.

## Unimplemented features

Expand All @@ -37,7 +37,7 @@ With random variables possibly spanning different time scales, *e.g.*, weekly, d

### Array alignment

Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel.sample()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:
Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:

```python
Rt_aligned, infections_aligned = align([Rt, infections])
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def pad_x_to_match_y(

class PeriodicProcessSample(NamedTuple):
"""
A container for holding the output from `process.PeriodicProcess.sample()`.
A container for holding the output from `process.PeriodicProcess()`.

Attributes
----------
Expand Down
11 changes: 5 additions & 6 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class HospitalAdmissionsSample(NamedTuple):
"""
A container to hold the output of `latent.HospAdmissions.sample()`.
A container to hold the output of `latent.HospAdmissions()`.

Attributes
----------
Expand Down Expand Up @@ -173,14 +173,14 @@ def sample(
HospitalAdmissionsSample
"""

infection_hosp_rate, *_ = self.infect_hosp_rate_rv.sample(**kwargs)
infection_hosp_rate, *_ = self.infect_hosp_rate_rv(**kwargs)

infection_hosp_rate_t = infection_hosp_rate * latent_infections

(
infection_to_admission_interval,
*_,
) = self.infection_to_admission_interval_rv.sample(**kwargs)
) = self.infection_to_admission_interval_rv(**kwargs)

latent_hospital_admissions = jnp.convolve(
infection_hosp_rate_t,
Expand All @@ -191,13 +191,12 @@ def sample(
# Applying the day of the week effect
latent_hospital_admissions = (
latent_hospital_admissions
* self.day_of_week_effect_rv.sample(**kwargs)[0]
* self.day_of_week_effect_rv(**kwargs)[0]
)

# Applying probability of hospitalization effect
latent_hospital_admissions = (
latent_hospital_admissions
* self.hosp_report_prob_rv.sample(**kwargs)[0]
latent_hospital_admissions * self.hosp_report_prob_rv(**kwargs)[0]
)

npro.deterministic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def seed_infections(self, I_pre_seed: ArrayLike):
raise ValueError(
f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}."
)
(rate,) = self.rate.sample()
(rate,) = self.rate()
if rate.size != 1:
raise ValueError(
f"rate must be an array of size 1. Got size {rate.size}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def sample(self) -> tuple:
tuple
a tuple where the only element is an array with the number of seeded infections at each time point.
"""
(I_pre_seed,) = self.I_pre_seed_rv.sample()
(I_pre_seed,) = self.I_pre_seed_rv()
infection_seeding = self.infection_seed_method(I_pre_seed)
npro.deterministic(self.name, infection_seeding)

Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class InfectionsSample(NamedTuple):
"""
A container for holding the output from `latent.Infections.sample()`.
A container for holding the output from `latent.Infections()`.

Attributes
----------
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def sample(
I0 = I0[-gen_int_rev.size :]

# Sampling inf feedback strength
inf_feedback_strength, *_ = self.infection_feedback_strength.sample(
inf_feedback_strength, *_ = self.infection_feedback_strength(
**kwargs,
)

Expand All @@ -175,7 +175,7 @@ def sample(
)

# Sampling inf feedback pmf
inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs)
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)

inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)

Expand Down
12 changes: 9 additions & 3 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _assert_sample_and_rtype(
"""
Return type-checking for RandomVariable's sample function

Objects passed as `RandomVariable` should (a) have a sample() method that
Objects passed as `RandomVariable` should (a) have a `sample()` method that
(b) returns either a tuple or a named tuple.

Parameters
Expand Down Expand Up @@ -103,7 +103,7 @@ class RandomVariable(metaclass=ABCMeta):
are expected to be used internally mostly for tasks including padding,
alignment of time series, and other time-aware operations.

Both attributes give information about the output of the sample() method,
Both attributes give information about the output of the `sample()` method,
in other words, the relative time units of the returning value.

Attributes
Expand Down Expand Up @@ -138,7 +138,7 @@ def set_timeseries(
t_start : int
The start of the time series relative to the
model time. It could be negative, indicating
that the sample() method returns timepoints
that the `sample()` method returns timepoints
that occur prior to the model t = 0.

t_unit : int
Expand Down Expand Up @@ -202,6 +202,12 @@ def validate(**kwargs) -> None:
"""
pass

def __call__(self, **kwargs):
"""
Alias for `sample()`.
"""
return self.sample(**kwargs)


class DistributionalRVSample(NamedTuple):
"""
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/model/admissionsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ def sample(
infection_hosp_rate,
latent_hosp_admissions,
*_,
) = self.latent_hosp_admissions_rv.sample(
) = self.latent_hosp_admissions_rv(
latent_infections=basic_model.latent_infections,
**kwargs,
)

(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
) = self.hosp_admission_obs_process_rv(
mu=latent_hosp_admissions[-n_datapoints:],
obs=data_observed_hosp_admissions,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,28 +192,28 @@ def sample(
n_timepoints = n_timepoints_to_simulate + padding
# Sampling from Rt (possibly with a given Rt, depending on
# the Rt_process (RandomVariable) object.)
Rt, *_ = self.Rt_process_rv.sample(
Rt, *_ = self.Rt_process_rv(
n_timepoints=n_timepoints,
**kwargs,
)

# Getting the generation interval
gen_int, *_ = self.gen_int_rv.sample(**kwargs)
gen_int, *_ = self.gen_int_rv(**kwargs)

# Sampling initial infections
I0, *_ = self.I0_rv.sample(**kwargs)
I0, *_ = self.I0_rv(**kwargs)
# Sampling from the latent process
(
post_initialization_latent_infections,
*_,
) = self.latent_infections_rv.sample(
) = self.latent_infections_rv(
Rt=Rt,
gen_int=gen_int,
I0=I0,
**kwargs,
)

observed_infections, *_ = self.infection_obs_process_rv.sample(
observed_infections, *_ = self.infection_obs_process_rv(
mu=post_initialization_latent_infections[padding:],
obs=data_observed_infections,
**kwargs,
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/process/firstdifferencear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def sample(
Parameters
----------
duration : int
Passed to ARProcess.sample().s
Passed to ARProcess()
init_val : ArrayLike, optional
Starting point of the AR process, by default None.
init_rate_of_change : ArrayLike, optional
Passed to ARProcess.sample, by default None.
name : str, optional
Passed to ARProcess.sample(), by default "trend_rw"
Passed to ARProcess(), by default "trend_rw"
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample()
calls, should there be any.
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/periodiceffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class PeriodicEffectSample(NamedTuple):
"""
A container for holding the output from
`process.PeriodicEffect.sample()`.
`process.PeriodicEffect()`.

Attributes
----------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class RtPeriodicDiffProcessSample(NamedTuple):
"""
A container for holding the output from `process.RtPeriodicDiffProcess.sample()`.
A container for holding the output from `process.RtPeriodicDiffProcess()`.

Attributes
----------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/rtrandomwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def sample(

Rt0_trans = self.Rt_transform(Rt0)
Rt_trans_proc = SimpleRandomWalkProcess(self.Rt_rw_dist)
Rt_trans_ts, *_ = Rt_trans_proc.sample(
Rt_trans_ts, *_ = Rt_trans_proc(
n_timepoints=n_timepoints,
name="Rt_transformed_rw",
init=Rt0_trans,
Expand Down
6 changes: 6 additions & 0 deletions model/src/pyrenew/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,11 @@ def sample(self) -> dict:
coefficients=coefficients,
)

def __call__(self, **kwargs):
"""
Alias for `sample()`.
"""
return self.sample(**kwargs)

def __repr__(self):
return "GLMPrediction " + str(self.name)
10 changes: 5 additions & 5 deletions model/src/test/test_ar_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def test_ar_can_be_sampled():
ar1 = ARProcess(5, jnp.array([0.95]), jnp.array([0.5]))
with numpyro.handlers.seed(rng_seed=62):
# can sample with and without inits
ar1.sample(3532, inits=jnp.array([50.0]))
ar1.sample(5023)
ar1(duration=3532, inits=jnp.array([50.0]))
ar1(duration=5023)

ar3 = ARProcess(5, jnp.array([0.05, 0.025, 0.025]), jnp.array([0.5]))
with numpyro.handlers.seed(rng_seed=62):
# can sample with and without inits
ar3.sample(1230)
ar3.sample(52, inits=jnp.array([50.0, 49.9, 48.2]))
ar3(duration=1230)
ar3(duration=52, inits=jnp.array([50.0, 49.9, 48.2]))


def test_ar_samples_correctly_distributed():
Expand All @@ -36,6 +36,6 @@ def test_ar_samples_correctly_distributed():
with numpyro.handlers.seed(rng_seed=62):
# check it regresses to mean
# when started away from it
long_ts, *_ = ar1.sample(10000, inits=ar_inits)
long_ts, *_ = ar1(duration=10000, inits=ar_inits)
assert_almost_equal(long_ts[0], ar_inits)
assert jnp.abs(long_ts[-1] - ar_mean) < 4 * noise_sd
12 changes: 6 additions & 6 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ def test_deterministic():
var5 = NullProcess()

testing.assert_array_equal(
var1.sample()[0],
var1()[0],
jnp.array(
[
1,
]
),
)
testing.assert_array_equal(
var2.sample()[0],
var2()[0],
jnp.array([0.25, 0.25, 0.2, 0.3]),
)
testing.assert_array_equal(
var3.sample(duration=5)[0],
var3(duration=5)[0],
jnp.array([1, 2, 3, 4, 4]),
)

testing.assert_array_equal(
var3.sample(duration=3)[0],
var3(duration=3)[0],
jnp.array(
[
1,
Expand All @@ -58,5 +58,5 @@ def test_deterministic():
),
)

testing.assert_equal(var4.sample()[0], None)
testing.assert_equal(var5.sample(duration=1)[0], None)
testing.assert_equal(var4()[0], None)
testing.assert_equal(var5(duration=1)[0], None)
Loading