Skip to content

Commit

Permalink
Rework Initial Infections and Support Exponential Growth Initializati…
Browse files Browse the repository at this point in the history
…ons (#124)

* Working on Rt docs

* Starting off module

* Adding test and starting tutorial

* Removing extra file

* Starting work on InfectionSeedingProccess and InfectionsSeedingMethod

* ignore lack of docs while developing

* clarify problems with pad

* rework so n_timepoints is an attribute of SeedInfectionsMethod

* use exceptions instead of assertions

* fixing issues with sampling for infection_seeding_process

* start to rework getting-started.qmd

* Make rate a RandomVariable and fix SeedInfectionsZeroPad

* fix TODO tag

* Tutorial code working (text  not updated)

* Adding a test and extending documentation

* Adding test for datautils

* Adding more content to the tutorial

* Adding a test checking for the calculations of the double conv (expected to fail)

* Update model/src/test/test_infectionsrtfeedback.py

* Clarify variable names in sample_infections_with_feedback()

* Sign consistency and documentation

* Add basic test that infections with feedback reduces to regular infections with zero-strength feedback, and returns the R(t) timeseries correctly

* Add pytest-mpl to pyproject.toml given use of @pytest.mark.mpl_image_compare in tests

* Run precommit

* Clarify pmf input format for sample_infections_with_feedback in documentation

* Update lockfile, style infections.py

* clarify use of reversed PMFs, remove padding operations that should not occur

* Harmonize indexing and remove autopadding; fix manual renewal process implementation. This will break other tests but is a worthwhile breaking change

* Update latent admissions test

* Fixing broken tests

* Cleanup infection seeding

* Cleanup infection seeding

* Adding needed padding for convo

* fix for I0 size

* update tutorials

* trying something

* report all latent infections, including latent onces

* Adding direction of padding to the docs

* use jnp.atleast_1d

* Change default tol in DeterministicPMF to work on macOS

* Namespaces

* Adding missing note

* Update model/src/test/test_model_basic_renewal.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Splitting tests and using pytest.raises

* Removing unnecesary call to test_*

* Update model/src/test/test_datautils.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Fixing tests and adding test for Infections.sample() raise error

* Fixing tutorial using old name of sample_infections

* Fixing docstring (missing r""")

* Deleting call to test

* inf_feedback: either len 1 or len Rt (otherwise error)

* Replacing I0 for DistributionalRVSample

* Replacing I0 for DistributionalRVSample (vis)

* Ensuring DeterministicRV/DistributionalRV.sample returns at least 1d array

* Apply suggestions from code review @damonbayer

Co-authored-by: Damon Bayer <xum8@cdc.gov>

* Fixing pre-commit

* Cherry pick when ghost was fixed

* Cherry pick data utils

* New metaclass.DistributionalRV (#138)

* New metaclass.DistributionalRV

* Update model/src/pyrenew/metaclass.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Update model/src/pyrenew/metaclass.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Update model/src/pyrenew/metaclass.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Update model/src/pyrenew/latent/i0.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Update model/src/pyrenew/latent/i0.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Removing latent.Infections0

* Replacing the InfectHospRate class with DistributionalRV

---------

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Removing residual comments from conflicts

* update example-with-dataset to use exponential initialization

* update default t_I_pre_seed for SeedInfectionsExponential

* update example-with-datasets to use default t_I_pre_seed

* docs

* Put infection functions back in website documentation

* add testing from InfectionSeedMethods and InfectionSeedProcess

* Clean up qmd's

* fix tests in test_infection_seeding_method

* fix test in test_infection_seeding_process

* fix test in test_infectionsrtfeedback

* typo in admissionsmodel

* starting to fix model test

* starting to fix tutorials

* formatting quarto docs

* fix test_model_basic_renewal

* most tests working

* fix extending_pyrenew

* all tests passing

* revise test_model_basicrenewal_plot.png

* cleanup test_model_hospitalizations

* relax test_seed_infections_exponential

* relax test_infection_seeding_method further

* trying plot from ubuntu to see if it fixes CI

* typo

* Apply suggestions from code review

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Updating image

* Making precommit happy

* correct latent.rst

* Apply suggestions from code review

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* format arrayutils

* Apply suggestions from code review

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* fix styling

* add name to infection_seeding_process

* clean up infections.py

* fix docstring for I0 on infectionswithfeedback

* add test for padding a non-scalar array

* additional tests for exponential seeding method

* fix quarto docs

* remove datautils

* rename test_datautils.py to test_arrayutils.py

* add tests for arrayutils

* add tests to make codecov happy

* fix tests

* fix test

* attempt to make codecov happy

* fix math rendering

---------

Co-authored-by: George G. Vega Yon <xrd4@cdc.gov>
Co-authored-by: George G. Vega Yon <g.vegayon@gmail.com>
Co-authored-by: Dylan H. Morris <dzl1@cdc.gov>
Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>
  • Loading branch information
5 people authored Jun 7, 2024
1 parent 600c963 commit 4eda387
Show file tree
Hide file tree
Showing 23 changed files with 880 additions and 207 deletions.
16 changes: 16 additions & 0 deletions docs/source/msei_reference/latent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,20 @@ Infection Functions
:undoc-members:
:show-inheritance:

Infection Seeding Process
-------------------------

.. automodule:: pyrenew.latent.infection_seeding_process
:members:
:undoc-members:
:show-inheritance:

Infection Seeding Method
------------------------

.. automodule:: pyrenew.latent.infection_seeding_method
:members:
:undoc-members:
:show-inheritance:

.. todo:: Determine and naming order of these modules.
114 changes: 60 additions & 54 deletions model/docs/example-with-datasets.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ $$
We start by loading the data and inspecting the first five rows.

```{python}
#| label: data-inspect
# | label: data-inspect
import polars as pl
from pyrenew import datasets
Expand All @@ -61,7 +61,7 @@ dat.head(5)
The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day.

```{python}
#| label: aggregation
# | label: aggregation
# Keeping the first observation of each date
dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"])
Expand All @@ -77,8 +77,8 @@ dat.head(5)
Let's take a look at the daily prevalence of hospital admissions.

```{python}
#| label: fig-plot-hospital-admissions
#| fig-cap: Daily hospital admissions from the simulated data
# | label: fig-plot-hospital-admissions
# | fig-cap: Daily hospital admissions from the simulated data
import matplotlib.pyplot as plt
# Rotating the x-axis labels, and only showing ~10 labels
Expand All @@ -96,13 +96,14 @@ plt.show()
First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospitalization interval.

```{python}
#| label: fig-data-extract
#| fig-cap: Generation interval and infection to hospitalization interval
# | label: fig-data-extract
# | fig-cap: Generation interval and infection to hospitalization interval
gen_int = datasets.load_generation_interval()
inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int = gen_int["probability_mass"].to_numpy()
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
# Taking a pick at the first 5 elements of each
Expand All @@ -121,7 +122,7 @@ plt.show()
With these two in hand, we can start building the model. First, we will define the latent hospital admissions:

```{python}
#| label: latent-hosp
# | label: latent-hosp
from pyrenew import latent, deterministic, metaclass
import jax.numpy as jnp
import numpyro.distributions as dist
Expand All @@ -136,27 +137,32 @@ hosp_rate = metaclass.DistributionalRV(
latent_hosp = latent.HospitalAdmissions(
infection_to_admission_interval=inf_hosp_int,
infect_hosp_rate_dist=hosp_rate,
)
)
```

The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospitalization interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospitalization rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospitalization rate. Now, we can define the rest of the other components:

```{python}
#| label: initializing-rest-of-model
# | label: initializing-rest-of-model
from pyrenew import model, process, observation, metaclass
from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponential
# Infection process
latent_inf = latent.Infections()
I0 = metaclass.DistributionalRV(
dist=dist.LogNormal(loc=jnp.log(80/.05), scale=1.5),
name="I0"
)
I0 = InfectionSeedingProcess(
"I0_seeding",
metaclass.DistributionalRV(
dist=dist.LogNormal(loc=jnp.log(100), scale=0.5), name="I0"
),
SeedInfectionsExponential(
gen_int_array.size,
deterministic.DeterministicVariable(0.5, name="rate"),
),
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int")
rtproc = process.RtRandomWalkProcess(
Rt_rw_dist=dist.Normal(0, 0.1)
)
rtproc = process.RtRandomWalkProcess(Rt_rw_dist=dist.Normal(0, 0.1))
# The observation model
obs = observation.NegativeBinomialObservation(concentration_prior=1.0)
Expand All @@ -165,7 +171,7 @@ obs = observation.NegativeBinomialObservation(concentration_prior=1.0)
Notice all the components are `RandomVariable` instances. We can now build the model:

```{python}
#| label: init-model
# | label: init-model
hosp_model = model.HospitalAdmissionsModel(
latent_infections=latent_inf,
latent_admissions=latent_hosp,
Expand All @@ -179,35 +185,35 @@ hosp_model = model.HospitalAdmissionsModel(
Let's simulate to check if the model is working:

```{python}
#| label: simulation
# | label: simulation
import numpyro as npro
import numpy as np
timeframe = 120
np.random.seed(223)
with npro.handlers.seed(rng_seed = np.random.randint(1, timeframe)):
with npro.handlers.seed(rng_seed=np.random.randint(1, timeframe)):
sim_data = hosp_model.sample(n_timepoints_to_simulate=timeframe)
```

```{python}
#| label: fig-basic
#| fig-cap: Rt and Infections
# | label: fig-basic
# | fig-cap: Rt and Infections
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(range(0, timeframe), sim_data.Rt)
axs[0].set_ylabel('Rt')
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(range(0, timeframe), sim_data.sampled_admissions)
axs[1].set_ylabel('Infections')
axs[1].set_yscale('log')
axs[1].plot(sim_data.sampled_admissions)
axs[1].set_ylabel("Infections")
axs[1].set_yscale("log")
fig.suptitle('Basic renewal model')
fig.supxlabel('Time')
fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()
```
Expand All @@ -218,7 +224,7 @@ We can fit the model to the data. We will use the `run` method of the model obje


```{python}
#| label: model-fit
# | label: model-fit
import jax
hosp_model.run(
Expand All @@ -235,8 +241,8 @@ We can use the `plot_posterior` method to visualize the results[^capture]:
[^capture]: The output is captured to avoid `quarto` from displaying the output twice.

```{python}
#| label: fig-output-hospital-admissions
#| fig-cap: Hospital Admissions posterior distribution
# | label: fig-output-hospital-admissions
# | fig-cap: Hospital Admissions posterior distribution
out = hosp_model.plot_posterior(
var="predicted_admissions",
ylab="Hospital Admissions",
Expand All @@ -251,7 +257,7 @@ The first half of the model is not looking good. The reason is that the infectio
We can use the padding argument to solve the overestimation of hospital admissions in the first half of the model. By setting `padding > 0`, the model then assumes that the first `padding` observations are missing; thus, only observations after `padding` will count towards the likelihood of the model. In practice, the model will extend the estimated Rt and latent infections by `padding` days, given time to adjust to the observed data. The following code will add 21 days of missing data at the beginning of the model and re-estimate it with `padding = 21`:

```{python}
#| label: model-fit-padding
# | label: model-fit-padding
days_to_impute = 21
dat_w_padding = dat["daily_hosp_admits"].to_numpy()
Expand All @@ -265,15 +271,15 @@ hosp_model.run(
observed_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute, # Padding the model
padding=days_to_impute, # Padding the model
)
```

And plotting the results:

```{python}
#| label: fig-output-admissions-with-padding
#| fig-cap: Hospital Admissions posterior distribution
# | label: fig-output-admissions-with-padding
# | fig-cap: Hospital Admissions posterior distribution
out = hosp_model.plot_posterior(
var="predicted_admissions",
ylab="Hospital Admissions",
Expand All @@ -284,33 +290,32 @@ out = hosp_model.plot_posterior(
We can also take a look at the latent infections:

```{python}
#| label: fig-output-infections-with-padding
#| fig-cap: Hospital Admissions posterior distribution
out2 = hosp_model.plot_posterior(
var="latent_infections",
ylab="Latent Infections"
)
# | label: fig-output-infections-with-padding
# | fig-cap: Hospital Admissions posterior distribution
out2 = hosp_model.plot_posterior(var="latent_infections", ylab="Latent Infections")
```

## Round 2: Incorporating weekday effects

We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a weekday effect distribution. To do this, we will create a new instance of `RandomVariable` to model the weekday effect. The weekday effect will be a truncated normal distribution with a mean of 1.0 and a standard deviation of 0.5. The distribution will be truncated between 0.1 and 10.0. The weekday effect will be repeated for the number of weeks in the dataset.

```{python}
#| label: weekly-effect
# | label: weekly-effect
from pyrenew import metaclass
import numpyro as npro
class WeekdayEffect(metaclass.RandomVariable):
"""Weekday effect distribution"""
def __init__(self, len: int):
""" Initialize the weekday effect distribution
"""Initialize the weekday effect distribution
Parameters
----------
len : int
The number of observations
"""
self.nweeks = int(jnp.ceil(len/7))
self.nweeks = int(jnp.ceil(len / 7))
self.len = len
@staticmethod
Expand All @@ -321,12 +326,13 @@ class WeekdayEffect(metaclass.RandomVariable):
ans = npro.sample(
name="weekday_effect",
fn=npro.distributions.TruncatedNormal(
loc=1.0, scale=.5, low=0.1, high=10.0
),
sample_shape=(7,)
loc=1.0, scale=0.5, low=0.1, high=10.0
),
sample_shape=(7,),
)
return jnp.tile(ans, self.nweeks)[:self.len]
return jnp.tile(ans, self.nweeks)[: self.len]
# Initializing the weekday effect
weekday_effect = WeekdayEffect(dat.shape[0])
Expand All @@ -335,12 +341,12 @@ weekday_effect = WeekdayEffect(dat.shape[0])
Notice that the instance's `nweeks` and `len` members are passed during construction. Trying to compute the number of weeks and the length of the dataset in the `validate` method will raise a `jit` error in `jax` as the shape and size of elements are not known during the validation step, which happens before the model is run. With the new weekday effect, we can rebuild the latent hospitalization model:

```{python}
#| label: latent-hosp-weekday
# | label: latent-hosp-weekday
latent_hosp_wday_effect = latent.HospitalAdmissions(
infection_to_admission_interval=inf_hosp_int,
infect_hosp_rate_dist=hosp_rate,
weekday_effect_dist=weekday_effect,
)
)
hosp_model_weekday = model.HospitalAdmissionsModel(
latent_infections=latent_inf,
Expand All @@ -355,7 +361,7 @@ hosp_model_weekday = model.HospitalAdmissionsModel(
Running the model (with the same padding as before):

```{python}
#| label: model-2-run
# | label: model-2-run
hosp_model_weekday.run(
num_samples=2000,
num_warmup=2000,
Expand All @@ -369,8 +375,8 @@ hosp_model_weekday.run(
And plotting the results:

```{python}
#| label: fig-output-admissions-padding-and-weekday
#| fig-cap: Hospital Admissions posterior distribution
# | label: fig-output-admissions-padding-and-weekday
# | fig-cap: Hospital Admissions posterior distribution
out = hosp_model_weekday.plot_posterior(
var="predicted_admissions",
ylab="Hospital Admissions",
Expand Down
Loading

0 comments on commit 4eda387

Please sign in to comment.