Skip to content

Latest commit

 

History

History
167 lines (109 loc) · 4.41 KB

ode_ccall_zig.md

File metadata and controls

167 lines (109 loc) · 4.41 KB

Ordinary differential equation model with the vector field defined in Zig

Simon Frost (@sdwfrost), 2024-06-03

Introduction

While Julia is a high-level language, it is possible to define the vector field for an ordinary differential equation (ODE) in another language and call it from Julia. This can be useful for performance reasons (if the calculation of the vector field in Julia happens to be slow for some reason), or if the vector field is already defined, for example, in another codebase. Julia's ccall makes it easy to call a compiled function in a shared library created by a language that supports the generation of C-compatible shared libraries, such as Zig.

Libraries

using OrdinaryDiffEq
using Libdl
using Plots
using BenchmarkTools

Transitions

We define the vector field in Zig; it is easiest for this function to be in-place, so that we do not have to do any memory management on the Zig side. This approach is also more efficient, as it reduces the number of allocations needed.

Zig_code = """
export fn sir_ode(du: [*c]f64, u: [*c]const f64, p: [*c]const f64, t: [*c]const f64) void {
    const beta: f64 = p[0];
    const c: f64 = p[1];
    const gamma: f64 = p[2];
    const S: f64 = u[0];
    const I: f64 = u[1];
    const R: f64 = u[2];
    const N: f64 = S + I + R;
    _ = t;

    du[0] = -beta * c * S * I / N;
    du[1] = beta * c * S * I / N - gamma * I;
    du[2] = gamma * I;
}
""";

We then compile the code into a shared library.

const Ziglib = tempname();
open(Ziglib * "." * "zig", "w") do f
    write(f, Zig_code)
end
run(`zig build-lib -dynamic -O ReleaseSafe -fPIC -femit-bin=$(Ziglib * "." * Libdl.dlext) $(Ziglib * "." * "zig")`);

We can then define the ODE function in Julia, which calls the Zig function using ccall. du, u, p are arrays of Float64, which are passed using pointers. t is passed as a Ref pointer to a Float64 value.

function sir_ode_jl!(du,u,p,t)
    ccall((:sir_ode,Ziglib,), Cvoid,
          (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}), du, u, p, Ref(t))
end;

Time domain and parameters

δt = 0.1
tmax = 40.0
tspan = (0.0,tmax)
u0 = [990.0,10.0,0.0] # S,I,R
p = [0.05,10.0,0.25]; # β,c,γ

Solving the ODE

prob_ode = ODEProblem{true}(sir_ode_jl!, u0, tspan, p)
sol_ode = solve(prob_ode, Tsit5(), dt = δt);

Plotting

plot(sol_ode)

Benchmarking

@benchmark solve(prob_ode, Tsit5(), dt = δt)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  10.458 μs …  3.240 ms  ┊ GC (min … max): 0.00% … 98.67
%
 Time  (median):     12.125 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   13.351 μs ± 55.407 μs  ┊ GC (mean ± σ):  7.11% ±  1.71
%

         ▁▃▁▅▆█▃▆▆▄▃                                           
  ▁▁▂▂▂▅▆███████████▆▇▇▅▄▄▃▄▃▂▃▂▂▂▂▂▂▂▁▁▂▂▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  10.5 μs         Histogram: frequency by time        17.3 μs <

 Memory estimate: 15.08 KiB, allocs estimate: 173.

We can compare the performance of the Zig-based ODE with the Julia-based ODE.

function sir_ode!(du,u,p,t)
    (S,I,R) = u
    (β,c,γ) = p
    N = S+I+R
    @inbounds begin
        du[1] = -β*c*I/N*S
        du[2] = β*c*I/N*S - γ*I
        du[3] = γ*I
    end
    nothing
end
prob_ode2 = ODEProblem(sir_ode!, u0, tspan, p)
sol_ode2 = solve(prob_ode2, Tsit5(), dt = δt)
@benchmark solve(prob_ode2, Tsit5(), dt = δt)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  10.500 μs …  3.309 ms  ┊ GC (min … max): 0.00% … 98.76
%
 Time  (median):     12.209 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   13.455 μs ± 55.732 μs  ┊ GC (mean ± σ):  7.10% ±  1.71
%

          ▃ ▆▇██▂▆▅▃▂                                          
  ▁▁▁▂▂▄▆▇███████████▆█▆▆▅▃▄▄▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  10.5 μs         Histogram: frequency by time        17.4 μs <

 Memory estimate: 15.08 KiB, allocs estimate: 173.

Note that the performance of the Zig-based vector field is similar to the one defined in Julia.