From 1f1da1c3e4a90e0f6f1ecc91573bf31545903292 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Fri, 30 Aug 2024 19:49:45 -0400 Subject: [PATCH] make quadgk thread-safe (#116) * make quadgk thread-safe * fix signature * whoops --- src/gausskronrod.jl | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/gausskronrod.jl b/src/gausskronrod.jl index 7c09e9a..c2042f3 100644 --- a/src/gausskronrod.jl +++ b/src/gausskronrod.jl @@ -598,13 +598,20 @@ const wgd7 = [1.2948496616886969327061143267787e-01, const rulecache = Dict{Type,Dict}( Float64 => Dict{Int,NTuple{3,Vector{Float64}}}(7 => (xd7,wd7,wgd7)), Float32 => Dict{Int,NTuple{3,Vector{Float32}}}(7 => (xd7,wd7,wgd7))) +const rulecache_lock = ReentrantLock() # thread-safety # for BigFloat rules, we need a separate cache keyed by (n,precision) const bigrulecache = Dict{Tuple{Int,Int}, NTuple{3,Vector{BigFloat}}}() +const bigrulecache_lock = ReentrantLock() # thread-safety function cachedrule(::Union{Type{BigFloat},Type{Complex{BigFloat}}}, n::Integer) key = (Int(n), precision(BigFloat)) - haskey(bigrulecache, key) ? bigrulecache[key] : (bigrulecache[key] = kronrod(BigFloat, Int(n))) + lock(bigrulecache_lock) + try + return haskey(bigrulecache, key) ? bigrulecache[key] : (bigrulecache[key] = kronrod(BigFloat, Int(n))) + finally + unlock(bigrulecache_lock) + end end # use a generated function to make this type-stable @@ -613,5 +620,22 @@ end :(haskey($cache, n) ? $cache[n] : ($cache[n] = kronrod($TF, n))) end -cachedrule(::Type{T}, n::Integer) where {T<:Number} = - _cachedrule(typeof(float(real(one(T)))), Int(n)) +function cachedrule(::Type{T}, n::Integer) where {T<:Number} + lock(rulecache_lock) + try + return _cachedrule(typeof(float(real(one(T)))), Int(n)) + finally + unlock(rulecache_lock) + end +end + +# fast path for common case of Float64 precision and default order +function cachedrule(::Union{Type{Float64},Type{ComplexF64}}, n::Integer) + n == 7 && return (xd7,wd7,wgd7) + lock(rulecache_lock) + try + return _cachedrule(Float64, Int(n)) + finally + unlock(rulecache_lock) + end +end