diff --git a/docs/source/msei_reference/latent.rst b/docs/source/msei_reference/latent.rst index 18bbbfae..e6aa0859 100644 --- a/docs/source/msei_reference/latent.rst +++ b/docs/source/msei_reference/latent.rst @@ -25,4 +25,20 @@ Infection Functions :undoc-members: :show-inheritance: +Infection Seeding Process +------------------------- + +.. automodule:: pyrenew.latent.infection_seeding_process + :members: + :undoc-members: + :show-inheritance: + +Infection Seeding Method +------------------------ + +.. automodule:: pyrenew.latent.infection_seeding_method + :members: + :undoc-members: + :show-inheritance: + .. todo:: Determine and naming order of these modules. diff --git a/model/docs/example-with-datasets.qmd b/model/docs/example-with-datasets.qmd index 1df25ad7..2587fa5f 100644 --- a/model/docs/example-with-datasets.qmd +++ b/model/docs/example-with-datasets.qmd @@ -50,7 +50,7 @@ $$ We start by loading the data and inspecting the first five rows. ```{python} -#| label: data-inspect +# | label: data-inspect import polars as pl from pyrenew import datasets @@ -61,7 +61,7 @@ dat.head(5) The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day. ```{python} -#| label: aggregation +# | label: aggregation # Keeping the first observation of each date dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"]) @@ -77,8 +77,8 @@ dat.head(5) Let's take a look at the daily prevalence of hospital admissions. ```{python} -#| label: fig-plot-hospital-admissions -#| fig-cap: Daily hospital admissions from the simulated data +# | label: fig-plot-hospital-admissions +# | fig-cap: Daily hospital admissions from the simulated data import matplotlib.pyplot as plt # Rotating the x-axis labels, and only showing ~10 labels @@ -96,13 +96,14 @@ plt.show() First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospitalization interval. ```{python} -#| label: fig-data-extract -#| fig-cap: Generation interval and infection to hospitalization interval +# | label: fig-data-extract +# | fig-cap: Generation interval and infection to hospitalization interval gen_int = datasets.load_generation_interval() inf_hosp_int = datasets.load_infection_admission_interval() # We only need the probability_mass column of each dataset -gen_int = gen_int["probability_mass"].to_numpy() +gen_int_array = gen_int["probability_mass"].to_numpy() +gen_int = gen_int_array inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() # Taking a pick at the first 5 elements of each @@ -121,7 +122,7 @@ plt.show() With these two in hand, we can start building the model. First, we will define the latent hospital admissions: ```{python} -#| label: latent-hosp +# | label: latent-hosp from pyrenew import latent, deterministic, metaclass import jax.numpy as jnp import numpyro.distributions as dist @@ -136,27 +137,32 @@ hosp_rate = metaclass.DistributionalRV( latent_hosp = latent.HospitalAdmissions( infection_to_admission_interval=inf_hosp_int, infect_hosp_rate_dist=hosp_rate, - ) +) ``` The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospitalization interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospitalization rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospitalization rate. Now, we can define the rest of the other components: ```{python} -#| label: initializing-rest-of-model +# | label: initializing-rest-of-model from pyrenew import model, process, observation, metaclass +from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponential # Infection process latent_inf = latent.Infections() -I0 = metaclass.DistributionalRV( - dist=dist.LogNormal(loc=jnp.log(80/.05), scale=1.5), - name="I0" - ) +I0 = InfectionSeedingProcess( + "I0_seeding", + metaclass.DistributionalRV( + dist=dist.LogNormal(loc=jnp.log(100), scale=0.5), name="I0" + ), + SeedInfectionsExponential( + gen_int_array.size, + deterministic.DeterministicVariable(0.5, name="rate"), + ), +) # Generation interval and Rt gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") -rtproc = process.RtRandomWalkProcess( - Rt_rw_dist=dist.Normal(0, 0.1) -) +rtproc = process.RtRandomWalkProcess(Rt_rw_dist=dist.Normal(0, 0.1)) # The observation model obs = observation.NegativeBinomialObservation(concentration_prior=1.0) @@ -165,7 +171,7 @@ obs = observation.NegativeBinomialObservation(concentration_prior=1.0) Notice all the components are `RandomVariable` instances. We can now build the model: ```{python} -#| label: init-model +# | label: init-model hosp_model = model.HospitalAdmissionsModel( latent_infections=latent_inf, latent_admissions=latent_hosp, @@ -179,35 +185,35 @@ hosp_model = model.HospitalAdmissionsModel( Let's simulate to check if the model is working: ```{python} -#| label: simulation +# | label: simulation import numpyro as npro import numpy as np timeframe = 120 np.random.seed(223) -with npro.handlers.seed(rng_seed = np.random.randint(1, timeframe)): +with npro.handlers.seed(rng_seed=np.random.randint(1, timeframe)): sim_data = hosp_model.sample(n_timepoints_to_simulate=timeframe) ``` ```{python} -#| label: fig-basic -#| fig-cap: Rt and Infections +# | label: fig-basic +# | fig-cap: Rt and Infections import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(range(0, timeframe), sim_data.Rt) -axs[0].set_ylabel('Rt') +axs[0].plot(sim_data.Rt) +axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(range(0, timeframe), sim_data.sampled_admissions) -axs[1].set_ylabel('Infections') -axs[1].set_yscale('log') +axs[1].plot(sim_data.sampled_admissions) +axs[1].set_ylabel("Infections") +axs[1].set_yscale("log") -fig.suptitle('Basic renewal model') -fig.supxlabel('Time') +fig.suptitle("Basic renewal model") +fig.supxlabel("Time") plt.tight_layout() plt.show() ``` @@ -218,7 +224,7 @@ We can fit the model to the data. We will use the `run` method of the model obje ```{python} -#| label: model-fit +# | label: model-fit import jax hosp_model.run( @@ -235,8 +241,8 @@ We can use the `plot_posterior` method to visualize the results[^capture]: [^capture]: The output is captured to avoid `quarto` from displaying the output twice. ```{python} -#| label: fig-output-hospital-admissions -#| fig-cap: Hospital Admissions posterior distribution +# | label: fig-output-hospital-admissions +# | fig-cap: Hospital Admissions posterior distribution out = hosp_model.plot_posterior( var="predicted_admissions", ylab="Hospital Admissions", @@ -251,7 +257,7 @@ The first half of the model is not looking good. The reason is that the infectio We can use the padding argument to solve the overestimation of hospital admissions in the first half of the model. By setting `padding > 0`, the model then assumes that the first `padding` observations are missing; thus, only observations after `padding` will count towards the likelihood of the model. In practice, the model will extend the estimated Rt and latent infections by `padding` days, given time to adjust to the observed data. The following code will add 21 days of missing data at the beginning of the model and re-estimate it with `padding = 21`: ```{python} -#| label: model-fit-padding +# | label: model-fit-padding days_to_impute = 21 dat_w_padding = dat["daily_hosp_admits"].to_numpy() @@ -265,15 +271,15 @@ hosp_model.run( observed_admissions=dat_w_padding, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), - padding=days_to_impute, # Padding the model + padding=days_to_impute, # Padding the model ) ``` And plotting the results: ```{python} -#| label: fig-output-admissions-with-padding -#| fig-cap: Hospital Admissions posterior distribution +# | label: fig-output-admissions-with-padding +# | fig-cap: Hospital Admissions posterior distribution out = hosp_model.plot_posterior( var="predicted_admissions", ylab="Hospital Admissions", @@ -284,12 +290,9 @@ out = hosp_model.plot_posterior( We can also take a look at the latent infections: ```{python} -#| label: fig-output-infections-with-padding -#| fig-cap: Hospital Admissions posterior distribution -out2 = hosp_model.plot_posterior( - var="latent_infections", - ylab="Latent Infections" -) +# | label: fig-output-infections-with-padding +# | fig-cap: Hospital Admissions posterior distribution +out2 = hosp_model.plot_posterior(var="latent_infections", ylab="Latent Infections") ``` ## Round 2: Incorporating weekday effects @@ -297,20 +300,22 @@ out2 = hosp_model.plot_posterior( We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a weekday effect distribution. To do this, we will create a new instance of `RandomVariable` to model the weekday effect. The weekday effect will be a truncated normal distribution with a mean of 1.0 and a standard deviation of 0.5. The distribution will be truncated between 0.1 and 10.0. The weekday effect will be repeated for the number of weeks in the dataset. ```{python} -#| label: weekly-effect +# | label: weekly-effect from pyrenew import metaclass import numpyro as npro + class WeekdayEffect(metaclass.RandomVariable): """Weekday effect distribution""" + def __init__(self, len: int): - """ Initialize the weekday effect distribution + """Initialize the weekday effect distribution Parameters ---------- len : int The number of observations """ - self.nweeks = int(jnp.ceil(len/7)) + self.nweeks = int(jnp.ceil(len / 7)) self.len = len @staticmethod @@ -321,12 +326,13 @@ class WeekdayEffect(metaclass.RandomVariable): ans = npro.sample( name="weekday_effect", fn=npro.distributions.TruncatedNormal( - loc=1.0, scale=.5, low=0.1, high=10.0 - ), - sample_shape=(7,) + loc=1.0, scale=0.5, low=0.1, high=10.0 + ), + sample_shape=(7,), ) - return jnp.tile(ans, self.nweeks)[:self.len] + return jnp.tile(ans, self.nweeks)[: self.len] + # Initializing the weekday effect weekday_effect = WeekdayEffect(dat.shape[0]) @@ -335,12 +341,12 @@ weekday_effect = WeekdayEffect(dat.shape[0]) Notice that the instance's `nweeks` and `len` members are passed during construction. Trying to compute the number of weeks and the length of the dataset in the `validate` method will raise a `jit` error in `jax` as the shape and size of elements are not known during the validation step, which happens before the model is run. With the new weekday effect, we can rebuild the latent hospitalization model: ```{python} -#| label: latent-hosp-weekday +# | label: latent-hosp-weekday latent_hosp_wday_effect = latent.HospitalAdmissions( infection_to_admission_interval=inf_hosp_int, infect_hosp_rate_dist=hosp_rate, weekday_effect_dist=weekday_effect, - ) +) hosp_model_weekday = model.HospitalAdmissionsModel( latent_infections=latent_inf, @@ -355,7 +361,7 @@ hosp_model_weekday = model.HospitalAdmissionsModel( Running the model (with the same padding as before): ```{python} -#| label: model-2-run +# | label: model-2-run hosp_model_weekday.run( num_samples=2000, num_warmup=2000, @@ -369,8 +375,8 @@ hosp_model_weekday.run( And plotting the results: ```{python} -#| label: fig-output-admissions-padding-and-weekday -#| fig-cap: Hospital Admissions posterior distribution +# | label: fig-output-admissions-padding-and-weekday +# | fig-cap: Hospital Admissions posterior distribution out = hosp_model_weekday.plot_posterior( var="predicted_admissions", ylab="Hospital Admissions", diff --git a/model/docs/extending_pyrenew.qmd b/model/docs/extending_pyrenew.qmd index 9f77be81..509c44ea 100644 --- a/model/docs/extending_pyrenew.qmd +++ b/model/docs/extending_pyrenew.qmd @@ -19,7 +19,7 @@ Where $\mathcal{R}^u(t)$ is the unadjusted reproduction number, $g(t)$ is the ge Before we start, let's simulate the model with the original `InfectionsWithFeedback` class. To keep it simple, we will simulate the model with no observation process, in other words, only with latent infections. The following code-chunk loads the required libraries and defines the model components: ```{python} -#| label: setup +# | label: setup import jax import jax.numpy as jnp import numpy as np @@ -30,20 +30,29 @@ from pyrenew.latent import InfectionsWithFeedback from pyrenew.model import RtInfectionsRenewalModel from pyrenew.process import RtRandomWalkProcess from pyrenew.metaclass import DistributionalRV +from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponential ``` The following code-chunk defines the model components. Notice that for both the generation interval and the infection feedback, we use a deterministic PMF with equal probabilities: ```{python} -#| label: model-components -gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1]), name="gen_int") +# | label: model-components +gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1]) +gen_int = DeterministicPMF(gen_int_array, name="gen_int") feedback_strength = DeterministicVariable(0.05, name="feedback_strength") -I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") +I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsExponential( + gen_int_array.size, + DeterministicVariable(0.5, name="rate"), + ), +) latent_infections = InfectionsWithFeedback( - infection_feedback_strength = feedback_strength, - infection_feedback_pmf = gen_int, + infection_feedback_strength=feedback_strength, + infection_feedback_pmf=gen_int, ) rt = RtRandomWalkProcess() @@ -52,7 +61,7 @@ rt = RtRandomWalkProcess() With all the components defined, we can build the model: ```{python} -#| label: build1 +# | label: build1 model0 = RtInfectionsRenewalModel( gen_int=gen_int, I0=I0, @@ -62,10 +71,10 @@ model0 = RtInfectionsRenewalModel( ) ``` -And simulate it from: +And simulate from it: ```{python} -#| label: simulate1 +# | label: simulate1 # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): @@ -73,9 +82,10 @@ with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): ``` ```{python} -#| label: fig-simulate1 -#| fig-cap: Simulated infections with no observation process +# | label: fig-simulate1 +# | fig-cap: Simulated infections with no observation process import matplotlib.pyplot as plt + fig, ax = plt.subplots() ax.plot(model0_samp.latent_infections) ax.set_xlabel("Time") @@ -89,9 +99,10 @@ plt.show() All instances of PyRenew's `RandomVariable` should have at least three functions: `__init__()`, `validate()`, and `sample()`. The `__init__()` function is the constructor and initializes the class. The `validate()` function checks if the class is correctly initialized. Finally, the `sample()` method contains the core of the class; it should return a tuple or named tuple. The following is a minimal example of a `RandomVariable` class based on `numpyro.distributions.Normal`: -```python +```{python} from pyrenew.metaclass import RandomVariable + class MyNormal(RandomVariable): def __init__(self, loc, scale): self.validate(scale) @@ -116,22 +127,22 @@ The `@staticmethod` decorator exposes the `validate` function to be used outside Although returning namedtuples is not strictly required, they are the recommended return type, as they make the code more readable. The following code-chunk shows how to create a named tuple for the `InfectionsWithFeedback` class: ```{python} -#| label: data-class +# | label: data-class from collections import namedtuple # Creating a tuple to store the output InfFeedbackSample = namedtuple( - typename='InfFeedbackSample', - field_names=['infections', 'rt'], - defaults=(None, None) + typename="InfFeedbackSample", + field_names=["infections", "rt"], + defaults=(None, None), ) ``` The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: ```{python} -#| label: new-model-def -#| code-line-numbers: true +# | label: new-model-def +# | code-line-numbers: true # Creating the class from pyrenew.metaclass import RandomVariable from pyrenew.latent import compute_infections_from_rt_with_feedback @@ -139,6 +150,7 @@ from pyrenew import arrayutils as au from jax.typing import ArrayLike import jax.numpy as jnp + class InfFeedback(RandomVariable): """Latent infections""" @@ -206,7 +218,7 @@ class InfFeedback(RandomVariable): return InfFeedbackSample( infections=all_infections, rt=Rt_adj, - ) + ) ``` The core of the class is implemented in the `sample()` method. Things to highlight from the above code: @@ -220,10 +232,10 @@ The core of the class is implemented in the `sample()` method. Things to highlig 4. **Return type of `InfFeedback.sample()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`. ```{python} -#| label: simulation2 +# | label: simulation2 latent_infections2 = InfFeedback( - infection_feedback_strength = feedback_strength, - infection_feedback_pmf = gen_int, + infection_feedback_strength=feedback_strength, + infection_feedback_pmf=gen_int, ) model1 = RtInfectionsRenewalModel( @@ -243,9 +255,10 @@ with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): Comparing `model0` with `model1`, these two should match: ```{python} -#| label: fig-model0-vs-model1 -#| fig-cap: Comparing latent infections from model 0 and model 1 +# | label: fig-model0-vs-model1 +# | fig-cap: Comparing latent infections from model 0 and model 1 import matplotlib.pyplot as plt + fig, ax = plt.subplots(ncols=2) ax[0].plot(model0_samp.latent_infections) ax[1].plot(model1_samp.latent_infections) diff --git a/model/docs/getting-started.qmd b/model/docs/getting-started.qmd index 747d707a..0b02c98d 100644 --- a/model/docs/getting-started.qmd +++ b/model/docs/getting-started.qmd @@ -42,15 +42,19 @@ This section will show the steps to build a simple renewal model featuring a lat We start by loading the needed components to build a basic renewal model: ```{python} -#| label: loading-pkgs -#| output: false -#| warning: false +# | label: loading-pkgs +# | output: false +# | warning: false import jax.numpy as jnp import numpy as np import numpyro as npro import numpyro.distributions as dist from pyrenew.process import RtRandomWalkProcess -from pyrenew.latent import Infections +from pyrenew.latent import ( + Infections, + InfectionSeedingProcess, + SeedInfectionsZeroPad, +) from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF from pyrenew.model import RtInfectionsRenewalModel @@ -73,7 +77,7 @@ To initialize these five components within the renewal modeling framework, we es (1) In this example, the generation interval is not estimated but passed as a deterministic instance of `RandomVariable` -(2) an instance of the `DistributionalRV` class, with a log-normal distribution with mean = 0 and standard deviation = 1 as input. +(2) an instance of the `InfectionSeedingProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `SeedInfectionsZeroPad`, the latent infections before this time are assumed to be 0. (3) an instance of the `RtRandomWalkProcess` class with default values @@ -82,12 +86,17 @@ To initialize these five components within the renewal modeling framework, we es (5) an instance of the `PoissonObservation` class with default values ```{python} -#| label: creating-elements +# | label: creating-elements # (1) The generation interval (deterministic) -gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int") +pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) +gen_int = DeterministicPMF(pmf_array, name="gen_int") # (2) Initial infections (inferred with a prior) -I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") +I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(pmf_array.size), +) # (3) The random process for Rt rt_proc = RtRandomWalkProcess() @@ -102,14 +111,14 @@ observation_process = PoissonObservation() With these five pieces, we can build the basic renewal model as an instance of the `RtInfectionsRenewalModel` class: ```{python} -#| label: model-creation +# | label: model-creation model1 = RtInfectionsRenewalModel( - gen_int = gen_int, - I0 = I0, - Rt_process = rt_proc, - latent_infections = latent_infections, - observation_process = observation_process, - ) + gen_int=gen_int, + I0=I0, + Rt_process=rt_proc, + latent_infections=latent_infections, + observation_process=observation_process, +) ``` The following diagram summarizes how the modules interact via composition; notably, `gen_int`, `I0`, `rt_proc`, `latent_infections`, and `observed_infections` are instances of `RandomVariable`, which means these can be easily replaced to generate a different instance of the `RtInfectionsRenewalModel` class: @@ -119,7 +128,7 @@ The following diagram summarizes how the modules interact via composition; notab %%| include: true flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] - i0["(2) I0\n(DistributionalRV)"] + i0["(2) I0\n(InfectionSeedingProcess)"] rt["(3) rt_proc\n(RtRandomWalkProcess)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] @@ -136,9 +145,9 @@ flowchart TB Using `numpyro`, we can simulate data using the `sample()` member function of `RtInfectionsRenewalModel`. The `sample()` method of the `RtInfectionsRenewalModel` class returns a list composed of the `Rt` and `infections` sequences, called `sim_data`: ```{python} -#| label: simulate +# | label: simulate np.random.seed(223) -with npro.handlers.seed(rng_seed = np.random.randint(1, 60)): +with npro.handlers.seed(rng_seed=np.random.randint(1, 60)): sim_data = model1.sample(n_timepoints_to_simulate=30) sim_data @@ -147,22 +156,22 @@ sim_data To understand what has been accomplished here, visualize an $R_t$ sample path (left panel) and infections over time (right panel): ```{python} -#| label: fig-basic -#| fig-cap: Rt and Infections +# | label: fig-basic +# | fig-cap: Rt and Infections import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(range(0, len(sim_data.Rt)), sim_data.Rt) -axs[0].set_ylabel('Rt') +axs[0].plot(sim_data.Rt) +axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(range(0, len(sim_data.Rt)), sim_data.sampled_infections) -axs[1].set_ylabel('Infections') +axs[1].plot(sim_data.sampled_infections) +axs[1].set_ylabel("Infections") -fig.suptitle('Basic renewal model') -fig.supxlabel('Time') +fig.suptitle("Basic renewal model") +fig.supxlabel("Time") plt.tight_layout() plt.show() ``` @@ -170,7 +179,7 @@ plt.show() To fit the model, we can use the `run()` method of the `RtInfectionsRenewalModel` class (an inherited method from the metaclass `Model`). `model1.run()` will call the `run` method of the `model1` object, which will generate an instance of model MCMC simulation, with 2000 warm-up iterations for the MCMC algorithm, used to tune the parameters of the MCMC algorithm to improve efficiency of the sampling process. From the posterior distribution of the model parameters, 1000 samples will be drawn and used to estimate the posterior distributions and compute summary statistics. Observed data is provided to the model using the `sim_data` object previously generated. `mcmc_args` provides additional arguments for the MCMC algorithm. ```{python} -#| label: model-fit +# | label: model-fit import jax model1.run( @@ -179,7 +188,7 @@ model1.run( observed_infections=sim_data.sampled_infections, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), - ) +) ``` Now, let's investigate the output, particularly the posterior distribution of the $R_t$ estimates: diff --git a/model/docs/pyrenew_demo.qmd b/model/docs/pyrenew_demo.qmd index 205dbeb5..ea8ff864 100644 --- a/model/docs/pyrenew_demo.qmd +++ b/model/docs/pyrenew_demo.qmd @@ -21,9 +21,9 @@ pip install git+https://github.com/CDCgov/multisignal-epi-inference@main#subdire To begin, run the following import section to call external modules and functions necessary to run the `pyrenew` demo. The `import` statement imports the module and the `as` statement renames the module for use within this script. The `from` statement imports a specific function from a module (named after the `.`) within a package (named before the `.`). ```{python} -#| output: false -#| label: loading-pkgs -#| warning: false +# | output: false +# | label: loading-pkgs +# | warning: false import matplotlib as mpl import matplotlib.pyplot as plt import jax @@ -40,11 +40,11 @@ from pyrenew.process import SimpleRandomWalkProcess To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the duration of the block that follows. Inside the `with` block, the `q_samp = q.sample(duration=100)` generates the sample instance over a duration of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. ```{python} -#| label: fig-randwalk -#| fig-cap: Random walk example +# | label: fig-randwalk +# | fig-cap: Random walk example np.random.seed(3312) q = SimpleRandomWalkProcess(dist.Normal(0, 0.001)) -with seed(rng_seed=np.random.randint(0,1000)): +with seed(rng_seed=np.random.randint(0, 1000)): q_samp = q.sample(duration=100) plt.plot(np.exp(q_samp[0])) @@ -54,7 +54,8 @@ Next, import several additional functions from the `latent` module of the `pyren ```{python} from pyrenew.latent import ( - Infections, HospitalAdmissions, + Infections, + HospitalAdmissions, ) from pyrenew.metaclass import DistributionalRV ``` @@ -66,13 +67,14 @@ from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.model import HospitalAdmissionsModel from pyrenew.process import RtRandomWalkProcess +from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsZeroPad ``` To initialize the model, we first define initial conditions, including: 1) deterministic generation time, defined as an instance of the `DeterministicPMF` class, which gives the probability of each possible outcome for a discrete random variable given as a JAX NumPy array of four possible outcomes -2) initial infections at the start of simulation as a log-normal distribution with mean = 0 and standard deviation = 1 +2) initial infections at the start of the renewal process as a log-normal distribution with mean = 0 and standard deviation = 1. Infections before this time are assumed to be 0. 3) latent infections as an instance of the `Infections` class with default settings @@ -86,10 +88,15 @@ To initialize the model, we first define initial conditions, including: # Initializing model components: # 1) A deterministic generation time -gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int") +pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) +gen_int = DeterministicPMF(pmf_array, name="gen_int") # 2) Initial infections -I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") +I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(pmf_array.size), +) # 3) The latent infections process latent_infections = Infections() @@ -98,17 +105,18 @@ latent_infections = Infections() # First, define a deterministic infection to hosp pmf inf_hosp_int = DeterministicPMF( - jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]), - name="inf_hosp_int" - ) + jnp.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.25, 0.5, 0.1, 0.1, 0.05] + ), + name="inf_hosp_int", +) latent_admissions = HospitalAdmissions( infection_to_admission_interval=inf_hosp_int, - infect_hosp_rate_dist = DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), - name="IHR" - ), - ) + infect_hosp_rate_dist=DistributionalRV( + dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + ), +) # 5) An observation process for the hospital admissions admissions_process = PoissonObservation() @@ -128,8 +136,8 @@ hospmodel = HospitalAdmissionsModel( latent_admissions=latent_admissions, observation_process=admissions_process, latent_infections=latent_infections, - Rt_process=Rt_process - ) + Rt_process=Rt_process, +) ``` Next, we sample from the `hospmodel` for 30 time steps and view the output of a single run: @@ -164,7 +172,7 @@ hospmodel.run( observed_admissions=x.sampled_admissions, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), - ) +) ``` Print a summary of the model: @@ -177,25 +185,31 @@ Next, we will use the `spread_draws` function from the `pyrenew.mcmcutils` modul ```{python} from pyrenew.mcmcutils import spread_draws + samps = spread_draws(hospmodel.mcmc.get_samples(), [("Rt", "time")]) ``` We visualize these samples below, with individual possible Rt estimates over time shown in light blue, and the overall mean estimate Rt shown in dark blue. ```{python} -#| label: fig-sampled-rt -#| fig-cap: Posterior Rt +# | label: fig-sampled-rt +# | fig-cap: Posterior Rt import numpy as np import polars as pl + fig, ax = plt.subplots(figsize=[4, 5]) ax.plot(x[0]) samp_ids = np.random.randint(size=25, low=0, high=999) for samp_id in samp_ids: - sub_samps = samps.filter(pl.col("draw") == samp_id).sort(pl.col('time')) - ax.plot(sub_samps.select("time").to_numpy(), - sub_samps.select("Rt").to_numpy(), color="darkblue", alpha=0.1) -ax.set_ylim([0.4, 1/.4]) + sub_samps = samps.filter(pl.col("draw") == samp_id).sort(pl.col("time")) + ax.plot( + sub_samps.select("time").to_numpy(), + sub_samps.select("Rt").to_numpy(), + color="darkblue", + alpha=0.1, + ) +ax.set_ylim([0.4, 1 / 0.4]) ax.set_yticks([0.5, 1, 2]) ax.set_yscale("log") ``` diff --git a/model/src/pyrenew/arrayutils.py b/model/src/pyrenew/arrayutils.py index 552183af..595b84b5 100644 --- a/model/src/pyrenew/arrayutils.py +++ b/model/src/pyrenew/arrayutils.py @@ -10,10 +10,11 @@ def pad_to_match( x: ArrayLike, y: ArrayLike, fill_value: float = 0.0, + pad_direction: str = "end", fix_y: bool = False, ) -> tuple[ArrayLike, ArrayLike]: """ - Pad the shorter array at the end to match the length of the longer array. + Pad the shorter array at the start or end to match the length of the longer array. Parameters ---------- @@ -23,6 +24,8 @@ def pad_to_match( Second array. fill_value : float, optional Value to use for padding, by default 0.0. + pad_direction : str, optional + Direction to pad the shorter array, either "start" or "end", by default "end". fix_y : bool, optional If True, raise an error when `y` is shorter than `x`, by default False. @@ -31,23 +34,32 @@ def pad_to_match( tuple[ArrayLike, ArrayLike] Tuple of the two arrays with the same length. """ - x = jnp.atleast_1d(x) y = jnp.atleast_1d(y) - x_len = x.size y_len = y.size + pad_size = abs(x_len - y_len) + + pad_width = {"start": (pad_size, 0), "end": (0, pad_size)}.get( + pad_direction, None + ) + + if pad_width is None: + raise ValueError( + "pad_direction must be either 'start' or 'end'." + f" Got {pad_direction}." + ) + if x_len > y_len: if fix_y: raise ValueError( "Cannot fix y when x is longer than y." - + f" x_len: {x_len}, y_len: {y_len}." + f" x_len: {x_len}, y_len: {y_len}." ) - - y = jnp.pad(y, (0, x_len - y_len), constant_values=fill_value) + y = jnp.pad(y, pad_width, constant_values=fill_value) elif y_len > x_len: - x = jnp.pad(x, (0, y_len - x_len), constant_values=fill_value) + x = jnp.pad(x, pad_width, constant_values=fill_value) return x, y @@ -56,9 +68,10 @@ def pad_x_to_match_y( x: ArrayLike, y: ArrayLike, fill_value: float = 0.0, + pad_direction: str = "end", ) -> ArrayLike: """ - Pad the `x` array at the end to match the length of the `y` array. + Pad the `x` array at the start or end to match the length of the `y` array. Parameters ---------- @@ -66,10 +79,16 @@ def pad_x_to_match_y( First array. y : ArrayLike Second array. + fill_value : float, optional + Value to use for padding, by default 0.0. + pad_direction : str, optional + Direction to pad the shorter array, either "start" or "end", by default "end". Returns ------- Array Padded array. """ - return pad_to_match(x, y, fill_value=fill_value, fix_y=True)[0] + return pad_to_match( + x, y, fill_value=fill_value, pad_direction=pad_direction, fix_y=True + )[0] diff --git a/model/src/pyrenew/deterministic/deterministicpmf.py b/model/src/pyrenew/deterministic/deterministicpmf.py index 0d62208b..f0f2aa0c 100644 --- a/model/src/pyrenew/deterministic/deterministicpmf.py +++ b/model/src/pyrenew/deterministic/deterministicpmf.py @@ -86,3 +86,15 @@ def sample( """ return self.basevar.sample(**kwargs) + + def size(self) -> int: + """ + Returns the size of the PMF + + Returns + ------- + int + The size of the PMF + """ + + return self.basevar.vars.size diff --git a/model/src/pyrenew/latent/__init__.py b/model/src/pyrenew/latent/__init__.py index 61adc9c0..5f588c97 100644 --- a/model/src/pyrenew/latent/__init__.py +++ b/model/src/pyrenew/latent/__init__.py @@ -8,6 +8,13 @@ compute_infections_from_rt_with_feedback, logistic_susceptibility_adjustment, ) +from pyrenew.latent.infection_seeding_method import ( + InfectionSeedMethod, + SeedInfectionsExponential, + SeedInfectionsFromVec, + SeedInfectionsZeroPad, +) +from pyrenew.latent.infection_seeding_process import InfectionSeedingProcess from pyrenew.latent.infections import Infections from pyrenew.latent.infectionswithfeedback import InfectionsWithFeedback @@ -17,5 +24,10 @@ "logistic_susceptibility_adjustment", "compute_infections_from_rt", "compute_infections_from_rt_with_feedback", + "InfectionSeedMethod", + "SeedInfectionsExponential", + "SeedInfectionsFromVec", + "SeedInfectionsZeroPad", + "InfectionSeedingProcess", "InfectionsWithFeedback", ] diff --git a/model/src/pyrenew/latent/infection_seeding_method.py b/model/src/pyrenew/latent/infection_seeding_method.py new file mode 100644 index 00000000..2bf3db01 --- /dev/null +++ b/model/src/pyrenew/latent/infection_seeding_method.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +# numpydoc ignore=GL08 +from abc import ABCMeta, abstractmethod + +import jax.numpy as jnp +from jax.typing import ArrayLike +from pyrenew.metaclass import RandomVariable + + +class InfectionSeedMethod(metaclass=ABCMeta): + """Method for seeding initial infections in a renewal process.""" + + def __init__(self, n_timepoints: int): + """Default constructor for the ``InfectionSeedMethod`` class. + + Parameters + ---------- + n_timepoints : int + the number of time points to generate seed infections for + + Returns + ------- + None + """ + self.validate(n_timepoints) + self.n_timepoints = n_timepoints + + @staticmethod + def validate(n_timepoints: int) -> None: + """Validate inputs for the ``InfectionSeedMethod`` class constructor + + Parameters + ---------- + n_timepoints : int + the number of time points to generate seed infections for + + Returns + ------- + None + """ + if not isinstance(n_timepoints, int): + raise TypeError( + f"n_timepoints must be an integer. Got {type(n_timepoints)}" + ) + if n_timepoints <= 0: + raise ValueError( + f"n_timepoints must be positive. Got {n_timepoints}" + ) + + @abstractmethod + def seed_infections(self, I_pre_seed: ArrayLike): + """Generate the number of seeded infections at each time point. + + Parameters + ---------- + I_pre_seed : ArrayLike + An array representing some number of latent infections to be used with the specified ``InfectionSeedMethod``. + + Returns + ------- + ArrayLike + An array of length ``n_timepoints`` with the number of seeded infections at each time point. + """ + + def __call__(self, I_pre_seed: ArrayLike): + return self.seed_infections(I_pre_seed) + + +class SeedInfectionsZeroPad(InfectionSeedMethod): + """ + Create a seed infection vector of specified length by + padding a shorter vector with an appropriate number of + zeros at the beginning of the time series. + """ + + def seed_infections(self, I_pre_seed: ArrayLike): + """Pad the seed infections with zeros at the beginning of the time series. + + Parameters + ---------- + I_pre_seed : ArrayLike + An array with seeded infections to be padded with zeros. + + Returns + ------- + ArrayLike + An array of length ``n_timepoints`` with the number of seeded infections at each time point. + """ + if self.n_timepoints < I_pre_seed.size: + raise ValueError( + "I_pre_seed must be no longer than n_timepoints. " + f"Got I_pre_seed of size {I_pre_seed.size} and " + f" n_timepoints of size {self.n_timepoints}." + ) + return jnp.pad(I_pre_seed, (self.n_timepoints - I_pre_seed.size, 0)) + + +class SeedInfectionsFromVec(InfectionSeedMethod): + """Create seed infections from a vector of infections.""" + + def seed_infections(self, I_pre_seed: ArrayLike): + """Create seed infections from a vector of infections. + + Parameters + ---------- + I_pre_seed : ArrayLike + An array with the same length as ``n_timepoints`` to be used as the seed infections. + + Returns + ------- + ArrayLike + An array of length ``n_timepoints`` with the number of seeded infections at each time point. + """ + if I_pre_seed.size != self.n_timepoints: + raise ValueError( + "I_pre_seed must have the same size as n_timepoints. " + f"Got I_pre_seed of size {I_pre_seed.size} " + f"and n_timepoints of size {self.n_timepoints}." + ) + return jnp.array(I_pre_seed) + + +class SeedInfectionsExponential(InfectionSeedMethod): + r"""Generate seed infections according to exponential growth. + + Notes + ----- + The number of incident infections at time `t` is given by: + + .. math:: I(t) = I_p \exp \left( r (t - t_p) \right) + + Where :math:`I_p` is ``I_pre_seed``, :math:`r` is ``rate``, and :math:`t_p` is ``t_pre_seed``. + This ensures that :math:`I(t_p) = I_p`. + We default to ``t_pre_seed = n_timepoints - 1``, so that + ``I_pre_seed`` represents the number of incident infections immediately + before the renewal process begins. + """ + + def __init__( + self, + n_timepoints: int, + rate: RandomVariable, + t_pre_seed: int | None = None, + ): + """Default constructor for the ``SeedInfectionsExponential`` class. + + Parameters + ---------- + n_timepoints : int + the number of time points to generate seed infections for + rate : RandomVariable + A random variable representing the rate of exponential growth + t_pre_seed : int | None, optional + The time point whose number of infections is described by ``I_pre_seed``. Defaults to ``n_timepoints - 1``. + """ + super().__init__(n_timepoints) + self.rate = rate + if t_pre_seed is None: + t_pre_seed = n_timepoints - 1 + self.t_pre_seed = t_pre_seed + + def seed_infections(self, I_pre_seed: ArrayLike): + """Generate seed infections according to exponential growth. + + Parameters + ---------- + I_pre_seed : ArrayLike + An array of size 1 representing the number of infections at time ``t_pre_seed``. + + Returns + ------- + ArrayLike + An array of length ``n_timepoints`` with the number of seeded infections at each time point. + """ + if I_pre_seed.size != 1: + raise ValueError( + f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}." + ) + (rate,) = self.rate.sample() + if rate.size != 1: + raise ValueError( + f"rate must be an array of size 1. Got size {rate.size}." + ) + return I_pre_seed * jnp.exp( + rate * (jnp.arange(self.n_timepoints) - self.t_pre_seed) + ) diff --git a/model/src/pyrenew/latent/infection_seeding_process.py b/model/src/pyrenew/latent/infection_seeding_process.py new file mode 100644 index 00000000..eac040c5 --- /dev/null +++ b/model/src/pyrenew/latent/infection_seeding_process.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# numpydoc ignore=GL08 +import numpyro as npro +from pyrenew.latent.infection_seeding_method import InfectionSeedMethod +from pyrenew.metaclass import RandomVariable + + +class InfectionSeedingProcess(RandomVariable): + """Generate an initial infection history""" + + def __init__( + self, + name, + I_pre_seed_rv: RandomVariable, + infection_seed_method: InfectionSeedMethod, + ) -> None: + """Default class constructor for InfectionSeedingProcess + + Parameters + ---------- + name : str + A name to assign to the RandomVariable. + I_pre_seed_rv : RandomVariable + A RandomVariable representing the number of infections that occur at some time before the renewal process begins. Each `infection_seed_method` uses this random variable in different ways. + infection_seed_method : InfectionSeedMethod + An `InfectionSeedMethod` that generates the seed infections for the renewal process. + + Returns + ------- + None + """ + InfectionSeedingProcess.validate(I_pre_seed_rv, infection_seed_method) + + self.I_pre_seed_rv = I_pre_seed_rv + self.infection_seed_method = infection_seed_method + self.name = name + + @staticmethod + def validate( + I_pre_seed_rv: RandomVariable, + infection_seed_method: InfectionSeedMethod, + ) -> None: + """Validate the input arguments to the InfectionSeedingProcess class constructor + + Parameters + ---------- + I_pre_seed_rv : RandomVariable + A random variable representing the number of infections that occur at some time before the renewal process begins. + infection_seed_method : InfectionSeedMethod + An method to generate the seed infections. + + Returns + ------- + None + """ + if not isinstance(I_pre_seed_rv, RandomVariable): + raise TypeError( + "I_pre_seed_rv must be an instance of RandomVariable" + f"Got {type(I_pre_seed_rv)}" + ) + if not isinstance(infection_seed_method, InfectionSeedMethod): + raise TypeError( + "infection_seed_method must be an instance of InfectionSeedMethod" + f"Got {type(infection_seed_method)}" + ) + + def sample(self) -> tuple: + """Sample the infection seeding process. + + Returns + ------- + tuple + a tuple where the only element is an array with the number of seeded infections at each time point. + """ + (I_pre_seed,) = self.I_pre_seed_rv.sample() + infection_seeding = self.infection_seed_method(I_pre_seed) + npro.deterministic(self.name, infection_seeding) + + return (infection_seeding,) diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index ac58e3c3..26e6fdfd 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -103,25 +103,25 @@ def sample( InfectionsSample Named tuple with "infections". """ - - gen_int_rev = jnp.flip(gen_int) - - if I0.size < gen_int_rev.size: + if I0.size < gen_int.size: raise ValueError( "Initial infections vector must be at least as long as " "the generation interval. " f"Initial infections vector length: {I0.size}, " - f"generation interval length: {gen_int_rev.size}." + f"generation interval length: {gen_int.size}." ) - else: - I0_vec = I0[-gen_int_rev.size :] + + gen_int_rev = jnp.flip(gen_int) + recent_I0 = I0[-gen_int_rev.size :] all_infections = inf.compute_infections_from_rt( - I0=I0_vec, + I0=recent_I0, Rt=Rt, reversed_generation_interval_pmf=gen_int_rev, ) + all_infections = jnp.hstack([I0, all_infections]) + npro.deterministic(self.infections_mean_varname, all_infections) return InfectionsSample(all_infections) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index d1947a3a..784c9001 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -37,6 +37,16 @@ class InfectionsWithFeedback(RandomVariable): This class computes infections, given Rt, initial infections, and generation interval. + Parameters + ---------- + infection_feedback_strength : RandomVariable + Infection feedback strength. + infection_feedback_pmf : RandomVariable + Infection feedback pmf. + infections_mean_varname : str, optional + Name to be assigned to the deterministic variable in the model. + Defaults to "latent_infections". + Notes ----- This function implements the following renewal process (reproduced from @@ -129,7 +139,7 @@ def sample( Reproduction number. I0 : ArrayLike Initial infections, as an array - at least as long as the + at least as long as the generation interval PMF. gen_int : ArrayLike Generation interval PMF. @@ -185,6 +195,9 @@ def sample( reversed_infection_feedback_pmf=inf_fb_pmf_rev, ) + # Appending initial infections to the infections + all_infections = jnp.hstack([I0, all_infections]) + npro.deterministic("Rt_adjusted", Rt_adj) return InfectionsRtFeedbackSample( diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index c3d9380f..872c625f 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -6,6 +6,7 @@ from typing import NamedTuple import jax.numpy as jnp +import pyrenew.arrayutils as au from jax.typing import ArrayLike from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel @@ -68,7 +69,7 @@ def __init__( The infections latent process (passed to RtInfectionsRenewalModel). gen_int : RandomVariable Generation time (passed to RtInfectionsRenewalModel) - IO : RandomVariable + I0 : RandomVariable Initial infections (passed to RtInfectionsRenewalModel) Rt_process : RandomVariable Rt process (passed to RtInfectionsRenewalModel). @@ -250,28 +251,35 @@ def sample( infections=basic_model.latent_infections, **kwargs, ) - - # Sampling the hospital admissions - if self.observation_process is not None: - if (observed_admissions is not None) and (padding > 0): - sampled_na = jnp.repeat(jnp.nan, padding) - - sampled_observed, *_ = self.sample_admissions_process( - predicted=latent[padding:], - observed_admissions=observed_admissions[padding:], + i0_size = len(latent) - n_timepoints + if self.observation_process is None: + sampled = None + else: + if observed_admissions is None: + sampled_obs, *_ = self.sample_admissions_process( + predicted=latent, + observed_admissions=observed_admissions, **kwargs, ) - - sampled = jnp.hstack([sampled_na, sampled_observed]) - else: - sampled, *_ = self.sample_admissions_process( - predicted=latent, - observed_admissions=observed_admissions, + observed_admissions = au.pad_x_to_match_y( + observed_admissions, latent, jnp.nan, pad_direction="start" + ) + + sampled_obs, *_ = self.sample_admissions_process( + predicted=latent[i0_size + padding :], + observed_admissions=observed_admissions[ + i0_size + padding : + ], **kwargs, ) - else: - sampled = None + # this is to accommodate the current version of test_model_hosp_no_obs_model. Not sure if we want this behavior + if sampled_obs is None: + sampled = None + else: + sampled = au.pad_x_to_match_y( + sampled_obs, latent, jnp.nan, pad_direction="start" + ) return HospModelSample( Rt=basic_model.Rt, diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 668e6ab8..897e1d22 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -308,15 +308,7 @@ def sample( # Sampling initial infections i0, *_ = self.sample_i0(**kwargs) - - # Padding i0 to match gen_int - # PADDING SHOULD BE REMOVED ONCE - # https://github.com/CDCgov/multisignal-epi-inference/pull/124 - # is merged. - # SEE ALSO: - # https://github.com/CDCgov/multisignal-epi-inference/pull/123#discussion_r1612337288 - i0 = au.pad_x_to_match_y(x=i0, y=gen_int, fill_value=0.0) - + i0_size = i0.size # Sampling from the latent process latent, *_ = self.sample_infections_latent( Rt=Rt, @@ -325,25 +317,28 @@ def sample( **kwargs, ) - # Using the predicted infections to sample from the observation process - if (observed_infections is not None) and (padding > 0): - sampled_pad = jnp.repeat(jnp.nan, padding) - + if observed_infections is None: sampled_obs, *_ = self.sample_infections_obs( - predicted=latent[padding:], - observed_infections=observed_infections[padding:], + predicted=latent, + observed_infections=observed_infections, **kwargs, ) - - sampled = jnp.hstack([sampled_pad, sampled_obs]) - else: - sampled, *_ = self.sample_infections_obs( - predicted=latent, - observed_infections=observed_infections, + observed_infections = au.pad_x_to_match_y( + observed_infections, latent, jnp.nan, pad_direction="start" + ) + + sampled_obs, *_ = self.sample_infections_obs( + predicted=latent[i0_size + padding :], + observed_infections=observed_infections[i0_size + padding :], **kwargs, ) + sampled = au.pad_x_to_match_y( + sampled_obs, latent, jnp.nan, pad_direction="start" + ) + + Rt = au.pad_x_to_match_y(Rt, latent, jnp.nan, pad_direction="start") return RtInfectionsRenewalSample( Rt=Rt, latent_infections=latent, diff --git a/model/src/test/baseline/test_model_basicrenewal_plot.png b/model/src/test/baseline/test_model_basicrenewal_plot.png index fe12bff6..34396449 100644 Binary files a/model/src/test/baseline/test_model_basicrenewal_plot.png and b/model/src/test/baseline/test_model_basicrenewal_plot.png differ diff --git a/model/src/test/test_datautils.py b/model/src/test/test_arrayutils.py similarity index 72% rename from model/src/test/test_datautils.py rename to model/src/test/test_arrayutils.py index 8ee35b68..ab1995c2 100644 --- a/model/src/test/test_datautils.py +++ b/model/src/test/test_arrayutils.py @@ -35,6 +35,16 @@ def test_arrayutils_pad_to_match(): with pytest.raises(ValueError): x_pad, y_pad = au.pad_to_match(x, y, fix_y=True) + # Verify function works with both padding directions + x_pad, y_pad = au.pad_to_match(x, y, pad_direction="start") + + assert x_pad.size == y_pad.size + assert x_pad.size == 3 + + # Verify function raises an error when pad_direction is not "start" or "end" + with pytest.raises(ValueError): + x_pad, y_pad = au.pad_to_match(x, y, pad_direction="middle") + def test_arrayutils_pad_x_to_match_y(): """ diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py new file mode 100644 index 00000000..0f232de1 --- /dev/null +++ b/model/src/test/test_infection_seeding_method.py @@ -0,0 +1,123 @@ +# numpydoc ignore=GL08 +import numpy as np +import numpy.testing as testing +import pytest +from pyrenew.deterministic import DeterministicVariable +from pyrenew.latent import ( + SeedInfectionsExponential, + SeedInfectionsFromVec, + SeedInfectionsZeroPad, +) + + +def test_seed_infections_exponential(): + """Check that the SeedInfectionsExponential class generates the correct number of infections at each time point.""" + n_timepoints = 10 + rate_RV = DeterministicVariable(0.5, name="rate_RV") + I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV") + default_t_pre_seed = n_timepoints - 1 + + (I_pre_seed,) = I_pre_seed_RV.sample() + (rate,) = rate_RV.sample() + + infections_default_t_pre_seed = SeedInfectionsExponential( + n_timepoints, rate=rate_RV + ).seed_infections(I_pre_seed) + infections_default_t_pre_seed_manual = I_pre_seed * np.exp( + rate * (np.arange(n_timepoints) - default_t_pre_seed) + ) + + testing.assert_array_almost_equal( + infections_default_t_pre_seed, infections_default_t_pre_seed_manual + ) + + # assert that infections at default t_pre_seed is I_pre_seed + assert infections_default_t_pre_seed[default_t_pre_seed] == I_pre_seed + + # test for failure with non-scalar rate or I_pre_seed + rate_RV_2 = DeterministicVariable(np.array([0.5, 0.5]), name="rate_RV") + with pytest.raises(ValueError): + SeedInfectionsExponential( + n_timepoints, rate=rate_RV_2 + ).seed_infections(I_pre_seed) + + I_pre_seed_RV_2 = DeterministicVariable( + np.array([10.0, 10.0]), name="I_pre_seed_RV" + ) + (I_pre_seed_2,) = I_pre_seed_RV_2.sample() + + with pytest.raises(ValueError): + SeedInfectionsExponential(n_timepoints, rate=rate_RV).seed_infections( + I_pre_seed_2 + ) + + # test non-default t_pre_seed + t_pre_seed = 6 + infections_custom_t_pre_seed = SeedInfectionsExponential( + n_timepoints, rate=rate_RV, t_pre_seed=t_pre_seed + ).seed_infections(I_pre_seed) + infections_custom_t_pre_seed_manual = I_pre_seed * np.exp( + rate * (np.arange(n_timepoints) - t_pre_seed) + ) + testing.assert_array_almost_equal( + infections_custom_t_pre_seed, + infections_custom_t_pre_seed_manual, + decimal=5, + ) + + assert infections_custom_t_pre_seed[t_pre_seed] == I_pre_seed + + +def test_seed_infections_zero_pad(): + """Check that the SeedInfectionsZeroPad class generates the correct number of infections at each time point.""" + + n_timepoints = 10 + I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV") + (I_pre_seed,) = I_pre_seed_RV.sample() + + infections = SeedInfectionsZeroPad(n_timepoints).seed_infections( + I_pre_seed + ) + testing.assert_array_equal( + infections, np.pad(I_pre_seed, (n_timepoints - I_pre_seed.size, 0)) + ) + + I_pre_seed_RV_2 = DeterministicVariable( + np.array([10.0, 10.0]), name="I_pre_seed_RV" + ) + (I_pre_seed_2,) = I_pre_seed_RV_2.sample() + + infections_2 = SeedInfectionsZeroPad(n_timepoints).seed_infections( + I_pre_seed_2 + ) + testing.assert_array_equal( + infections_2, + np.pad(I_pre_seed_2, (n_timepoints - I_pre_seed_2.size, 0)), + ) + + # Check that the SeedInfectionsZeroPad class raises an error when the length of I_pre_seed is greater than n_timepoints. + with pytest.raises(ValueError): + SeedInfectionsZeroPad(1).seed_infections(I_pre_seed_2) + + +def test_seed_infections_from_vec(): + """Check that the SeedInfectionsFromVec class generates the correct number of infections at each time point.""" + n_timepoints = 10 + I_pre_seed = np.arange(n_timepoints) + + infections = SeedInfectionsFromVec(n_timepoints).seed_infections( + I_pre_seed + ) + testing.assert_array_equal(infections, I_pre_seed) + + I_pre_seed_2 = np.arange(n_timepoints - 1) + with pytest.raises(ValueError): + SeedInfectionsFromVec(n_timepoints).seed_infections(I_pre_seed_2) + + n_timepoints_float = 10.0 + with pytest.raises(TypeError): + SeedInfectionsFromVec(n_timepoints_float).seed_infections(I_pre_seed) + + n_timepoints_neg = -10 + with pytest.raises(ValueError): + SeedInfectionsFromVec(n_timepoints_neg).seed_infections(I_pre_seed) diff --git a/model/src/test/test_infection_seeding_process.py b/model/src/test/test_infection_seeding_process.py new file mode 100644 index 00000000..f307dcbe --- /dev/null +++ b/model/src/test/test_infection_seeding_process.py @@ -0,0 +1,57 @@ +# numpydoc ignore=GL08 +import jax.numpy as jnp +import numpyro as npro +import numpyro.distributions as dist +import pytest +from pyrenew.deterministic import DeterministicVariable +from pyrenew.latent import ( + InfectionSeedingProcess, + SeedInfectionsExponential, + SeedInfectionsFromVec, + SeedInfectionsZeroPad, +) +from pyrenew.metaclass import DistributionalRV + + +def test_infection_seeding_process(): + """Check that the InfectionSeedingProcess class generates can be sampled from with all InfectionSeedMethods.""" + n_timepoints = 10 + + zero_pad_model = InfectionSeedingProcess( + "zero_pad_model", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints), + ) + + exp_model = InfectionSeedingProcess( + "exp_model", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsExponential( + n_timepoints, DeterministicVariable(0.5, name="rate") + ), + ) + + vec_model = InfectionSeedingProcess( + "vec_model", + DeterministicVariable(jnp.arange(n_timepoints), name="I0"), + SeedInfectionsFromVec(n_timepoints), + ) + + for model in [zero_pad_model, exp_model, vec_model]: + with npro.handlers.seed(rng_seed=1): + model.sample() + + # Check that the InfectionSeedingProcess class raises an error when the wrong type of I0 is passed + with pytest.raises(TypeError): + InfectionSeedingProcess( + "vec_model", + jnp.arange(n_timepoints), + SeedInfectionsFromVec(n_timepoints), + ) + + with pytest.raises(TypeError): + InfectionSeedingProcess( + "vec_model", + DeterministicVariable(jnp.arange(n_timepoints), name="I0"), + 3, + ) diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index 856f17f5..325cf530 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -54,7 +54,7 @@ def _infection_w_feedback_alt( I_vec[t : t + len_gen], np.flip(gen_int) ) - return {"infections": I_vec[-T:], "rt": Rt_adj} + return {"infections": I_vec, "rt": Rt_adj} def test_infectionsrtfeedback(): diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index 87207a61..c6aea454 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -10,7 +10,11 @@ import polars as pl import pytest from pyrenew.deterministic import DeterministicPMF, NullObservation -from pyrenew.latent import Infections +from pyrenew.latent import ( + Infections, + InfectionSeedingProcess, + SeedInfectionsZeroPad, +) from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation @@ -97,7 +101,11 @@ def test_model_basicrenewal_no_obs_model(): with pytest.raises(ValueError): I0 = DistributionalRV(dist=1, name="I0") - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() @@ -116,6 +124,9 @@ def test_model_basicrenewal_no_obs_model(): np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model0_samp = model0.sample(n_timepoints_to_simulate=30) + model0_samp.Rt + model0_samp.latent_infections + model0_samp.sampled_infections # Generating model0.observation_process = NullObservation() @@ -160,7 +171,11 @@ def test_model_basicrenewal_with_obs_model(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() @@ -225,7 +240,11 @@ def test_model_basicrenewal_plot() -> plt.Figure: jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() @@ -264,7 +283,11 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 814ca0dd..173dd1ca 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -13,7 +13,12 @@ DeterministicVariable, NullObservation, ) -from pyrenew.latent import HospitalAdmissions, Infections +from pyrenew.latent import ( + HospitalAdmissions, + Infections, + InfectionSeedingProcess, + SeedInfectionsZeroPad, +) from pyrenew.metaclass import DistributionalRV, RandomVariable from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation @@ -177,7 +182,11 @@ def test_model_hosp_no_obs_model(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() Rt_process = RtRandomWalkProcess() @@ -275,7 +284,11 @@ def test_model_hosp_with_obs_model(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() Rt_process = RtRandomWalkProcess() @@ -356,7 +369,11 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() Rt_process = RtRandomWalkProcess() @@ -449,7 +466,11 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) n_obs_to_generate = 30 - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + ) latent_infections = Infections() Rt_process = RtRandomWalkProcess() @@ -483,18 +504,18 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # Other random components weekday = jnp.array([1, 1, 1, 1, 2, 2]) - weekday = jnp.tile(weekday, 10) weekday = weekday / weekday.sum() - weekday = weekday[:n_obs_to_generate] + weekday = jnp.tile(weekday, 10) + # weekday = weekday[:n_obs_to_generate] + weekday = weekday[:34] weekday = DeterministicVariable(weekday, name="weekday") hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4]) hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10) + hosp_report_prob_dist = hosp_report_prob_dist[:34] hosp_report_prob_dist = hosp_report_prob_dist / hosp_report_prob_dist.sum() - hosp_report_prob_dist = hosp_report_prob_dist[:n_obs_to_generate] - hosp_report_prob_dist = DeterministicVariable( vars=hosp_report_prob_dist, name="hosp_report_prob_dist" ) @@ -523,9 +544,11 @@ def test_model_hosp_with_obs_model_weekday_phosp(): model1_samp = model1.sample(n_timepoints_to_simulate=n_obs_to_generate) obs = jnp.hstack( - [jnp.repeat(jnp.nan, 5), model1_samp.sampled_admissions[5:]] + [ + jnp.repeat(jnp.nan, 5), + model1_samp.sampled_admissions[5 + gen_int.size() :], + ] ) - # Running with padding model1.run( num_warmup=500, diff --git a/scratch/InfectionSeedingProcess_demo.py b/scratch/InfectionSeedingProcess_demo.py new file mode 100644 index 00000000..963a2697 --- /dev/null +++ b/scratch/InfectionSeedingProcess_demo.py @@ -0,0 +1,55 @@ +"""This is a demo""" +import jax.numpy as jnp +import numpyro as npro +import numpyro.distributions as dist +from pyrenew.deterministic import DeterministicVariable +from pyrenew.latent import ( + InfectionSeedingProcess, + SeedInfectionsExponential, + SeedInfectionsFromVec, + SeedInfectionsZeroPad, +) +from pyrenew.metaclass import DistributionalRV + +n_timepoints = 10 +rng_seed = 1 + +I0 = jnp.array([4]) +I0_long = jnp.arange(10) + +# Testing SeedInfections functions with __call__ method +SeedInfectionsZeroPad(n_timepoints).seed_infections(I0) +SeedInfectionsExponential( + n_timepoints, rate=DeterministicVariable(0.5) +).seed_infections(I0) +SeedInfectionsFromVec(n_timepoints).seed_infections(I0_long) + +# Testing SeedInfections functions within InfectionSeedingProcess +zero_pad_model = InfectionSeedingProcess( + DistributionalRV( + dist=dist.LogNormal(loc=jnp.log(80 / 0.05), scale=1.5), name="I0" + ), + SeedInfectionsZeroPad(n_timepoints), +) +with npro.handlers.seed(rng_seed=rng_seed): + zero_pad_dat = zero_pad_model.sample() +zero_pad_dat + +exp_model = InfectionSeedingProcess( + DistributionalRV( + dist=dist.LogNormal(loc=jnp.log(80 / 0.05), scale=1.5), name="I0" + ), + SeedInfectionsExponential( + n_timepoints, DeterministicVariable(0.5), t_pre_seed=0 + ), +) +with npro.handlers.seed(rng_seed=rng_seed): + exp_dat = exp_model.sample() +exp_dat + +vec_model = InfectionSeedingProcess( + DeterministicVariable(I0_long), SeedInfectionsFromVec(n_timepoints) +) +with npro.handlers.seed(rng_seed=rng_seed): + vec_dat = vec_model.sample() +vec_dat diff --git a/src/test/baseline/test_model_basicrenewal_plot.png b/src/test/baseline/test_model_basicrenewal_plot.png new file mode 100644 index 00000000..a80c3069 Binary files /dev/null and b/src/test/baseline/test_model_basicrenewal_plot.png differ