Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving model to top level #388

Merged
merged 9 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/container_image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ jobs:
make image-build
podman tag pyrenew:latest ghcr.io/cdcgov/pyrenew:latest
podman push ghcr.io/cdcgov/pyrenew:latest
working-directory: model
2 changes: 1 addition & 1 deletion .github/workflows/test_model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
run: |
poetry run pytest \
--mpl --mpl-default-tolerance=10 \
--cov=pyrenew --cov-report term --cov-report xml model
--cov=pyrenew --cov-report term --cov-report xml .

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ repos:
"--baseline",
".secrets.baseline",
"--exclude-files",
"model/docs/*_cache",
]
exclude: package.lock.json
####
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
60 changes: 51 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,63 @@
# Multisignal Renewal Project
# PyRenew: A Package for Bayesian Renewal Modeling with JAX and NumPyro.

⚠️ This is a work in progress ⚠️

[![Pre-commit](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/pre-commit.yaml)
[![installation and testing model](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/model.yaml/badge.svg)](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/model.yaml)
[![installation and testing pipeline](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/pipeline.yaml/badge.svg)](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/pipeline.yaml)
[![Docs: model](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/website.yaml/badge.svg)](https://github.com/CDCgov/multisignal-epi-inference/actions/workflows/website.yaml)
[![codecov (model)](https://codecov.io/gh/CDCgov/multisignal-epi-inference/graph/badge.svg?token=7Z06HOMYR1)](https://codecov.io/gh/CDCgov/multisignal-epi-inference)
`pyrenew` is a flexible tool for simulation and statistical inference of epidemiological models, emphasizing renewal models. Built on top of the [`numpyro`](https://num.pyro.ai/) Python library, `pyrenew` provides core components for model building, including pre-defined models for processing various types of observational processes. To start, visit the tutorials section on the project's website [here](https://cdcgov.github.io/multisignal-epi-inference/tutorials/index.html).

## Overview
The following diagram illustrates the composition of the `HospitalAdmissionsModel` class. Notably, all components are modular and can be replaced with custom implementations.

The **Multisignal Renewal Project** aims to develop a modeling framework that leverages multiple data sources to enhance CDC's epidemiological modeling capabilities. The project's goal is twofold: (a) **create a Python library** that provides a flexible renewal modeling framework and (b) **develop a pipeline** that leverages this framework to estimate epidemiological parameters from multiple data sources and produce forecasts. The library and pipeline are located in the [**model/**](https://github.com/CDCgov/multisignal-epi-inference/tree/main/model) and [**pipeline/**](https://github.com/CDCgov/multisignal-epi-inference/tree/main/pipeline/) directories of the GitHub repository, respectively.
```mermaid
flowchart LR

%% Elements
rt_proc["Random Walk Rt\nProcess (latent)"];
latent_inf["Latent Infections"]
latent_ihr["Infection to Hosp.\nrate (latent)"]
neg_binom["Observation process\n(hospitalizations)"]
latent_hosp["Latent Hospitalizations"];
i0["Initial infections\n(latent)"];
gen_int["Generation\ninterval (fixed)"];
hosp_int["Hospitalization\ninterval (fixed)"];

%% Models
basic_model(("Infections\nModel"));
admin_model(("Hospital Admissions\nModel"));

%% Latent infections
rt_proc --> latent_inf;
i0 --> latent_inf;
gen_int --> latent_inf;
latent_inf --> basic_model

%% Hospitalizations
hosp_int --> latent_hosp

neg_binom --> admin_model;
latent_ihr --> latent_hosp;
basic_model --> admin_model;
latent_hosp --> admin_model;
```

## Installation

Install via pip with

```bash
pip install git+https://github.com/CDCgov/multisignal-epi-inference@main
```

## Container image

A container image is available at `ghcr.io/CDCgov/pyrenew:latest`. You can pull it with

```bash
docker pull ghcr.io/CDCgov/pyrenew:latest
```

## Resources

* [The MSR Website](https://cdcgov.github.io/multisignal-epi-inference/tutorials/index.html) provides general documentation and tutorials on using MSR.
* [The Model Equations Sheet](https://github.com/CDCgov/multisignal-epi-inference/blob/main/model/equations.md) describe the mathematics of the renewal processes and models MSR supports.
* [The Model Equations Sheet](https://github.com/CDCgov/multisignal-epi-inference/blob/main/equations.md) describe the mathematics of the renewal processes and models MSR supports.
* Additional reading on renewal processes in epidemiology
* [_Semi-mechanistic Bayesian modelling of COVID-19 with renewal processes_](https://academic.oup.com/jrsssa/article-pdf/186/4/601/54770289/qnad030.pdf)
* [_Unifying incidence and prevalence under a time-varying general branching process_](https://link.springer.com/content/pdf/10.1007/s00285-023-01958-w.pdf)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import sys

sys.path.insert(0, os.path.abspath("../../model/src"))
sys.path.insert(0, os.path.abspath("../../src"))


# -- Project information -----------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Tutorials
=========

This section contains tutorials that demonstrate how to use the `pyrenew` package. The source code for the tutorials can be found in the project repository: https://github.com/CDCgov/multisignal-epi-inference/tree/main/model/docs/.
This section contains tutorials that demonstrate how to use the `pyrenew` package. The source code for the tutorials can be found in the project repository: https://github.com/CDCgov/multisignal-epi-inference/tree/main/docs/source/tutorials.

.. toctree::
:maxdepth: 2
Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions model/.gitignore

This file was deleted.

53 changes: 0 additions & 53 deletions model/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = ["CFA"]
license = "Apache 2.0"
readme = "README.md"
packages = [
{include = "pyrenew", from = "model/src"},
{include = "pyrenew", from = "src"},
]
include = [{path = "datasets/*.tsv"}]
exclude = [{path = "datasets/*.rds"}]
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable, SampledValue


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

from jax.typing import ArrayLike

from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.distutil import validate_discrete_dist_vector
from pyrenew.metaclass import RandomVariable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from jax.typing import ArrayLike

from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.metaclass import SampledValue

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# numpydoc ignore=GL08

import jax.numpy as jnp

from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.metaclass import SampledValue

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike

from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable, SampledValue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from pyrenew.convolve import new_convolve_scanner, new_double_convolve_scanner
from pyrenew.transformation import ExpTransform, IdentityTransform

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
import numpyro

from pyrenew.latent.infection_initialization_method import (
InfectionInitializationMethod,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.latent.infection_functions as inf
from jax.typing import ArrayLike

import pyrenew.latent.infection_functions as inf
from pyrenew.metaclass import RandomVariable, SampledValue


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import NamedTuple

import jax.numpy as jnp
from numpy.typing import ArrayLike

import pyrenew.arrayutils as au
import pyrenew.latent.infection_functions as inf
from numpy.typing import ArrayLike
from pyrenew.metaclass import (
RandomVariable,
SampledValue,
Expand Down
1 change: 1 addition & 0 deletions model/src/pyrenew/math.py → src/pyrenew/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import jax.numpy as jnp
from jax.typing import ArrayLike

from pyrenew.distutil import validate_discrete_dist_vector


Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from jax.typing import ArrayLike
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import Reparam

from pyrenew.mcmcutils import plot_posterior, spread_draws
from pyrenew.transformation import Transform

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import NamedTuple

from jax.typing import ArrayLike

from pyrenew.deterministic import NullObservation
from pyrenew.metaclass import (
Model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

import jax.numpy as jnp
import numpyro
import pyrenew.arrayutils as au
from numpy.typing import ArrayLike

import pyrenew.arrayutils as au
from pyrenew.deterministic import NullObservation
from pyrenew.metaclass import (
Model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpyro
import numpyro.distributions as dist
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable, SampledValue


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpyro
import numpyro.distributions as dist
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable, SampledValue


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpyro.distributions as dist
from jax import lax
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable, SampledValue


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jax.numpy as jnp
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.process import ARProcess

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.arrayutils as au
from jax.typing import ArrayLike

import pyrenew.arrayutils as au
from pyrenew.metaclass import (
RandomVariable,
SampledValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax.numpy as jnp
from numpyro.contrib.control_flow import scan

from pyrenew.metaclass import RandomVariable, SampledValue


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
from jax.typing import ArrayLike

import pyrenew.transformation as t
from pyrenew.metaclass import SampledValue


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from numpyro.distributions.transforms import (
__all__ as numpyro_public_transforms,
)

from pyrenew.transformation.builtin import ScaledLogitTransform

__all__ = ["ScaledLogitTransform"] + numpyro_public_transforms
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax.numpy as jnp
import numpyro
from numpy.testing import assert_almost_equal

from pyrenew.process import ARProcess


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""

import jax.numpy as jnp
import pyrenew.arrayutils as au
import pytest

import pyrenew.arrayutils as au


def test_arrayutils_pad_to_match():
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpyro.distributions as dist
import pytest
from numpy.testing import assert_equal

from pyrenew.deterministic import DeterministicVariable, NullObservation
from pyrenew.metaclass import (
DistributionalRV,
Expand Down
Loading