Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Usage Of npro In Favor Of Just numpyro #316

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Before we start, let's simulate the model with the original `InfectionsWithFeedb
import jax
import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro
import numpyro.distributions as dist
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
Expand Down Expand Up @@ -90,7 +90,7 @@ And simulate from it:
# | label: simulate1
# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model0_samp = model0.sample(n_datapoints=30)
```

Expand Down Expand Up @@ -222,7 +222,7 @@ class InfFeedback(RandomVariable):
)

# Storing adjusted Rt for future use
npro.deterministic("Rt_adjusted", Rt_adj)
numpyro.deterministic("Rt_adjusted", Rt_adj)

# Preparing theoutput

Expand Down Expand Up @@ -259,7 +259,7 @@ model1 = RtInfectionsRenewalModel(

# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_datapoints=30)
```

Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The `RtPeriodicDiff` and `RtWeeklyDiff` classes use `PeriodicBroadcaster` to rep
# | warning: false
import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro
from pyrenew import process, deterministic
```

Expand All @@ -40,7 +40,7 @@ rt_proc = process.RtWeeklyDiffProcess(
```

```{python}
with npro.handlers.seed(rng_seed=20):
with numpyro.handlers.seed(rng_seed=20):
sim_data = rt_proc(duration=30)

# Plotting the Rt values
Expand Down Expand Up @@ -84,7 +84,7 @@ dayofweek = process.DayOfWeekEffect(
Like before, we can use the `sample` method to generate samples from the day of week effect:

```{python}
with npro.handlers.seed(rng_seed=20):
with numpyro.handlers.seed(rng_seed=20):
sim_data = dayofweek(duration=30)

# Plotting the effect values
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import jax.numpy as jnp
import numpyro as npro
import numpyro
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable

Expand Down Expand Up @@ -86,5 +86,5 @@ def sample(
Containing the stored values during construction.
"""
if record:
npro.deterministic(self.name, self.vars)
numpyro.deterministic(self.name, self.vars)
return (self.vars,)
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, NamedTuple

import jax.numpy as jnp
import numpyro as npro
import numpyro
from jax.typing import ArrayLike
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable
Expand Down Expand Up @@ -191,7 +191,7 @@ def sample(
latent_hospital_admissions * self.hosp_report_prob_rv(**kwargs)[0]
)

npro.deterministic(
numpyro.deterministic(
"latent_hospital_admissions", latent_hospital_admissions
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
import numpyro as npro
import numpyro
from pyrenew.latent.infection_initialization_method import (
InfectionInitializationMethod,
)
Expand Down Expand Up @@ -97,6 +97,6 @@ def sample(self) -> tuple:
"""
(I_pre_init,) = self.I_pre_init_rv()
infection_initialization = self.infection_init_method(I_pre_init)
npro.deterministic(self.name, infection_initialization)
numpyro.deterministic(self.name, infection_initialization)

return (infection_initialization,)
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import NamedTuple

import jax.numpy as jnp
import numpyro as npro
import numpyro
import pyrenew.arrayutils as au
import pyrenew.latent.infection_functions as inf
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -192,7 +192,7 @@ def sample(

# Appending initial infections to the infections

npro.deterministic("Rt_adjusted", Rt_adj)
numpyro.deterministic("Rt_adjusted", Rt_adj)

return InfectionsRtFeedbackSample(
post_initialization_infections=post_initialization_infections,
Expand Down
6 changes: 3 additions & 3 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def posterior_predictive(
Random key for the Predictive function call. Defaults to None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to the
:class:`numpyro.inference.Predictive` constructor.
:class:`numpyro.infer.Predictive` constructor.
**kwargs
Additional named arguments passed to the
`__call__()` method of :class:`numpyro.infer.Predictive`
Expand Down Expand Up @@ -559,9 +559,9 @@ def prior_predictive(
rng_key : ArrayLike, optional
Random key for the Predictive function call. Defaults to None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor.
Dictionary of arguments to be passed to the numpyro.infer.Predictive constructor.
**kwargs
Additional named arguments passed to the `__call__()` method of numpyro.inference.Predictive
Additional named arguments passed to the `__call__()` method of numpyro.infer.Predictive

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro as npro
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_forecast():

n_datapoints = 30
n_forecast_points = 10
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model_sample = model.sample(n_datapoints=n_datapoints)

model.run(
Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_infection_seeding_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# numpydoc ignore=GL08
import jax.numpy as jnp
import numpyro as npro
import numpyro
import numpyro.distributions as dist
import pytest
from pyrenew.deterministic import DeterministicVariable
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_infection_initialization_process():
)

for model in [zero_pad_model, exp_model, vec_model]:
with npro.handlers.seed(rng_seed=1):
with numpyro.handlers.seed(rng_seed=1):
model()

# Check that the InfectionInitializationProcess class raises an error when the wrong type of I0 is passed
Expand Down
6 changes: 3 additions & 3 deletions model/src/test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro
import pyrenew.latent as latent
from jax.typing import ArrayLike
from numpy.testing import assert_array_almost_equal, assert_array_equal
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_infectionsrtfeedback():

infections = latent.Infections()

with npro.handlers.seed(rng_seed=0):
with numpyro.handlers.seed(rng_seed=0):
samp1 = InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_infectionsrtfeedback_feedback():

infections = latent.Infections()

with npro.handlers.seed(rng_seed=0):
with numpyro.handlers.seed(rng_seed=0):
samp1 = InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
Expand Down
8 changes: 4 additions & 4 deletions model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as testing
import numpyro as npro
import numpyro
import numpyro.distributions as dist
from pyrenew import transformation as t
from pyrenew.deterministic import DeterministicPMF
Expand Down Expand Up @@ -32,15 +32,15 @@ def test_admissions_sample():
transforms=t.ExpTransform(),
)

with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_rt, *_ = rt(n_steps=30)

gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1])
i0 = 10 * jnp.ones_like(gen_int)

inf1 = Infections()

with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
inf_sampled1 = inf1(Rt=sim_rt, gen_int=gen_int, I0=i0)

# Testing the hospital admissions
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_admissions_sample():
),
)

with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0])

testing.assert_array_less(
Expand Down
8 changes: 4 additions & 4 deletions model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as testing
import numpyro as npro
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
import pytest
Expand All @@ -30,7 +30,7 @@ def test_infections_as_deterministic():
transforms=t.ExpTransform(),
)

with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_rt, *_ = rt(n_steps=30)

gen_int = jnp.array([0.25, 0.25, 0.25, 0.25])
Expand All @@ -42,7 +42,7 @@ def test_infections_as_deterministic():
I0=jnp.zeros(gen_int.size),
gen_int=gen_int,
)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
inf_sampled1 = inf1(**obs)
inf_sampled2 = inf1(**obs)

Expand All @@ -52,7 +52,7 @@ def test_infections_as_deterministic():
)

# Check that Initial infections vector must be at least as long as the generation interval.
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with pytest.raises(ValueError):
obs["I0"] = jnp.array([1])
inf1(**obs)
14 changes: 7 additions & 7 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro as npro
import numpyro
import numpyro.distributions as dist
import polars as pl
import pyrenew.transformation as t
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_model_basicrenewal_no_timepoints_or_observations():
)

np.random.seed(2203)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with pytest.raises(ValueError, match="Either"):
model1.sample(n_datapoints=None, data_observed_infections=None)

Expand Down Expand Up @@ -103,7 +103,7 @@ def test_model_basicrenewal_both_timepoints_and_observations():
)

np.random.seed(2203)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with pytest.raises(ValueError, match="Cannot pass both"):
model1.sample(
n_datapoints=30,
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_model_basicrenewal_no_obs_model():

# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model0_samp = model0.sample(n_datapoints=30)
model0_samp.Rt
model0_samp.latent_infections
Expand All @@ -155,7 +155,7 @@ def test_model_basicrenewal_no_obs_model():
# Generating
model0.infection_obs_process_rv = NullObservation()
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model0.sample(n_datapoints=30)

np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt)
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_model_basicrenewal_with_obs_model():

# Sampling and fitting model 1 (with obs infections)
np.random.seed(2203)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_datapoints=30)

model1.run(
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08
# Sampling and fitting model 1 (with obs infections)
np.random.seed(2203)
pad_size = 5
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_datapoints=30, padding=pad_size)

model1.run(
Expand Down
6 changes: 3 additions & 3 deletions model/src/test/test_observation_negativebinom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import numpy.testing as testing
import numpyro as npro
import numpyro
from jax.typing import ArrayLike
from pyrenew.deterministic import DeterministicVariable
from pyrenew.observation import NegativeBinomialObservation
Expand All @@ -21,7 +21,7 @@ def test_negativebinom_deterministic_obs():

np.random.seed(223)
rates = np.random.randint(1, 5, size=10)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_nb1 = negb(mu=rates, obs=rates)
sim_nb2 = negb(mu=rates, obs=rates)

Expand All @@ -48,7 +48,7 @@ def test_negativebinom_random_obs():

np.random.seed(223)
rates = np.repeat(5, 20000)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_nb1 = negb(mu=rates)
sim_nb2 = negb(mu=rates)
assert isinstance(sim_nb1, tuple)
Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_observation_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as testing
import numpyro as npro
import numpyro
from pyrenew.observation import PoissonObservation


Expand All @@ -17,7 +17,7 @@ def test_poisson_obs():

np.random.seed(223)
rates = np.random.randint(1, 5, size=10)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_pois, *_ = pois(mu=rates)

testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois))
Loading