Skip to content

Commit

Permalink
Parallel MCMC in tutorials (#278)
Browse files Browse the repository at this point in the history
set host device count
  • Loading branch information
damonbayer authored Jul 18, 2024
1 parent 54d3ccb commit c66377a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
6 changes: 3 additions & 3 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import DistributionalRV
import pyrenew.transformation as t
npro.set_host_device_count(2)
```

## Architecture of `RtInfectionsRenewalModel`
Expand Down Expand Up @@ -214,9 +216,7 @@ model1.run(
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(
progress_bar=False, num_chains=2, chain_method="sequential"
),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
```

Expand Down
16 changes: 10 additions & 6 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ format: gfm
engine: jupyter
---

```{python}
# | label: numpyro setup
# | echo: false
import numpyro as npro
npro.set_host_device_count(2)
```

This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data.

## Model definition
Expand Down Expand Up @@ -205,7 +213,6 @@ Let's simulate to check if the model is working:

```{python}
# | label: simulation
import numpyro as npro
import numpy as np
timeframe = 120
Expand Down Expand Up @@ -245,15 +252,12 @@ We can fit the model to the data. We will use the `run` method of the model obje
# | label: model-fit
import jax
npro.set_host_device_count(jax.local_device_count())
hosp_model.run(
num_samples=1000,
num_warmup=1000,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(
progress_bar=False, num_chains=2, chain_method="sequential"
),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
```

Expand Down Expand Up @@ -535,7 +539,7 @@ hosp_model_weekday.run(
num_warmup=2000,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
mcmc_args=dict(progress_bar=False, num_chains=2),
padding=pad_size,
)
```
Expand Down
5 changes: 4 additions & 1 deletion docs/source/tutorials/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from numpyro.handlers import seed
import numpyro.distributions as dist
numpyro.set_host_device_count(2)
```

```{python}
Expand Down Expand Up @@ -180,7 +183,7 @@ hospmodel.run(
num_samples=1000,
data_observed_hosp_admissions=x.observed_hosp_admissions,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
```

Expand Down

0 comments on commit c66377a

Please sign in to comment.