Simon Frost (@sdwfrost), 2024-06-03
While Julia is a high-level language, it is possible to define the vector field for an ordinary differential equation (ODE) in another language and call it from Julia. This can be useful for performance reasons (if the calculation of the vector field in Julia happens to be slow for some reason), or if the vector field is already defined, for example, in another codebase. Julia's ccall
makes it easy to call a compiled function in a shared library created by a language that supports the generation of C-compatible shared libraries, such as Zig.
using OrdinaryDiffEq
using Libdl
using Plots
using BenchmarkTools
We define the vector field in Zig; it is easiest for this function to be in-place, so that we do not have to do any memory management on the Zig side. This approach is also more efficient, as it reduces the number of allocations needed.
Zig_code = """
export fn sir_ode(du: [*c]f64, u: [*c]const f64, p: [*c]const f64, t: [*c]const f64) void {
const beta: f64 = p[0];
const c: f64 = p[1];
const gamma: f64 = p[2];
const S: f64 = u[0];
const I: f64 = u[1];
const R: f64 = u[2];
const N: f64 = S + I + R;
_ = t;
du[0] = -beta * c * S * I / N;
du[1] = beta * c * S * I / N - gamma * I;
du[2] = gamma * I;
}
""";
We then compile the code into a shared library.
const Ziglib = tempname();
open(Ziglib * "." * "zig", "w") do f
write(f, Zig_code)
end
run(`zig build-lib -dynamic -O ReleaseSafe -fPIC -femit-bin=$(Ziglib * "." * Libdl.dlext) $(Ziglib * "." * "zig")`);
We can then define the ODE function in Julia, which calls the Zig function using ccall
. du
, u
, p
are arrays of Float64
, which are passed using pointers. t
is passed as a Ref
pointer to a Float64
value.
function sir_ode_jl!(du,u,p,t)
ccall((:sir_ode,Ziglib,), Cvoid,
(Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}), du, u, p, Ref(t))
end;
δt = 0.1
tmax = 40.0
tspan = (0.0,tmax)
u0 = [990.0,10.0,0.0] # S,I,R
p = [0.05,10.0,0.25]; # β,c,γ
prob_ode = ODEProblem{true}(sir_ode_jl!, u0, tspan, p)
sol_ode = solve(prob_ode, Tsit5(), dt = δt);
plot(sol_ode)
@benchmark solve(prob_ode, Tsit5(), dt = δt)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 10.458 μs … 3.240 ms ┊ GC (min … max): 0.00% … 98.67
%
Time (median): 12.125 μs ┊ GC (median): 0.00%
Time (mean ± σ): 13.351 μs ± 55.407 μs ┊ GC (mean ± σ): 7.11% ± 1.71
%
▁▃▁▅▆█▃▆▆▄▃
▁▁▂▂▂▅▆███████████▆▇▇▅▄▄▃▄▃▂▃▂▂▂▂▂▂▂▁▁▂▂▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁ ▃
10.5 μs Histogram: frequency by time 17.3 μs <
Memory estimate: 15.08 KiB, allocs estimate: 173.
We can compare the performance of the Zig-based ODE with the Julia-based ODE.
function sir_ode!(du,u,p,t)
(S,I,R) = u
(β,c,γ) = p
N = S+I+R
@inbounds begin
du[1] = -β*c*I/N*S
du[2] = β*c*I/N*S - γ*I
du[3] = γ*I
end
nothing
end
prob_ode2 = ODEProblem(sir_ode!, u0, tspan, p)
sol_ode2 = solve(prob_ode2, Tsit5(), dt = δt)
@benchmark solve(prob_ode2, Tsit5(), dt = δt)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 10.500 μs … 3.309 ms ┊ GC (min … max): 0.00% … 98.76
%
Time (median): 12.209 μs ┊ GC (median): 0.00%
Time (mean ± σ): 13.455 μs ± 55.732 μs ┊ GC (mean ± σ): 7.10% ± 1.71
%
▃ ▆▇██▂▆▅▃▂
▁▁▁▂▂▄▆▇███████████▆█▆▆▅▃▄▄▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
10.5 μs Histogram: frequency by time 17.4 μs <
Memory estimate: 15.08 KiB, allocs estimate: 173.
Note that the performance of the Zig-based vector field is similar to the one defined in Julia.