Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create and use default rt across tests #372

Merged
merged 5 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deptry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:

- name: run deptry
run: |
poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP001=pytest,DEP003=pytest"
poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP001=pytest,DEP003=pytest,DEP001=test"
20 changes: 4 additions & 16 deletions model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# numpydoc ignore=GL08

from test.utils import simple_rt

import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
from numpy.testing import assert_array_equal
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
InfectionInitializationProcess,
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess


def test_forecast():
Expand All @@ -30,19 +30,7 @@ def test_forecast():
)
latent_infections = Infections()
observed_infections = PoissonObservation(name="poisson_rv")
rt = TransformedRandomVariable(
name="Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
rt = simple_rt()

model = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down
20 changes: 4 additions & 16 deletions model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

from test.utils import simple_rt

import jax.numpy as jnp
import numpy.testing as testing
import numpyro
import numpyro.distributions as dist
from pyrenew import transformation as t
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import HospitalAdmissions, Infections
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.metaclass import DistributionalRV


def test_admissions_sample():
Expand All @@ -20,19 +20,7 @@ def test_admissions_sample():

# Generating Rt and Infections to compute the hospital admissions

rt = TransformedRandomVariable(
name="Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
rt = simple_rt()

with numpyro.handlers.seed(rng_seed=223):
sim_rt = rt(n_steps=30)[0].value
Expand Down
20 changes: 3 additions & 17 deletions model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

from test.utils import simple_rt

import jax.numpy as jnp
import numpy.testing as testing
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
import pytest
from pyrenew.latent import Infections
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.process import SimpleRandomWalkProcess


def test_infections_as_deterministic():
Expand All @@ -18,19 +16,7 @@ def test_infections_as_deterministic():
the same seed is used.
"""

rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
rt = simple_rt()

with numpyro.handlers.seed(rng_seed=223):
sim_rt, *_ = rt(n_steps=30)
Expand Down
42 changes: 8 additions & 34 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,24 @@
# numpydoc ignore=GL08


from test.utils import simple_rt

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import numpyro.distributions as dist
import polars as pl
import pyrenew.transformation as t
import pytest
from pyrenew.deterministic import DeterministicPMF, NullObservation
from pyrenew.latent import (
InfectionInitializationProcess,
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess


def get_default_rt():
"""
Helper function to create a default Rt
RandomVariable for this testing session.

Returns
-------
TransformedRandomVariable :
A log-scale random walk with fixed
init value and step size priors
"""
return TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)


def test_model_basicrenewal_no_timepoints_or_observations():
Expand All @@ -65,7 +39,7 @@ def test_model_basicrenewal_no_timepoints_or_observations():

observed_infections = PoissonObservation("poisson_rv")

rt = get_default_rt()
rt = simple_rt()

model1 = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down Expand Up @@ -96,7 +70,7 @@ def test_model_basicrenewal_both_timepoints_and_observations():

observed_infections = PoissonObservation("possion_rv")

rt = get_default_rt()
rt = simple_rt()

model1 = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down Expand Up @@ -137,7 +111,7 @@ def test_model_basicrenewal_no_obs_model():

latent_infections = Infections()

rt = get_default_rt()
rt = simple_rt()

model0 = RtInfectionsRenewalModel(
gen_int_rv=gen_int,
Expand Down Expand Up @@ -210,7 +184,7 @@ def test_model_basicrenewal_with_obs_model():

observed_infections = PoissonObservation("poisson_rv")

rt = get_default_rt()
rt = simple_rt()

model1 = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down Expand Up @@ -259,7 +233,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08

observed_infections = PoissonObservation("poisson_rv")

rt = get_default_rt()
rt = simple_rt()

model1 = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down
49 changes: 9 additions & 40 deletions model/src/test/test_model_hosp_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# numpydoc ignore=GL08


from test.utils import simple_rt

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import numpyro.distributions as dist
import polars as pl
import pytest
from pyrenew import transformation as t
from pyrenew.deterministic import (
DeterministicPMF,
DeterministicVariable,
Expand All @@ -21,41 +22,9 @@
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import (
DistributionalRV,
RandomVariable,
SampledValue,
TransformedRandomVariable,
)
from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue
from pyrenew.model import HospitalAdmissionsModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess


def get_default_rt():
"""
Helper function to create a default Rt
RandomVariable for this testing session.

Returns
-------
TransformedRandomVariable :
A log-scale random walk with fixed
init value and step size priors
"""
return TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)


class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08
Expand Down Expand Up @@ -91,7 +60,7 @@ def test_model_hosp_no_timepoints_or_observations():
I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1))

latent_infections = Infections()
Rt_process = get_default_rt()
Rt_process = simple_rt()

observed_admissions = PoissonObservation("poisson_rv")

Expand Down Expand Up @@ -156,7 +125,7 @@ def test_model_hosp_both_timepoints_and_observations():
I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1))

latent_infections = Infections()
Rt_process = get_default_rt()
Rt_process = simple_rt()

observed_admissions = PoissonObservation("poisson_rv")

Expand Down Expand Up @@ -229,7 +198,7 @@ def test_model_hosp_no_obs_model():
)

latent_infections = Infections()
Rt_process = get_default_rt()
Rt_process = simple_rt()

inf_hosp = DeterministicPMF(
name="inf_hosp",
Expand Down Expand Up @@ -339,7 +308,7 @@ def test_model_hosp_with_obs_model():
)

latent_infections = Infections()
Rt_process = get_default_rt()
Rt_process = simple_rt()
observed_admissions = PoissonObservation("poisson_rv")

inf_hosp = DeterministicPMF(
Expand Down Expand Up @@ -426,7 +395,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2():
)

latent_infections = Infections()
Rt_process = get_default_rt()
Rt_process = simple_rt()
observed_admissions = PoissonObservation("poisson_rv")

inf_hosp = DeterministicPMF(
Expand Down Expand Up @@ -525,7 +494,7 @@ def test_model_hosp_with_obs_model_weekday_phosp():
)

latent_infections = Infections()
Rt_process = get_default_rt()
Rt_process = simple_rt()

observed_admissions = PoissonObservation("poisson_rv")

Expand Down
18 changes: 4 additions & 14 deletions model/src/test/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
Ensures that posterior predictive samples are not generated when no posterior samples are available.
"""

from test.utils import simple_rt

import jax.numpy as jnp
import numpyro.distributions as dist
import pyrenew.transformation as t
import pytest
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
InfectionInitializationProcess,
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess

pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45])
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
Expand All @@ -29,17 +29,7 @@
)
latent_infections = Infections()
observed_infections = PoissonObservation("poisson_rv")
rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
),
transforms=t.ExpTransform(),
)
rt = simple_rt()

model = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down
Loading