diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 2587fa5f..7737f68a 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -127,7 +127,9 @@ from pyrenew import latent, deterministic, metaclass import jax.numpy as jnp import numpyro.distributions as dist -inf_hosp_int = deterministic.DeterministicPMF(inf_hosp_int, name="inf_hosp_int") +inf_hosp_int = deterministic.DeterministicPMF( + inf_hosp_int, name="inf_hosp_int" +) hosp_rate = metaclass.DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.1), @@ -135,8 +137,8 @@ hosp_rate = metaclass.DistributionalRV( ) latent_hosp = latent.HospitalAdmissions( - infection_to_admission_interval=inf_hosp_int, - infect_hosp_rate_dist=hosp_rate, + infection_to_admission_interval_rv=inf_hosp_int, + infect_hosp_rate_rv=hosp_rate, ) ``` @@ -173,12 +175,12 @@ Notice all the components are `RandomVariable` instances. We can now build the m ```{python} # | label: init-model hosp_model = model.HospitalAdmissionsModel( - latent_infections=latent_inf, - latent_admissions=latent_hosp, - I0=I0, - gen_int=gen_int, - Rt_process=rtproc, - observation_process=obs, + latent_infections_rv=latent_inf, + latent_hosp_admissions_rv=latent_hosp, + I0_rv=I0, + gen_int_rv=gen_int, + Rt_process_rv=rtproc, + hosp_admission_obs_process_rv=obs, ) ``` @@ -208,7 +210,7 @@ axs[0].plot(sim_data.Rt) axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(sim_data.sampled_admissions) +axs[1].plot(sim_data.sampled_observed_hosp_admissions) axs[1].set_ylabel("Infections") axs[1].set_yscale("log") @@ -230,7 +232,7 @@ import jax hosp_model.run( num_samples=2000, num_warmup=2000, - observed_admissions=dat["daily_hosp_admits"].to_numpy(), + observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), ) @@ -244,7 +246,7 @@ We can use the `plot_posterior` method to visualize the results[^capture]: # | label: fig-output-hospital-admissions # | fig-cap: Hospital Admissions posterior distribution out = hosp_model.plot_posterior( - var="predicted_admissions", + var="observed_hosp_admissions", ylab="Hospital Admissions", obs_signal=dat["daily_hosp_admits"].to_numpy(), ) @@ -268,7 +270,7 @@ dat_w_padding = np.hstack((np.repeat(np.nan, days_to_impute), dat_w_padding)) hosp_model.run( num_samples=2000, num_warmup=2000, - observed_admissions=dat_w_padding, + observed_hosp_admissions=dat_w_padding, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), padding=days_to_impute, # Padding the model @@ -281,7 +283,7 @@ And plotting the results: # | label: fig-output-admissions-with-padding # | fig-cap: Hospital Admissions posterior distribution out = hosp_model.plot_posterior( - var="predicted_admissions", + var="observed_hosp_admissions", ylab="Hospital Admissions", obs_signal=dat_w_padding, ) @@ -343,18 +345,18 @@ Notice that the instance's `nweeks` and `len` members are passed during construc ```{python} # | label: latent-hosp-weekday latent_hosp_wday_effect = latent.HospitalAdmissions( - infection_to_admission_interval=inf_hosp_int, - infect_hosp_rate_dist=hosp_rate, - weekday_effect_dist=weekday_effect, + infection_to_admission_interval_rv=inf_hosp_int, + infect_hosp_rate_rv=hosp_rate, + weekday_effect_rv=weekday_effect, ) hosp_model_weekday = model.HospitalAdmissionsModel( - latent_infections=latent_inf, - latent_admissions=latent_hosp_wday_effect, - I0=I0, - gen_int=gen_int, - Rt_process=rtproc, - observation_process=obs, + latent_infections_rv=latent_inf, + latent_hosp_admissions_rv=latent_hosp_wday_effect, + I0_rv=I0, + gen_int_rv=gen_int, + Rt_process_rv=rtproc, + hosp_admission_obs_process_rv=obs, ) ``` @@ -365,7 +367,7 @@ Running the model (with the same padding as before): hosp_model_weekday.run( num_samples=2000, num_warmup=2000, - observed_admissions=dat_w_padding, + observed_hosp_admissions=dat_w_padding, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), padding=days_to_impute, @@ -378,7 +380,7 @@ And plotting the results: # | label: fig-output-admissions-padding-and-weekday # | fig-cap: Hospital Admissions posterior distribution out = hosp_model_weekday.plot_posterior( - var="predicted_admissions", + var="observed_hosp_admissions", ylab="Hospital Admissions", obs_signal=dat_w_padding, ) diff --git a/model/docs/extending_pyrenew.qmd b/model/docs/extending_pyrenew.qmd index 509c44ea..a6306a77 100644 --- a/model/docs/extending_pyrenew.qmd +++ b/model/docs/extending_pyrenew.qmd @@ -63,11 +63,11 @@ With all the components defined, we can build the model: ```{python} # | label: build1 model0 = RtInfectionsRenewalModel( - gen_int=gen_int, - I0=I0, - latent_infections=latent_infections, - Rt_process=rt, - observation_process=None, + gen_int_rv=gen_int, + I0_rv=I0, + latent_infections_rv=latent_infections, + Rt_process_rv=rt, + infection_obs_process_rv=None, ) ``` @@ -239,11 +239,11 @@ latent_infections2 = InfFeedback( ) model1 = RtInfectionsRenewalModel( - gen_int=gen_int, - I0=I0, - latent_infections=latent_infections2, - Rt_process=rt, - observation_process=None, + gen_int_rv=gen_int, + I0_rv=I0, + latent_infections_rv=latent_infections2, + Rt_process_rv=rt, + infection_obs_process_rv=None, ) # Sampling and fitting model 0 (with no obs for infections) diff --git a/model/docs/getting_started.qmd b/model/docs/getting_started.qmd index 0b02c98d..4693eff1 100644 --- a/model/docs/getting_started.qmd +++ b/model/docs/getting_started.qmd @@ -113,11 +113,11 @@ With these five pieces, we can build the basic renewal model as an instance of t ```{python} # | label: model-creation model1 = RtInfectionsRenewalModel( - gen_int=gen_int, - I0=I0, - Rt_process=rt_proc, - latent_infections=latent_infections, - observation_process=observation_process, + gen_int_rv=gen_int, + I0_rv=I0, + Rt_process_rv=rt_proc, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observation_process, ) ``` @@ -167,7 +167,7 @@ axs[0].plot(sim_data.Rt) axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(sim_data.sampled_infections) +axs[1].plot(sim_data.sampled_observed_infections) axs[1].set_ylabel("Infections") fig.suptitle("Basic renewal model") @@ -185,7 +185,7 @@ import jax model1.run( num_warmup=2000, num_samples=1000, - observed_infections=sim_data.sampled_infections, + observed_infections=sim_data.sampled_observed_infections, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), ) diff --git a/model/docs/pyrenew_demo.qmd b/model/docs/pyrenew_demo.qmd index ea8ff864..a704eeb8 100644 --- a/model/docs/pyrenew_demo.qmd +++ b/model/docs/pyrenew_demo.qmd @@ -37,7 +37,7 @@ import numpyro.distributions as dist from pyrenew.process import SimpleRandomWalkProcess ``` -To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the duration of the block that follows. Inside the `with` block, the `q_samp = q.sample(duration=100)` generates the sample instance over a duration of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. +To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q.sample(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. ```{python} # | label: fig-randwalk @@ -45,7 +45,7 @@ To understand the simple random walk process underlying the sampling within the np.random.seed(3312) q = SimpleRandomWalkProcess(dist.Normal(0, 0.001)) with seed(rng_seed=np.random.randint(0, 1000)): - q_samp = q.sample(duration=100) + q_samp = q.sample(n_timepoints=100) plt.plot(np.exp(q_samp[0])) ``` @@ -112,8 +112,8 @@ inf_hosp_int = DeterministicPMF( ) latent_admissions = HospitalAdmissions( - infection_to_admission_interval=inf_hosp_int, - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp_int, + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) @@ -131,12 +131,12 @@ The `HospitalAdmissionsModel` is then initialized using the initial conditions j ```{python} # Initializing the model hospmodel = HospitalAdmissionsModel( - gen_int=gen_int, - I0=I0, - latent_admissions=latent_admissions, - observation_process=admissions_process, - latent_infections=latent_infections, - Rt_process=Rt_process, + gen_int_rv=gen_int, + I0_rv=I0, + latent_hosp_admissions_rv=latent_admissions, + hosp_admission_obs_process_rv=admissions_process, + latent_infections_rv=latent_infections, + Rt_process_rv=Rt_process, ) ``` @@ -151,13 +151,13 @@ x Visualizations of the single model output show (top) infections over the 30 time steps, (middle) hospital admissions over the 30 time steps, and (bottom) ```{python} -#| label: fig-hosp -#| fig-cap: Infections +# | label: fig-hosp +# | fig-cap: Infections fig, ax = plt.subplots(nrows=3, sharex=True) ax[0].plot(x.latent_infections) -ax[0].set_ylim([1/5, 5]) -ax[1].plot(x.latent_admissions) -ax[2].plot(x.sampled_admissions, 'o') +ax[0].set_ylim([1 / 5, 5]) +ax[1].plot(x.latent_hosp_admissions) +ax[2].plot(x.sampled_observed_hosp_admissions, "o") for axis in ax[:-1]: axis.set_yscale("log") ``` @@ -169,7 +169,7 @@ To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC hospmodel.run( num_warmup=1000, num_samples=1000, - observed_admissions=x.sampled_admissions, + observed_hosp_admissions=x.sampled_observed_hosp_admissions, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), ) diff --git a/model/pyproject.toml b/model/pyproject.toml index 9ab64318..7d44a340 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -32,6 +32,28 @@ pytest-cov = "^5.0.0" pytest-mpl = "^0.17.0" numpydoc = "^1.7.0" +[tool.numpydoc_validation] +checks = [ + "GL03", + "GL08", + "SS01", + "PR03", + "PR04", + "PR07", + "RT01" +] +ignore = [ + "ES01", + "SA01", + "EX01", + "SS06", + "RT05" +] +exclude = [ # don't report on objects that match any of these regex + '\.undocumented_method$', + '\.__repr__$', + '\.__call__$' +] [build-system] requires = ["poetry-core"] diff --git a/model/src/pyrenew/deterministic/nullrv.py b/model/src/pyrenew/deterministic/nullrv.py index 5640f1d8..435c68b6 100644 --- a/model/src/pyrenew/deterministic/nullrv.py +++ b/model/src/pyrenew/deterministic/nullrv.py @@ -129,7 +129,7 @@ def validate() -> None: def sample( self, - predicted: ArrayLike, + mu: ArrayLike, obs: ArrayLike | None = None, name: str | None = None, **kwargs, @@ -139,8 +139,8 @@ def sample( Parameters ---------- - predicted : ArrayLike - Rate parameter of the Poisson distribution. + mu : ArrayLike + Unused parameter, represents mean of non-null distributions obs : ArrayLike, optional Observed data. Defaults to None. name : str, optional diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 217108e9..a51cedaf 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple import jax.numpy as jnp import numpyro as npro @@ -20,15 +20,15 @@ class HospAdmissionsSample(NamedTuple): ---------- infection_hosp_rate : float, optional The infection-to-hospitalization rate. Defaults to None. - predicted : ArrayLike or None - The predicted number of hospital admissions. Defaults to None. + observed_hosp_admissions : ArrayLike or None + The observed number of hospital admissions. Defaults to None. """ infection_hosp_rate: float | None = None - predicted: ArrayLike | None = None + observed_hosp_admissions: ArrayLike | None = None def __repr__(self): - return f"HospAdmissionsSample(infection_hosp_rate={self.IRH}, predicted={self.predicted})" + return f"HospAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, observed_hosp_admissions={self.observed_hosp_admissions})" class HospitalAdmissions(RandomVariable): @@ -63,29 +63,29 @@ class HospitalAdmissions(RandomVariable): def __init__( self, - infection_to_admission_interval: RandomVariable, - infect_hosp_rate_dist: RandomVariable, - admissions_predicted_varname: str = "predicted_admissions", - weekday_effect_dist: Optional[RandomVariable] = None, - hosp_report_prob_dist: Optional[RandomVariable] = None, + infection_to_admission_interval_rv: RandomVariable, + infect_hosp_rate_rv: RandomVariable, + observed_hosp_admissions_varname: str = "observed_hosp_admissions", + weekday_effect_rv: RandomVariable | None = None, + hosp_report_prob_rv: RandomVariable | None = None, ) -> None: """ Default constructor Parameters ---------- - infection_to_admission_interval : RandomVariable + infection_to_admission_interval_rv : RandomVariable pmf for reporting (informing) hospital admissions (see pyrenew.observations.Deterministic). - infect_hosp_rate_dist : RandomVariable - Infection to hospitalization rate distribution. - admissions_predicted_varname : str + infect_hosp_rate_rv : RandomVariable + Infection to hospitalization rate random variable. + observed_hosp_admissions_varname : str Name to assign to the deterministic component in numpyro of - predicted hospital admissions. - weekday_effect_dist : RandomVariable, optional + observed hospital admissions. + weekday_effect_rv : RandomVariable, optional Weekday effect. - hosp_report_prob_dist : RandomVariable, optional - Distribution or fixed value for the hospital admission reporting + hosp_report_prob_rv : RandomVariable, optional + Random variable for the hospital admission reporting probability. Defaults to 1 (full reporting). Returns @@ -93,31 +93,34 @@ def __init__( None """ - if weekday_effect_dist is None: - weekday_effect_dist = DeterministicVariable(1, "weekday_effect") - if hosp_report_prob_dist is None: - hosp_report_prob_dist = DeterministicVariable( - 1, "hosp_report_prob" - ) + if weekday_effect_rv is None: + weekday_effect_rv = DeterministicVariable(1, "weekday_effect") + if hosp_report_prob_rv is None: + hosp_report_prob_rv = DeterministicVariable(1, "hosp_report_prob") HospitalAdmissions.validate( - infect_hosp_rate_dist, - weekday_effect_dist, - hosp_report_prob_dist, + infect_hosp_rate_rv, + weekday_effect_rv, + hosp_report_prob_rv, ) - self.admissions_predicted_varname = admissions_predicted_varname + self.observed_hosp_admissions_varname = ( + observed_hosp_admissions_varname + ) - self.infect_hosp_rate_dist = infect_hosp_rate_dist - self.weekday_effect_dist = weekday_effect_dist - self.hosp_report_prob_dist = hosp_report_prob_dist - self.infection_to_admission_interval = infection_to_admission_interval + self.infect_hosp_rate_rv = infect_hosp_rate_rv + self.weekday_effect_rv = weekday_effect_rv + self.hosp_report_prob_rv = hosp_report_prob_rv + self.infection_to_admission_interval_rv = ( + infection_to_admission_interval_rv + ) + # Why isn't infection_to_admission_interval_rv validated? @staticmethod def validate( - infect_hosp_rate_dist: Any, - weekday_effect_dist: Any, - hosp_report_prob_dist: Any, + infect_hosp_rate_rv: Any, + weekday_effect_rv: Any, + hosp_report_prob_rv: Any, ) -> None: """ Validates that the IHR, weekday effects, and probability of being @@ -125,11 +128,11 @@ def validate( Parameters ---------- - infect_hosp_rate_dist : Any + infect_hosp_rate_rv : Any Possibly incorrect input for infection to hospitalization rate distribution. - weekday_effect_dist : Any + weekday_effect_rv : Any Possibly incorrect input for weekday effect. - hosp_report_prob_dist : Any + hosp_report_prob_rv : Any Possibly incorrect input for distribution or fixed value for the hospital admission reporting probability. @@ -143,15 +146,15 @@ def validate( If the object `distr` is not an instance of `dist.Distribution`, indicating that the validation has failed. """ - assert isinstance(infect_hosp_rate_dist, RandomVariable) - assert isinstance(weekday_effect_dist, RandomVariable) - assert isinstance(hosp_report_prob_dist, RandomVariable) + assert isinstance(infect_hosp_rate_rv, RandomVariable) + assert isinstance(weekday_effect_rv, RandomVariable) + assert isinstance(hosp_report_prob_rv, RandomVariable) return None def sample( self, - latent: ArrayLike, + latent_infections: ArrayLike, **kwargs, ) -> HospAdmissionsSample: """ @@ -170,32 +173,37 @@ def sample( HospAdmissionsSample """ - infection_hosp_rate, *_ = self.infect_hosp_rate_dist.sample(**kwargs) + infection_hosp_rate, *_ = self.infect_hosp_rate_rv.sample(**kwargs) - infection_hosp_rate_t = infection_hosp_rate * latent + infection_hosp_rate_t = infection_hosp_rate * latent_infections ( - infection_to_admission_interval, + infection_to_admission_interval_rv, *_, - ) = self.infection_to_admission_interval.sample(**kwargs) + ) = self.infection_to_admission_interval_rv.sample(**kwargs) - predicted_admissions = jnp.convolve( - infection_hosp_rate_t, infection_to_admission_interval, mode="full" + observed_hosp_admissions = jnp.convolve( + infection_hosp_rate_t, + infection_to_admission_interval_rv, + mode="full", )[: infection_hosp_rate_t.shape[0]] # Applying weekday effect - predicted_admissions = ( - predicted_admissions * self.weekday_effect_dist.sample(**kwargs)[0] + observed_hosp_admissions = ( + observed_hosp_admissions + * self.weekday_effect_rv.sample(**kwargs)[0] ) # Applying probability of hospitalization effect - predicted_admissions = ( - predicted_admissions - * self.hosp_report_prob_dist.sample(**kwargs)[0] + observed_hosp_admissions = ( + observed_hosp_admissions + * self.hosp_report_prob_rv.sample(**kwargs)[0] ) npro.deterministic( - self.admissions_predicted_varname, predicted_admissions + self.observed_hosp_admissions_varname, observed_hosp_admissions ) - return HospAdmissionsSample(infection_hosp_rate, predicted_admissions) + return HospAdmissionsSample( + infection_hosp_rate, observed_hosp_admissions + ) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 872c625f..99ccc4fb 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -22,22 +22,28 @@ class HospModelSample(NamedTuple): The reproduction number over time. Defaults to None. latent_infections : ArrayLike | None, optional The estimated number of new infections over time. Defaults to None. - IHR : float | None, optional + infection_hosp_rate : float | None, optional The infected hospitalization rate. Defaults to None. - latent_admissions : ArrayLike | None, optional + latent_hosp_admissions : ArrayLike | None, optional The estimated latent hospitalizations. Defaults to None. - sampled_admissions : ArrayLike | None, optional + sampled_observed_hosp_admissions : ArrayLike | None, optional The sampled or observed hospital admissions. Defaults to None. """ Rt: float | None = None latent_infections: ArrayLike | None = None - IHR: float | None = None - latent_admissions: ArrayLike | None = None - sampled_admissions: ArrayLike | None = None + infection_hosp_rate: float | None = None + latent_hosp_admissions: ArrayLike | None = None + sampled_observed_hosp_admissions: ArrayLike | None = None def __repr__(self): - return f"HospModelSample(Rt={self.Rt}, latent_infections={self.latent_infections}, IHR={self.IHR}, latent_admissions={self.latent_admissions}, sampled_admissions={self.sampled_admissions})" + return ( + f"HospModelSample(Rt={self.Rt}, " + f"latent_infections={self.latent_infections}, " + f"infection_hosp_rate={self.infection_hosp_rate}, " + f"latent_hosp_admissions={self.latent_hosp_admissions}, " + f"sampled_observed_hosp_admissions={self.sampled_observed_hosp_admissions}" + ) class HospitalAdmissionsModel(Model): @@ -51,29 +57,29 @@ class HospitalAdmissionsModel(Model): def __init__( self, - latent_admissions: RandomVariable, - latent_infections: RandomVariable, - gen_int: RandomVariable, - I0: RandomVariable, - Rt_process: RandomVariable, - observation_process: RandomVariable, + latent_hosp_admissions_rv: RandomVariable, + latent_infections_rv: RandomVariable, + gen_int_rv: RandomVariable, + I0_rv: RandomVariable, + Rt_process_rv: RandomVariable, + hosp_admission_obs_process_rv: RandomVariable, ) -> None: # numpydoc ignore=PR04 """ Default constructor Parameters ---------- - latent_admissions : RandomVariable + latent_hosp_admissions_rv : RandomVariable Latent process for the hospital admissions. - latent_infections : RandomVariable + latent_infections_rv : RandomVariable The infections latent process (passed to RtInfectionsRenewalModel). - gen_int : RandomVariable + gen_int_rv : RandomVariable Generation time (passed to RtInfectionsRenewalModel) - I0 : RandomVariable + I0_rv : RandomVariable Initial infections (passed to RtInfectionsRenewalModel) - Rt_process : RandomVariable + Rt_process_rv : RandomVariable Rt process (passed to RtInfectionsRenewalModel). - observation_process : RandomVariable, optional + hosp_admission_obs_process_rv : RandomVariable, optional Observation process for the hospital admissions. Returns @@ -81,30 +87,32 @@ def __init__( None """ self.basic_renewal = RtInfectionsRenewalModel( - gen_int=gen_int, - I0=I0, - latent_infections=latent_infections, - observation_process=None, - Rt_process=Rt_process, + gen_int_rv=gen_int_rv, + I0_rv=I0_rv, + latent_infections_rv=latent_infections_rv, + infection_obs_process_rv=None, # why is this None? + Rt_process_rv=Rt_process_rv, ) HospitalAdmissionsModel.validate( - latent_admissions, observation_process + latent_hosp_admissions_rv, hosp_admission_obs_process_rv ) - self.latent_admissions = latent_admissions - self.observation_process = observation_process + self.latent_hosp_admissions_rv = latent_hosp_admissions_rv + self.hosp_admission_obs_process_rv = hosp_admission_obs_process_rv @staticmethod - def validate(latent_admissions, observation_process) -> None: + def validate( + latent_hosp_admissions_rv, hosp_admission_obs_process_rv + ) -> None: """ Verifies types and status (RV) of latent and observed hospital admissions Parameters ---------- - latent_admissions : ArrayLike + latent_hosp_admissions_rv : RandomVariable The latent process for the hospital admissions. - observation_process : ArrayLike + hosp_admission_obs_process_rv : RandomVariable The observed hospital admissions. Returns @@ -115,13 +123,15 @@ def validate(latent_admissions, observation_process) -> None: -------- _assert_sample_and_rtype : Perform type-checking and verify RV """ - _assert_sample_and_rtype(latent_admissions, skip_if_none=False) - _assert_sample_and_rtype(observation_process, skip_if_none=True) + _assert_sample_and_rtype(latent_hosp_admissions_rv, skip_if_none=False) + _assert_sample_and_rtype( + hosp_admission_obs_process_rv, skip_if_none=True + ) return None - def sample_latent_admissions( + def sample_latent_hosp_admissions( self, - infections: ArrayLike, + latent_infections: ArrayLike, **kwargs, ) -> tuple: """ @@ -141,18 +151,18 @@ def sample_latent_admissions( See Also -------- - latent_admissions.sample : For sampling latent hospital admissions + latent_hosp_admissions.sample : For sampling latent hospital admissions """ - return self.latent_admissions.sample( - latent=infections, + return self.latent_hosp_admissions_rv.sample( + latent_infections=latent_infections, **kwargs, ) def sample_admissions_process( self, - predicted: ArrayLike, - observed_admissions: ArrayLike, + observed_hosp_admissions_mean: ArrayLike, + observed_hosp_admissions: ArrayLike, name: str | None = None, **kwargs, ) -> tuple: @@ -161,8 +171,8 @@ def sample_admissions_process( Parameters ---------- - predicted : ArrayLike - The predicted hospital admissions. + observed_hosp_admissions_mean : ArrayLike + The mean of the predictive distribution for observed hospital admissions. obs : ArrayLike The observed hospitalization data (to fit). name : str, optional @@ -176,9 +186,9 @@ def sample_admissions_process( tuple """ - return self.observation_process.sample( - predicted=predicted, - obs=observed_admissions, + return self.hosp_admission_obs_process_rv.sample( + mu=observed_hosp_admissions_mean, + obs=observed_hosp_admissions, name=name, **kwargs, ) @@ -186,7 +196,7 @@ def sample_admissions_process( def sample( self, n_timepoints_to_simulate: int | None = None, - observed_admissions: ArrayLike | None = None, + observed_hosp_admissions: ArrayLike | None = None, padding: int = 0, **kwargs, ) -> HospModelSample: @@ -197,7 +207,7 @@ def sample( ---------- n_timepoints_to_simulate : int, optional Number of timepoints to sample (passed to the basic renewal model). - observed_admissions : ArrayLike, optional + observed_hosp_admissions : ArrayLike, optional The observed hospitalization data (passed to the basic renewal model). Defaults to None (simulation, rather than fit). padding : int, optional @@ -214,23 +224,26 @@ def sample( See Also -------- basic_renewal.sample : For sampling the basic renewal model - sample_latent_admissions : To sample latent hospitalization process + sample_latent_hosp_admissions : To sample latent hospitalization process sample_observed_admissions : For sampling observed hospital admissions """ - if n_timepoints_to_simulate is None and observed_admissions is None: + if ( + n_timepoints_to_simulate is None + and observed_hosp_admissions is None + ): raise ValueError( - "Either n_timepoints_to_simulate or observed_admissions " + "Either n_timepoints_to_simulate or observed_hosp_admissions " "must be passed." ) elif ( n_timepoints_to_simulate is not None - and observed_admissions is not None + and observed_hosp_admissions is not None ): raise ValueError( - "Cannot pass both n_timepoints_to_simulate and observed_admissions." + "Cannot pass both n_timepoints_to_simulate and observed_hosp_admissions." ) elif n_timepoints_to_simulate is None: - n_timepoints = len(observed_admissions) + n_timepoints = len(observed_hosp_admissions) else: n_timepoints = n_timepoints_to_simulate @@ -245,46 +258,50 @@ def sample( # Sampling the latent hospital admissions ( infection_hosp_rate, - latent, + latent_hosp_admissions, *_, - ) = self.sample_latent_admissions( - infections=basic_model.latent_infections, + ) = self.sample_latent_hosp_admissions( + latent_infections=basic_model.latent_infections, **kwargs, ) - i0_size = len(latent) - n_timepoints - if self.observation_process is None: - sampled = None + i0_size = len(latent_hosp_admissions) - n_timepoints + if self.hosp_admission_obs_process_rv is None: + sampled_observed_hosp_admissions = None else: - if observed_admissions is None: - sampled_obs, *_ = self.sample_admissions_process( - predicted=latent, - observed_admissions=observed_admissions, + if observed_hosp_admissions is None: + ( + sampled_observed_hosp_admissions, + *_, + ) = self.sample_admissions_process( + observed_hosp_admissions_mean=latent_hosp_admissions, + observed_hosp_admissions=observed_hosp_admissions, **kwargs, ) else: - observed_admissions = au.pad_x_to_match_y( - observed_admissions, latent, jnp.nan, pad_direction="start" + observed_hosp_admissions = au.pad_x_to_match_y( + observed_hosp_admissions, + latent_hosp_admissions, + jnp.nan, + pad_direction="start", ) - sampled_obs, *_ = self.sample_admissions_process( - predicted=latent[i0_size + padding :], - observed_admissions=observed_admissions[ + ( + sampled_observed_hosp_admissions, + *_, + ) = self.sample_admissions_process( + observed_hosp_admissions_mean=latent_hosp_admissions[ + i0_size + padding : + ], + observed_hosp_admissions=observed_hosp_admissions[ i0_size + padding : ], **kwargs, ) - # this is to accommodate the current version of test_model_hosp_no_obs_model. Not sure if we want this behavior - if sampled_obs is None: - sampled = None - else: - sampled = au.pad_x_to_match_y( - sampled_obs, latent, jnp.nan, pad_direction="start" - ) return HospModelSample( Rt=basic_model.Rt, latent_infections=basic_model.latent_infections, - IHR=infection_hosp_rate, - latent_admissions=latent, - sampled_admissions=sampled, + infection_hosp_rate=infection_hosp_rate, + latent_hosp_admissions=latent_hosp_admissions, + sampled_observed_hosp_admissions=sampled_observed_hosp_admissions, ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 897e1d22..3e395c03 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -23,19 +23,19 @@ class RtInfectionsRenewalSample(NamedTuple): The reproduction number over time. Defaults to None. latent_infections : ArrayLike | None, optional The estimated latent infections. Defaults to None. - sampled_infections : ArrayLike | None, optional + sampled_observed_infections : ArrayLike | None, optional The sampled infections. Defaults to None. """ Rt: float | None = None latent_infections: ArrayLike | None = None - sampled_infections: ArrayLike | None = None + sampled_observed_infections: ArrayLike | None = None def __repr__(self): return ( f"RtInfectionsRenewalSample(Rt={self.Rt}, " f"latent_infections={self.latent_infections}, " - f"sampled_infections={self.sampled_infections})" + f"sampled_observed_infections={self.sampled_observed_infections})" ) @@ -49,28 +49,28 @@ class RtInfectionsRenewalModel(Model): def __init__( self, - latent_infections: RandomVariable, - gen_int: RandomVariable, - I0: RandomVariable, - Rt_process: RandomVariable, - observation_process: RandomVariable = None, + latent_infections_rv: RandomVariable, + gen_int_rv: RandomVariable, + I0_rv: RandomVariable, + Rt_process_rv: RandomVariable, + infection_obs_process_rv: RandomVariable = None, ) -> None: """ Default constructor Parameters ---------- - latent_infections : RandomVariable + latent_infections_rv : RandomVariable Infections latent process (e.g., pyrenew.latent.Infections.). - gen_int : RandomVariable + gen_int_rv : RandomVariable The generation interval. - I0 : RandomVariable + I0_rv : RandomVariable The initial infections. - Rt_process : RandomVariable + Rt_process_rv : RandomVariable The sample function of the process should return a tuple where the first element is the drawn Rt. - observation_process : RandomVariable + infection_obs_process_rv : RandomVariable Infections observation process (e.g., pyrenew.observations.Poisson.). @@ -79,30 +79,30 @@ def __init__( None """ - if observation_process is None: - observation_process = NullObservation() + if infection_obs_process_rv is None: + infection_obs_process_rv = NullObservation() RtInfectionsRenewalModel.validate( - gen_int=gen_int, - i0=I0, - latent_infections=latent_infections, - observation_process=observation_process, - Rt_process=Rt_process, + gen_int_rv=gen_int_rv, + I0_rv=I0_rv, + latent_infections_rv=latent_infections_rv, + infection_obs_process_rv=infection_obs_process_rv, + Rt_process_rv=Rt_process_rv, ) - self.gen_int = gen_int - self.i0 = I0 - self.latent_infections = latent_infections - self.observation_process = observation_process - self.Rt_process = Rt_process + self.gen_int_rv = gen_int_rv + self.I0_rv = I0_rv + self.latent_infections_rv = latent_infections_rv + self.infection_obs_process_rv = infection_obs_process_rv + self.Rt_process_rv = Rt_process_rv @staticmethod def validate( - gen_int: any, - i0: any, - latent_infections: any, - observation_process: any, - Rt_process: any, + gen_int_rv: any, + I0_rv: any, + latent_infections_rv: any, + infection_obs_process_rv: any, + Rt_process_rv: any, ) -> None: """ Verifies types and status (RV) of the generation interval, initial @@ -110,15 +110,15 @@ def validate( Parameters ---------- - gen_int : any + gen_int_rv : any The generation interval. Expects RandomVariable. - i0 : any + I0_rv : any The initial infections. Expects RandomVariable. - latent_infections : any + latent_infections_rv : any Infections latent process. Expects RandomVariable. - observation_process : any + infection_obs_process_rv : any Infections observation process. Expects RandomVariable. - Rt_process : any + Rt_process_rv : any The sample function of the process should return a tuple where the first element is the drawn Rt. Expects RandomVariable. @@ -130,11 +130,11 @@ def validate( -------- _assert_sample_and_rtype : Perform type-checking and verify RV """ - _assert_sample_and_rtype(gen_int, skip_if_none=False) - _assert_sample_and_rtype(i0, skip_if_none=False) - _assert_sample_and_rtype(latent_infections, skip_if_none=False) - _assert_sample_and_rtype(observation_process, skip_if_none=False) - _assert_sample_and_rtype(Rt_process, skip_if_none=False) + _assert_sample_and_rtype(gen_int_rv, skip_if_none=False) + _assert_sample_and_rtype(I0_rv, skip_if_none=False) + _assert_sample_and_rtype(latent_infections_rv, skip_if_none=False) + _assert_sample_and_rtype(infection_obs_process_rv, skip_if_none=False) + _assert_sample_and_rtype(Rt_process_rv, skip_if_none=False) return None def sample_rt( @@ -154,7 +154,7 @@ def sample_rt( ------- tuple """ - return self.Rt_process.sample(**kwargs) + return self.Rt_process_rv.sample(**kwargs) def sample_gen_int( self, @@ -173,9 +173,9 @@ def sample_gen_int( ------- tuple """ - return self.gen_int.sample(**kwargs) + return self.gen_int_rv.sample(**kwargs) - def sample_i0( + def sample_I0( self, **kwargs, ) -> tuple: @@ -186,13 +186,13 @@ def sample_i0( ---------- **kwargs : dict, optional Additional keyword arguments passed through to internal - sample_i0 calls, should there be any. + sample_I0 calls, should there be any. Returns ------- tuple """ - return self.i0.sample(**kwargs) + return self.I0_rv.sample(**kwargs) def sample_infections_latent( self, @@ -211,11 +211,11 @@ def sample_infections_latent( ------- tuple """ - return self.latent_infections.sample(**kwargs) + return self.latent_infections_rv.sample(**kwargs) - def sample_infections_obs( + def sample_infection_obs_process( self, - predicted: ArrayLike, + observed_infections_mean: ArrayLike, observed_infections: ArrayLike | None = None, name: str | None = None, **kwargs, @@ -227,8 +227,8 @@ def sample_infections_obs( Parameters ---------- - predicted : ArrayLike - The predicted infecteds. + observed_infections_mean : ArrayLike + The mean of the observed infections distribution. observed_infections : ArrayLike | None, optional The observed infection values, if any, for inference. Defaults to None. name : str | None, optional @@ -241,8 +241,8 @@ def sample_infections_obs( ------- tuple """ - return self.observation_process.sample( - predicted=predicted, + return self.infection_obs_process_rv.sample( + mu=observed_infections_mean, obs=observed_infections, name=name, **kwargs, @@ -299,7 +299,7 @@ def sample( # Sampling from Rt (possibly with a given Rt, depending on # the Rt_process (RandomVariable) object.) Rt, *_ = self.sample_rt( - duration=n_timepoints, + n_timepoints=n_timepoints, **kwargs, ) @@ -307,40 +307,56 @@ def sample( gen_int, *_ = self.sample_gen_int(**kwargs) # Sampling initial infections - i0, *_ = self.sample_i0(**kwargs) - i0_size = i0.size + I0, *_ = self.sample_I0(**kwargs) + I0_size = I0.size # Sampling from the latent process - latent, *_ = self.sample_infections_latent( + latent_infections, *_ = self.sample_infections_latent( Rt=Rt, gen_int=gen_int, - I0=i0, + I0=I0, **kwargs, ) if observed_infections is None: - sampled_obs, *_ = self.sample_infections_obs( - predicted=latent, + ( + sampled_observed_infections, + *_, + ) = self.sample_infection_obs_process( + observed_infections_mean=latent_infections, observed_infections=observed_infections, **kwargs, ) else: observed_infections = au.pad_x_to_match_y( - observed_infections, latent, jnp.nan, pad_direction="start" + observed_infections, + latent_infections, + jnp.nan, + pad_direction="start", ) - sampled_obs, *_ = self.sample_infections_obs( - predicted=latent[i0_size + padding :], - observed_infections=observed_infections[i0_size + padding :], + ( + sampled_observed_infections, + *_, + ) = self.sample_infection_obs_process( + observed_infections_mean=latent_infections[ + I0_size + padding : + ], + observed_infections=observed_infections[I0_size + padding :], **kwargs, ) - sampled = au.pad_x_to_match_y( - sampled_obs, latent, jnp.nan, pad_direction="start" + sampled_observed_infections = au.pad_x_to_match_y( + sampled_observed_infections, + latent_infections, + jnp.nan, + pad_direction="start", ) - Rt = au.pad_x_to_match_y(Rt, latent, jnp.nan, pad_direction="start") + Rt = au.pad_x_to_match_y( + Rt, latent_infections, jnp.nan, pad_direction="start" + ) return RtInfectionsRenewalSample( Rt=Rt, - latent_infections=latent, - sampled_infections=sampled, + latent_infections=latent_infections, + sampled_observed_infections=sampled_observed_infections, ) diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index 670f56bd..48710592 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -62,7 +62,7 @@ def __init__( def sample( self, - predicted: ArrayLike, + mu: ArrayLike, obs: ArrayLike | None = None, name: str | None = None, **kwargs, @@ -72,7 +72,7 @@ def sample( Parameters ---------- - predicted : ArrayLike + mu : ArrayLike Mean parameter of the negative binomial distribution. obs : ArrayLike, optional Observed data, by default None. @@ -95,7 +95,7 @@ def sample( numpyro.sample( name=name, fn=dist.NegativeBinomial2( - mean=predicted + self.eps, + mean=mu + self.eps, concentration=concentration, ), obs=obs, diff --git a/model/src/pyrenew/observation/poisson.py b/model/src/pyrenew/observation/poisson.py index 400acd9f..c641cf76 100644 --- a/model/src/pyrenew/observation/poisson.py +++ b/model/src/pyrenew/observation/poisson.py @@ -42,7 +42,7 @@ def __init__( def sample( self, - predicted: ArrayLike, + mu: ArrayLike, obs: ArrayLike | None = None, name: str | None = None, **kwargs, @@ -52,7 +52,7 @@ def sample( Parameters ---------- - predicted : ArrayLike + mu : ArrayLike Rate parameter of the Poisson distribution. obs : ArrayLike | None, optional Observed data. Defaults to None. @@ -72,7 +72,7 @@ def sample( return ( numpyro.sample( name=name, - fn=dist.Poisson(rate=predicted + self.eps), + fn=dist.Poisson(rate=mu + self.eps), obs=obs, ), ) diff --git a/model/src/pyrenew/process/rtrandomwalk.py b/model/src/pyrenew/process/rtrandomwalk.py index 4fa70931..95d9cf8c 100644 --- a/model/src/pyrenew/process/rtrandomwalk.py +++ b/model/src/pyrenew/process/rtrandomwalk.py @@ -95,7 +95,7 @@ def validate( def sample( self, - duration: int, + n_timepoints: int, **kwargs, ) -> tuple: """ @@ -103,7 +103,7 @@ def sample( Parameters ---------- - duration : int + n_timepoints : int Number of timepoints to sample. **kwargs : dict, optional Additional keyword arguments passed through to internal sample() @@ -112,7 +112,7 @@ def sample( Returns ------- tuple - With a single array of shape (duration,). + With a single array of shape (n_timepoints,). """ Rt0 = npro.sample("Rt0", self.Rt0_dist) @@ -120,7 +120,7 @@ def sample( Rt0_trans = self.Rt_transform(Rt0) Rt_trans_proc = SimpleRandomWalkProcess(self.Rt_rw_dist) Rt_trans_ts, *_ = Rt_trans_proc.sample( - duration=duration, + n_timepoints=n_timepoints, name="Rt_transformed_rw", init=Rt0_trans, ) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index ef506237..d2c233b3 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -34,7 +34,7 @@ def __init__( def sample( self, - duration: int, + n_timepoints: int, name: str = "randomwalk", init: float = None, **kwargs, @@ -44,7 +44,7 @@ def sample( Parameters ---------- - duration : int + n_timepoints : int Length of the walk. name : str, optional Passed to numpyro.sample, by default "randomwalk" @@ -57,13 +57,14 @@ def sample( Returns ------- tuple - With a single array of shape (duration,). + With a single array of shape (n_timepoints,). """ if init is None: init = npro.sample(name + "_init", self.error_distribution) diffs = npro.sample( - name + "_diffs", self.error_distribution.expand((duration - 1,)) + name + "_diffs", + self.error_distribution.expand((n_timepoints - 1,)), ) return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),) diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index cc970be5..486e130c 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -22,7 +22,7 @@ def test_admissions_sample(): np.random.seed(223) rt = RtRandomWalkProcess() with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_rt, *_ = rt.sample(duration=30) + sim_rt, *_ = rt.sample(n_timepoints=30) gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) i0 = 10 * jnp.ones_like(gen_int) @@ -60,16 +60,16 @@ def test_admissions_sample(): ) hosp1 = HospitalAdmissions( - infection_to_admission_interval=inf_hosp, - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp, + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_hosp_1 = hosp1.sample(latent=inf_sampled1[0]) + sim_hosp_1 = hosp1.sample(latent_infections=inf_sampled1[0]) testing.assert_array_less( - sim_hosp_1.predicted, + sim_hosp_1.observed_hosp_admissions, inf_sampled1[0], ) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 49db045a..58a06989 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -19,7 +19,7 @@ def test_infections_as_deterministic(): np.random.seed(223) rt = RtRandomWalkProcess() with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_rt, *_ = rt.sample(duration=30) + sim_rt, *_ = rt.sample(n_timepoints=30) gen_int = jnp.array([0.25, 0.25, 0.25, 0.25]) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index c6aea454..3b38cdb6 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -39,11 +39,11 @@ def test_model_basicrenewal_no_timepoints_or_observations(): rt = RtRandomWalkProcess() model1 = RtInfectionsRenewalModel( - I0=I0, - gen_int=gen_int, - latent_infections=latent_infections, - observation_process=observed_infections, - Rt_process=rt, + I0_rv=I0, + gen_int_rv=gen_int, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observed_infections, + Rt_process_rv=rt, ) np.random.seed(2203) @@ -72,11 +72,11 @@ def test_model_basicrenewal_both_timepoints_and_observations(): rt = RtRandomWalkProcess() model1 = RtInfectionsRenewalModel( - I0=I0, - gen_int=gen_int, - latent_infections=latent_infections, - observation_process=observed_infections, - Rt_process=rt, + I0_rv=I0, + gen_int_rv=gen_int, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observed_infections, + Rt_process_rv=rt, ) np.random.seed(2203) @@ -112,12 +112,12 @@ def test_model_basicrenewal_no_obs_model(): rt = RtRandomWalkProcess() model0 = RtInfectionsRenewalModel( - gen_int=gen_int, - I0=I0, - latent_infections=latent_infections, - Rt_process=rt, + gen_int_rv=gen_int, + I0_rv=I0, + latent_infections_rv=latent_infections, + Rt_process_rv=rt, # Explicitly use None, this should call the NullObservation - observation_process=None, + infection_obs_process_rv=None, ) # Sampling and fitting model 0 (with no obs for infections) @@ -126,10 +126,10 @@ def test_model_basicrenewal_no_obs_model(): model0_samp = model0.sample(n_timepoints_to_simulate=30) model0_samp.Rt model0_samp.latent_infections - model0_samp.sampled_infections + model0_samp.sampled_observed_infections # Generating - model0.observation_process = NullObservation() + model0.infection_obs_process_rv = NullObservation() np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) @@ -139,7 +139,8 @@ def test_model_basicrenewal_no_obs_model(): model0_samp.latent_infections, model1_samp.latent_infections ) np.testing.assert_array_equal( - model0_samp.sampled_infections, model1_samp.sampled_infections + model0_samp.sampled_observed_infections, + model1_samp.sampled_observed_infections, ) model0.run( @@ -184,11 +185,11 @@ def test_model_basicrenewal_with_obs_model(): rt = RtRandomWalkProcess() model1 = RtInfectionsRenewalModel( - I0=I0, - gen_int=gen_int, - latent_infections=latent_infections, - observation_process=observed_infections, - Rt_process=rt, + I0_rv=I0, + gen_int_rv=gen_int, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observed_infections, + Rt_process_rv=rt, ) # Sampling and fitting model 1 (with obs infections) @@ -200,7 +201,7 @@ def test_model_basicrenewal_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jax.random.PRNGKey(22), - observed_infections=model1_samp.sampled_infections, + observed_infections=model1_samp.sampled_observed_infections, ) inf = model1.spread_draws(["latent_infections"]) @@ -253,11 +254,11 @@ def test_model_basicrenewal_plot() -> plt.Figure: rt = RtRandomWalkProcess() model1 = RtInfectionsRenewalModel( - I0=I0, - gen_int=gen_int, - latent_infections=latent_infections, - observation_process=observed_infections, - Rt_process=rt, + I0_rv=I0, + gen_int_rv=gen_int, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observed_infections, + Rt_process_rv=rt, ) # Sampling and fitting model 1 (with obs infections) @@ -269,12 +270,12 @@ def test_model_basicrenewal_plot() -> plt.Figure: num_warmup=500, num_samples=500, rng_key=jax.random.PRNGKey(22), - observed_infections=model1_samp.sampled_infections, + observed_infections=model1_samp.sampled_observed_infections, ) return model1.plot_posterior( var="latent_infections", - obs_signal=model1_samp.sampled_infections, + obs_signal=model1_samp.sampled_observed_infections, ) @@ -296,11 +297,11 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 rt = RtRandomWalkProcess() model1 = RtInfectionsRenewalModel( - I0=I0, - gen_int=gen_int, - latent_infections=latent_infections, - observation_process=observed_infections, - Rt_process=rt, + I0_rv=I0, + gen_int_rv=gen_int, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observed_infections, + Rt_process_rv=rt, ) # Sampling and fitting model 1 (with obs infections) @@ -309,7 +310,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 model1_samp = model1.sample(n_timepoints_to_simulate=30) new_obs = jnp.hstack( - [jnp.repeat(jnp.nan, 5), model1_samp.sampled_infections[5:]], + [jnp.repeat(jnp.nan, 5), model1_samp.sampled_observed_infections[5:]], ) model1.run( diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 173dd1ca..77b5ce29 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -83,19 +83,19 @@ def test_model_hosp_no_timepoints_or_observations(): ) latent_admissions = HospitalAdmissions( - infection_to_admission_interval=inf_hosp, - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp, + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) model1 = HospitalAdmissionsModel( - gen_int=gen_int, - I0=I0, - Rt_process=Rt_process, - latent_infections=latent_infections, - latent_admissions=latent_admissions, - observation_process=observed_admissions, + gen_int_rv=gen_int, + I0_rv=I0, + Rt_process_rv=Rt_process, + latent_infections_rv=latent_infections, + latent_hosp_admissions_rv=latent_admissions, + hosp_admission_obs_process_rv=observed_admissions, ) np.random.seed(223) @@ -148,19 +148,19 @@ def test_model_hosp_both_timepoints_and_observations(): ) latent_admissions = HospitalAdmissions( - infection_to_admission_interval=inf_hosp, - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp, + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) model1 = HospitalAdmissionsModel( - gen_int=gen_int, - I0=I0, - Rt_process=Rt_process, - latent_infections=latent_infections, - latent_admissions=latent_admissions, - observation_process=observed_admissions, + gen_int_rv=gen_int, + I0_rv=I0, + Rt_process_rv=Rt_process, + latent_infections_rv=latent_infections, + latent_hosp_admissions_rv=latent_admissions, + hosp_admission_obs_process_rv=observed_admissions, ) np.random.seed(223) @@ -168,7 +168,7 @@ def test_model_hosp_both_timepoints_and_observations(): with pytest.raises(ValueError, match="Cannot pass both"): model1.sample( n_timepoints_to_simulate=30, - observed_admissions=jnp.repeat(jnp.nan, 30), + observed_hosp_admissions=jnp.repeat(jnp.nan, 30), ) @@ -217,20 +217,20 @@ def test_model_hosp_no_obs_model(): ) latent_admissions = HospitalAdmissions( - infection_to_admission_interval=inf_hosp, - admissions_predicted_varname="observed_admissions", - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp, + observed_hosp_admissions_varname="observed_admissions", + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) model0 = HospitalAdmissionsModel( - gen_int=gen_int, - I0=I0, - Rt_process=Rt_process, - latent_infections=latent_infections, - latent_admissions=latent_admissions, - observation_process=None, + gen_int_rv=gen_int, + I0_rv=I0, + Rt_process_rv=Rt_process, + latent_infections_rv=latent_infections, + latent_hosp_admissions_rv=latent_admissions, + hosp_admission_obs_process_rv=None, ) # Sampling and fitting model 0 (with no obs for infections) @@ -248,19 +248,22 @@ def test_model_hosp_no_obs_model(): np.testing.assert_array_equal( model0_samp.latent_infections, model1_samp.latent_infections ) - np.testing.assert_array_equal(model0_samp.IHR, model1_samp.IHR) np.testing.assert_array_equal( - model0_samp.latent_admissions, model1_samp.latent_admissions + model0_samp.infection_hosp_rate, model1_samp.infection_hosp_rate ) np.testing.assert_array_equal( - model0_samp.sampled_admissions, model1_samp.sampled_admissions + model0_samp.latent_hosp_admissions, model1_samp.latent_hosp_admissions + ) + np.testing.assert_array_equal( + model0_samp.sampled_observed_hosp_admissions, + model1_samp.sampled_observed_hosp_admissions, ) model0.run( num_warmup=500, num_samples=500, rng_key=jax.random.PRNGKey(272), - observed_admissions=model0_samp.latent_admissions, + observed_hosp_admissions=model0_samp.latent_hosp_admissions, ) inf = model0.spread_draws(["observed_admissions"]) @@ -321,19 +324,19 @@ def test_model_hosp_with_obs_model(): ) latent_admissions = HospitalAdmissions( - infection_to_admission_interval=inf_hosp, - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp, + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) model1 = HospitalAdmissionsModel( - gen_int=gen_int, - I0=I0, - Rt_process=Rt_process, - latent_infections=latent_infections, - latent_admissions=latent_admissions, - observation_process=observed_admissions, + gen_int_rv=gen_int, + I0_rv=I0, + Rt_process_rv=Rt_process, + latent_infections_rv=latent_infections, + latent_hosp_admissions_rv=latent_admissions, + hosp_admission_obs_process_rv=observed_admissions, ) # Sampling and fitting model 0 (with no obs for infections) @@ -345,13 +348,13 @@ def test_model_hosp_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jax.random.PRNGKey(272), - observed_admissions=model1_samp.sampled_admissions, + observed_hosp_admissions=model1_samp.sampled_observed_hosp_admissions, ) - inf = model1.spread_draws(["predicted_admissions"]) + inf = model1.spread_draws(["observed_hosp_admissions"]) inf_mean = ( inf.group_by("draw") - .agg(pl.col("predicted_admissions").mean()) + .agg(pl.col("observed_hosp_admissions").mean()) .sort(pl.col("draw")) ) @@ -415,21 +418,21 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): weekday = UniformProbForTest("weekday") latent_admissions = HospitalAdmissions( - infection_to_admission_interval=inf_hosp, - weekday_effect_dist=weekday, - hosp_report_prob_dist=hosp_report_prob_dist, - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp, + weekday_effect_rv=weekday, + hosp_report_prob_rv=hosp_report_prob_dist, + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) model1 = HospitalAdmissionsModel( - I0=I0, - gen_int=gen_int, - Rt_process=Rt_process, - latent_infections=latent_infections, - latent_admissions=latent_admissions, - observation_process=observed_admissions, + I0_rv=I0, + gen_int_rv=gen_int, + Rt_process_rv=Rt_process, + latent_infections_rv=latent_infections, + latent_hosp_admissions_rv=latent_admissions, + hosp_admission_obs_process_rv=observed_admissions, ) # Sampling and fitting model 0 (with no obs for infections) @@ -441,13 +444,13 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): num_warmup=500, num_samples=500, rng_key=jax.random.PRNGKey(272), - observed_admissions=model1_samp.sampled_admissions, + observed_hosp_admissions=model1_samp.sampled_observed_hosp_admissions, ) - inf = model1.spread_draws(["predicted_admissions"]) + inf = model1.spread_draws(["observed_hosp_admissions"]) inf_mean = ( inf.group_by("draw") - .agg(pl.col("predicted_admissions").mean()) + .agg(pl.col("observed_hosp_admissions").mean()) .sort(pl.col("draw")) ) @@ -521,21 +524,21 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) latent_admissions = HospitalAdmissions( - infection_to_admission_interval=inf_hosp, - weekday_effect_dist=weekday, - hosp_report_prob_dist=hosp_report_prob_dist, - infect_hosp_rate_dist=DistributionalRV( + infection_to_admission_interval_rv=inf_hosp, + weekday_effect_rv=weekday, + hosp_report_prob_rv=hosp_report_prob_dist, + infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), ) model1 = HospitalAdmissionsModel( - I0=I0, - gen_int=gen_int, - Rt_process=Rt_process, - latent_infections=latent_infections, - latent_admissions=latent_admissions, - observation_process=observed_admissions, + I0_rv=I0, + gen_int_rv=gen_int, + Rt_process_rv=Rt_process, + latent_infections_rv=latent_infections, + latent_hosp_admissions_rv=latent_admissions, + hosp_admission_obs_process_rv=observed_admissions, ) # Sampling and fitting model 0 (with no obs for infections) @@ -546,7 +549,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): obs = jnp.hstack( [ jnp.repeat(jnp.nan, 5), - model1_samp.sampled_admissions[5 + gen_int.size() :], + model1_samp.sampled_observed_hosp_admissions[5 + gen_int.size() :], ] ) # Running with padding @@ -554,14 +557,14 @@ def test_model_hosp_with_obs_model_weekday_phosp(): num_warmup=500, num_samples=500, rng_key=jax.random.PRNGKey(272), - observed_admissions=obs, + observed_hosp_admissions=obs, padding=5, ) - inf = model1.spread_draws(["predicted_admissions"]) + inf = model1.spread_draws(["observed_hosp_admissions"]) inf_mean = ( inf.group_by("draw") - .agg(pl.col("predicted_admissions").mean()) + .agg(pl.col("observed_hosp_admissions").mean()) .sort(pl.col("draw")) ) diff --git a/model/src/test/test_observation_negativebinom.py b/model/src/test/test_observation_negativebinom.py index d2b40e61..a90408b5 100644 --- a/model/src/test/test_observation_negativebinom.py +++ b/model/src/test/test_observation_negativebinom.py @@ -17,8 +17,8 @@ def test_negativebinom_deterministic_obs(): np.random.seed(223) rates = np.random.randint(1, 5, size=10) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_pois1 = negb.sample(predicted=rates, obs=rates) - sim_pois2 = negb.sample(predicted=rates, obs=rates) + sim_pois1 = negb.sample(mu=rates, obs=rates) + sim_pois2 = negb.sample(mu=rates, obs=rates) testing.assert_array_equal( sim_pois1, @@ -36,8 +36,8 @@ def test_negativebinom_random_obs(): np.random.seed(223) rates = np.repeat(5, 20000) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_pois1 = negb.sample(predicted=rates) - sim_pois2 = negb.sample(predicted=rates) + sim_pois1 = negb.sample(mu=rates) + sim_pois2 = negb.sample(mu=rates) testing.assert_array_almost_equal( np.mean(sim_pois1), diff --git a/model/src/test/test_observation_poisson.py b/model/src/test/test_observation_poisson.py index f7b5cc39..fee5bbac 100644 --- a/model/src/test/test_observation_poisson.py +++ b/model/src/test/test_observation_poisson.py @@ -18,6 +18,6 @@ def test_poisson_obs(): np.random.seed(223) rates = np.random.randint(1, 5, size=10) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_pois, *_ = pois.sample(predicted=rates) + sim_pois, *_ = pois.sample(mu=rates) testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois))