Skip to content

Latest commit

 

History

History
415 lines (270 loc) · 12 KB

ode_gen.md

File metadata and controls

415 lines (270 loc) · 12 KB

Inference for a deterministic ODE model using Gen.jl

Simon Frost (@sdwfrost), 2024-07-11

Introduction

Gen.jl is a probabilistic programming language that allows for the definition of generative models and the inference of parameters from data. In this notebook, we will use Gen.jl to infer the parameters of an SIR model (as an ordinary differential equation) from simulated data using Importance Sampling (IS), Markov Chain Monte Carlo (MCMC), and Sequential Monte Carlo (SMC). The problem specification is similar to that of the Turing.jl example.

Libraries

using OrdinaryDiffEq
using SciMLSensitivity
using Distributions
using Random
using Gen
using GenDistributions # to define custom distributions
using GenParticleFilters # for SMC
using Plots;

We set a fixed seed for reproducibility.

Random.seed!(1234);

Transitions

We define the ODE system for the SIR model, including a variable for the total number of cases, C, which will be used for inference.

function sir_ode!(du,u,p,t)
        (S,I,R,C) = u
        (β,c,γ) = p
        N = S+I+R
        infection = β*c*I/N*S
        recovery = γ*I
        @inbounds begin
            du[1] = -infection
            du[2] = infection - recovery
            du[3] = recovery
            du[4] = infection
        end
        nothing
end;

Model definition

We define a Gen model that simulates the SIR model; to simplify the code, we define a convenience function that takes the output of the ODE model (as an ODESolution) and returns the number of new cases per day.

function cases_from_solution(sol)
    sol_C = Array(sol)[4, :] # Cumulative cases
    sol_X = abs.(sol_C[2:end] - sol_C[1:(end-1)]) # New cases
    return sol_X
end;

The first argument is the number of timesteps to generate the data, in the form of new cases per day; this convention is important for running Sequential Monte Carlo. Random variables are declared using ~; for array variables, we use the syntax (:y, i) to declare the address of the variable y at index i. The model function returns the solution of the ODE; this is not necessary for inference, but is useful for plotting. The model is a simple deterministic ODE; we generate random observations by assuming independent Poisson variables for the number of new cases per day.

@gen function sir_ode_model(l::Int=40,
                            N::Float64=1000.0,
                            c::Float64=10.0,
                            γ::Float64=0.25,
                            δt::Float64=1.0)
    i₀ ~ uniform(0.001, 0.1)
    β ~ uniform(0.01, 0.1)
    I = i₀ * N
    u0 = [N - I, I, 0.0, 0.0]
    p = [β, c, γ]
    tspan = (0.0, float(l))
    prob = ODEProblem(sir_ode!, u0, tspan, p)
    sol = solve(prob, Tsit5(), saveat = δt)
    sol_X = cases_from_solution(sol)
    for i in 1:l
        {(:y, i)} ~ poisson(sol_X[i])
    end
    return sol
end;

Simulating data

We simulate data from the model, by constraining two of the stochastic nodes representing the parameters. The following generates 40 timepoints of data. Note the use of (l,fixed_args...); this avoids the use of default arguments within the function, and splitting the arguments into two will allow the use of SMC later on.

p = Gen.choicemap()
p[] = 0.05
p[:i₀] = 0.01
fixed_args = (1000.0, 10.0, 0.25, 1.0)
l = 40
(sol, _) = Gen.generate(sir_ode_model, (l,fixed_args...), p);

We can verify the arguments passed to the model, as well as extract the values of the parameters, as follows.

Gen.get_args(sol)
(40, 1000.0, 10.0, 0.25, 1.0)
sol[]
0.05
sol[:i₀]
0.01

The [] operator can be used to extract the return value of the Gen model - in this case, an ODESolution.

ode_sol=sol[]
plot(ode_sol, labels=["S" "I" "R" "C"], xlabel="Time", ylabel="Number", title="Simulated SIR model")

The following plots out the the solution of the ODE, as well as the simulated data, for the number of new cases.

ts = collect(range(1,l))
Yp = cases_from_solution(ode_sol)
Y = [sol[(:y, i)] for i=1:l]
plot(ts,Yp,label="Solution",xlabel="Time",ylabel="Number")
scatter!(ts,Y,label="Observations")

Inference

In order to perform inference, we constrain the observations (given by the addresses (:y, i) to the simulated data, Y.

observations = Gen.choicemap()
for (i, y) in enumerate(Y)
    observations[(:y, i)] = y
end;

Importance sampling

For importance sampling, we use Gen.importance_resampling, a function that takes the model, the observations, and the number of particles, and returns the trace of the model. We can then extract the parameters of interest from the trace, which we store in a Vector.

num_particles = 1000
num_replicates = 1000
β_is = Vector{Real}(undef, num_replicates)
i₀_is = Vector{Real}(undef, num_replicates)
for i in 1:num_replicates
    (trace, lml_est) = Gen.importance_resampling(sir_ode_model, (l,fixed_args...), observations, num_particles)
    β_is[i] = trace[]
    i₀_is[i] = trace[:i₀]
end;

We can then calculate the mean and standard deviation of the inferred parameters, as well as plot out marginal histograms.

mean(β_is),sqrt(var(β_is))
(0.04959096291583562, 0.001589240733892294)
mean(i₀_is),sqrt(var(i₀_is))
(0.010138383061774704, 0.0019685575134365044)
pl_β_is = histogram(β_is, label=false, title="β", ylabel="Density", density=true, xrotation=45)
vline!([sol[]], label="True value")
pl_i₀_is = histogram(i₀_is, label=false, title="i₀", ylabel="Density", density=true, xrotation=45)
vline!([sol[:i₀]], label="True value")
plot(pl_β_is, pl_i₀_is, layout=(1,2), plot_title="Importance sampling")

The parameter estimates are distributed evenly around the true values.

Metroplis-Hastings Markov Chain Monte Carlo

In order to guide the Metropolis-Hastings algorithm, we define a proposal distribution. In order to keep the parameter estimates positive, we use truncated normal distributions. These are not defined in Gen.jl, so we define a new distribution type using DistributionsBacked from GenDistributions.jl.

const truncated_normal = DistributionsBacked((mu,std,lb,ub) -> Distributions.Truncated(Normal(mu, std), lb, ub),
                                             (true,true,false,false),
                                             true,
                                             Float64)


@gen function sir_proposal(current_trace)
    β ~ truncated_normal(current_trace[], 0.001, 0.0, Inf)
    i₀ ~ truncated_normal(current_trace[:i₀], 0.002, 0.0, Inf)
end;

We can then run the Metropolis-Hastings algorithm, storing the parameter estimates in β_mh and i₀_mh, and the scores in scores. If we wanted to omit the targeted proposal, we could use Gen.mh(tr, select(:β, :i₀)) instead. The use of global in the loop is only necessary when running in a script; ideally, this would be run in a function.

n_iter = 100000
β_mh = Vector{Real}(undef, n_iter)
i₀_mh = Vector{Real}(undef, n_iter)
scores = Vector{Float64}(undef, n_iter)
(tr,) = Gen.generate(sir_ode_model, (l,fixed_args...), merge(observations, p))
n_accept = 0
for i in 1:n_iter
    global (tr, did_accept) = Gen.mh(tr, sir_proposal, ()) # Gen.mh(tr, select(:β, :i₀)) for untargeted
    β_mh[i] = tr[]
    i₀_mh[i] = tr[:i₀]
    scores[i] = Gen.get_score(tr)
    if did_accept
        global n_accept += 1
    end
end;

We aim for about a 30% acceptance rate, which is a good rule of thumb for Metropolis-Hastings; after tweaking the proposal variances above, we arrive at a reasonable acceptance rate.

acceptance_rate = n_accept/n_iter
0.28457
pl_β_mh = histogram(β_mh, label=false, title="β", ylabel="Density", density=true, xrotation=45)
vline!([sol[]], label="True value")
pl_i₀_mh = histogram(i₀_mh, label=false, title="i₀", ylabel="Density", density=true, xrotation=45)
vline!([sol[:i₀]], label="True value")
plot(pl_β_mh, pl_i₀_mh, layout=(1,2), plot_title="Metropolis-Hastings")

These too are distributed evenly around the true values.

Sequential Monte Carlo

To run Sequential Monte Carlo, we define a proposal kernel that will be used to rejuvenate the particles. This is a Metropolis-Hastings kernel, which we define as a function kern that takes a trace as an argument and returns a new trace and whether the move was accepted; in this case, we re-use the proposal from the Metropolis_Hastings MCMC above.

kern(tr) = Gen.mh(tr, sir_proposal, ());

GenParticleFilters.jl defines a number of useful functions with which to build a particle filter. We define a function particle_filter that takes the observations, the number of particles, and an optional threshold for the effective sample size. We initialize the particle filter with the first observation, and then iterate across the remaining observations. If the effective sample size falls below the threshold, we resample the particles and rejuvenate them. We then update the filter state with the new observation at timestep t. Internally, Gen.jl runs the model at increasing values of time (hence the use of the number of steps as the first argument of the model), and fixes the parameters of the system up to that timepoint.

function particle_filter(observations, n_particles, ess_thresh=0.5)
    # Initialize particle filter with first observation
    n_obs = length(observations)
    obs_choices = [choicemap((:y, t) => observations[t]) for t=1:n_obs]
    state = pf_initialize(sir_ode_model, (1,fixed_args...), obs_choices[1], n_particles)
    # Iterate across timesteps
    for t=2:n_obs
        # Resample if the effective sample size is too low
        if effective_sample_size(state) < ess_thresh * n_particles
            # Perform residual resampling, pruning low-weight particles
            pf_resample!(state, :residual)
            # Rejuvenate particles
            pf_rejuvenate!(state, kern, ())
        end
        # Update filter state with new observation at timestep t
        # The following code explicitly allows for the number of timesteps to change
        # while keeping the other arguments fixed
        new_args = (t, fixed_args...)
        argdiffs = (UnknownChange(), map(x -> NoChange(), new_args)...)
        pf_update!(state, new_args, argdiffs, obs_choices[t])
    end
    return state
end;
n_particles = 10000
state = particle_filter(Y, n_particles);

The state contains the traces of the particles, which we can use to calculate the effective sample size, as well as the (weighted) mean and standard deviation of the parameter estimates.

effective_sample_size(state)
4889.765243371621
mean(state, ), sqrt(var(state, ))
(0.04527937526296185, 0.0131884923267931)
mean(state, :i₀), sqrt(var(state, :i₀))
(0.026999614295607016, 0.027400975963326527)

We can extract the parameter estimates and weights from the state as follows.

β_smc = getindex.(state.traces, )
i₀_smc = getindex.(state.traces, :i₀)
w = get_norm_weights(state);

As before, the marginal histograms (using the weights w) show that the parameter estimates are distributed evenly around the true values.

pl_β_smc = histogram(β_smc, label=false, title="β", ylabel="Density", density=true, xrotation=45, weights=w, xlim=(0.045,0.055))
vline!([sol[]], label="True value")
pl_i₀_smc = histogram(i₀_smc, label=false, title="i₀", ylabel="Density", density=true, xrotation=45, weights=w, xlim=(0.005,0.015))
vline!([sol[:i₀]], label="True value")
plot(pl_β_smc, pl_i₀_smc, layout=(1,2), plot_title="Sequential Monte Carlo")