Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __call__ As Alias For RandomVariable sample Method #253

6 changes: 3 additions & 3 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ class InfFeedback(RandomVariable):

The core of the class is implemented in the `sample()` method. Things to highlight from the above code:

1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback.sample()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable.sample()` calls are expected to include the `**kwargs` argument, even if unused.
1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable()` calls are expected to include the `**kwargs` argument, even if unused.

2. **Calls to `RandomVariable.sample()`**: All calls to `RandomVariable.sample()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength.sample()` and `infection_feedback_pmf.sample()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).
2. **Calls to `RandomVariable()`**: All calls to `RandomVariable()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength()` and `infection_feedback_pmf()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback.sample()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.
4. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

```{python}
# | label: simulation2
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The tuple `(t_unit, t_start)` can encode different types of time series data. Fo

## How it relates to periodicity

The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect.sample()` and `RtPeriodicDiffProcess.sample()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.
The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.

## Unimplemented features

Expand All @@ -37,7 +37,7 @@ With random variables possibly spanning different time scales, *e.g.*, weekly, d

### Array alignment

Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel.sample()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:
Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:

```python
Rt_aligned, infections_aligned = align([Rt, infections])
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def pad_x_to_match_y(

class PeriodicProcessSample(NamedTuple):
"""
A container for holding the output from `process.PeriodicProcess.sample()`.
A container for holding the output from `process.PeriodicProcess()`.

Attributes
----------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class HospitalAdmissionsSample(NamedTuple):
"""
A container to hold the output of `latent.HospAdmissions.sample()`.
A container to hold the output of `latent.HospAdmissions()`.

Attributes
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def seed_infections(self, I_pre_seed: ArrayLike):
raise ValueError(
f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}."
)
(rate,) = self.rate.sample()
(rate,) = self.rate()
if rate.size != 1:
raise ValueError(
f"rate must be an array of size 1. Got size {rate.size}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def sample(self) -> tuple:
tuple
a tuple where the only element is an array with the number of seeded infections at each time point.
"""
(I_pre_seed,) = self.I_pre_seed_rv.sample()
(I_pre_seed,) = self.I_pre_seed_rv()
infection_seeding = self.infection_seed_method(I_pre_seed)
npro.deterministic(self.name, infection_seeding)

Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class InfectionsSample(NamedTuple):
"""
A container for holding the output from `latent.Infections.sample()`.
A container for holding the output from `latent.Infections()`.

Attributes
----------
Expand Down
12 changes: 9 additions & 3 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _assert_sample_and_rtype(
"""
Return type-checking for RandomVariable's sample function

Objects passed as `RandomVariable` should (a) have a sample() method that
Objects passed as `RandomVariable` should (a) have a `sample()` method that
(b) returns either a tuple or a named tuple.

Parameters
Expand Down Expand Up @@ -103,7 +103,7 @@ class RandomVariable(metaclass=ABCMeta):
are expected to be used internally mostly for tasks including padding,
alignment of time series, and other time-aware operations.

Both attributes give information about the output of the sample() method,
Both attributes give information about the output of the `sample()` method,
in other words, the relative time units of the returning value.

Attributes
Expand Down Expand Up @@ -138,7 +138,7 @@ def set_timeseries(
t_start : int
The start of the time series relative to the
model time. It could be negative, indicating
that the sample() method returns timepoints
that the `sample()` method returns timepoints
that occur prior to the model t = 0.

t_unit : int
Expand Down Expand Up @@ -202,6 +202,12 @@ def validate(**kwargs) -> None:
"""
pass

def __call__(self, **kwargs):
"""
Alias for `sample()`.
"""
return self.sample(**kwargs)


class DistributionalRVSample(NamedTuple):
"""
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/model/admissionsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class HospModelSample(NamedTuple):
"""
A container for holding the output from `model.HospitalAdmissionsModel.sample()`.
A container for holding the output from `model.HospitalAdmissionsModel()`.

Attributes
----------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Output class of the RtInfectionsRenewalModel
class RtInfectionsRenewalSample(NamedTuple):
"""
A container for holding the output from `model.RtInfectionsRenewalModel.sample()`.
A container for holding the output from `model.RtInfectionsRenewalModel()`.

Attributes
----------
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/process/firstdifferencear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def sample(
Parameters
----------
duration : int
Passed to ARProcess.sample().s
Passed to ARProcess()
init_val : ArrayLike, optional
Starting point of the AR process, by default None.
init_rate_of_change : ArrayLike, optional
Passed to ARProcess.sample, by default None.
name : str, optional
Passed to ARProcess.sample(), by default "trend_rw"
Passed to ARProcess(), by default "trend_rw"
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample()
calls, should there be any.
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/periodiceffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class PeriodicEffectSample(NamedTuple):
"""
A container for holding the output from
`process.PeriodicEffect.sample()`.
`process.PeriodicEffect()`.

Attributes
----------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class RtPeriodicDiffProcessSample(NamedTuple):
"""
A container for holding the output from `process.RtPeriodicDiffProcess.sample()`.
A container for holding the output from `process.RtPeriodicDiffProcess()`.

Attributes
----------
Expand Down
6 changes: 6 additions & 0 deletions model/src/pyrenew/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,11 @@ def sample(self) -> dict:
coefficients=coefficients,
)

def __call__(self, **kwargs):
"""
Alias for `sample()`.
"""
return self.sample(**kwargs)

def __repr__(self):
return "GLMPrediction " + str(self.name)
8 changes: 4 additions & 4 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def test_deterministic():
var5 = NullProcess()

testing.assert_array_equal(
var1.sample()[0],
var1()[0],
jnp.array(
[
1,
]
),
)
testing.assert_array_equal(
var2.sample()[0],
var2()[0],
jnp.array([0.25, 0.25, 0.2, 0.3]),
)
testing.assert_array_equal(
Expand All @@ -58,5 +58,5 @@ def test_deterministic():
),
)

testing.assert_equal(var4.sample()[0], None)
testing.assert_equal(var5.sample(duration=1)[0], None)
testing.assert_equal(var4()[0], None)
testing.assert_equal(var5(duration=1)[0], None)
10 changes: 5 additions & 5 deletions model/src/test/test_infection_seeding_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_seed_infections_exponential():
I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV")
default_t_pre_seed = n_timepoints - 1

(I_pre_seed,) = I_pre_seed_RV.sample()
(rate,) = rate_RV.sample()
(I_pre_seed,) = I_pre_seed_RV()
(rate,) = rate_RV()

infections_default_t_pre_seed = InitializeInfectionsExponentialGrowth(
n_timepoints, rate=rate_RV
Expand All @@ -44,7 +44,7 @@ def test_seed_infections_exponential():
I_pre_seed_RV_2 = DeterministicVariable(
np.array([10.0, 10.0]), name="I_pre_seed_RV"
)
(I_pre_seed_2,) = I_pre_seed_RV_2.sample()
(I_pre_seed_2,) = I_pre_seed_RV_2()

with pytest.raises(ValueError):
InitializeInfectionsExponentialGrowth(
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_seed_infections_zero_pad():

n_timepoints = 10
I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV")
(I_pre_seed,) = I_pre_seed_RV.sample()
(I_pre_seed,) = I_pre_seed_RV()

infections = InitializeInfectionsZeroPad(n_timepoints).seed_infections(
I_pre_seed
Expand All @@ -85,7 +85,7 @@ def test_seed_infections_zero_pad():
I_pre_seed_RV_2 = DeterministicVariable(
np.array([10.0, 10.0]), name="I_pre_seed_RV"
)
(I_pre_seed_2,) = I_pre_seed_RV_2.sample()
(I_pre_seed_2,) = I_pre_seed_RV_2()

infections_2 = InitializeInfectionsZeroPad(n_timepoints).seed_infections(
I_pre_seed_2
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_infection_seeding_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_infection_initialization_process():

for model in [zero_pad_model, exp_model, vec_model]:
with npro.handlers.seed(rng_seed=1):
model.sample()
model()

# Check that the InfectionInitializationProcess class raises an error when the wrong type of I0 is passed
with pytest.raises(TypeError):
Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def test_infectionsrtfeedback_feedback():
gen_int=gen_int,
Rt=Rt,
I0=I0,
inf_feedback_strength=inf_feed_strength.sample()[0],
inf_feedback_pmf=inf_feedback_pmf.sample()[0],
inf_feedback_strength=inf_feed_strength()[0],
inf_feedback_pmf=inf_feedback_pmf()[0],
)

assert not jnp.array_equal(
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_glm_prediction():

# sampling should work
with numpyro.handlers.seed(rng_seed=5):
preds = glm_pred.sample()
preds = glm_pred()

assert isinstance(preds, dict)

Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ def test_rtweeklydiff_manual_reconstruction() -> None:

_, ans0 = lax.scan(
f=rtwd.autoreg_process,
init=np.hstack([params["log_rt_prior"].sample()[0], b]),
init=np.hstack([params["log_rt_prior"]()[0], b]),
xs=noise,
)

ans1 = _manual_rt_weekly_diff(
log_seed=params["log_rt_prior"].sample()[0], sd=noise, b=b
log_seed=params["log_rt_prior"]()[0], sd=noise, b=b
)

assert_array_equal(ans0, ans1)
Expand Down