Skip to content

Commit

Permalink
simplify batching to evaluate whole segments
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Jul 18, 2023
1 parent 1d6d887 commit 8e498bf
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 421 deletions.
19 changes: 1 addition & 18 deletions src/QuadGK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,10 @@ end
InplaceIntegrand(f!::F, I::RI, fx::R) where {F,RI,R} =
InplaceIntegrand{F,R,RI}(f!, similar(fx), similar(fx), similar(fx), similar(fx), fx, similar(I), I)

struct BatchIntegrand{F,Y,X}
# in-place function f!(y, x) that takes an array of x values and outputs an array of results in-place
f!::F
y::Vector{Y}
x::Vector{X}
max_batch::Int # maximum number of x to supply in parallel (defaults to typemax(Int))
function BatchIntegrand{F,Y,X}(f!::F, y::Vector{Y}, x::Vector{X}, max_batch::Int) where {F,Y,X}
max_batch > 0 || throw(ArgumentError("maximum batch size must be positive"))
return new{F,Y,X}(f!, y, x, max_batch)
end
end

BatchIntegrand(f!::F, y::Vector{Y}, x::Vector{X}; max_batch::Integer=typemax(Int)) where {F,Y,X} =
BatchIntegrand{F,Y,X}(f!, y, x, max_batch)
BatchIntegrand(f!::F, ::Type{Y}, ::Type{X}=Nothing; kwargs...) where {F,Y,X} =
BatchIntegrand(f!, Y[], X[]; kwargs...)

include("gausskronrod.jl")
include("evalrule.jl")
include("evalsegs.jl")
include("adapt.jl")
include("weightedgauss.jl")
include("batch.jl")

end # module QuadGK
156 changes: 27 additions & 129 deletions src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ function do_quadgk(f::F, s::NTuple{N,T}, n, atol, rtol, maxevals, nrm, segbuf) w

@assert N 2
if f isa BatchIntegrand
segs = evalsegs(f, ntuple(i -> (s[i], s[i+1]), Val{N-1}()), x,w,gw, nrm)
f.max_batch < (N-1)*(4n+2) && throw(ArgumentError("Batch buffer can't fit points"))
segs = evalrules(f, s, x,w,gw, nrm)
else
segs = ntuple(i -> evalrule(f, s[i],s[i+1], x,w,gw, nrm), Val{N-1}())
end
Expand Down Expand Up @@ -38,132 +39,66 @@ function do_quadgk(f::F, s::NTuple{N,T}, n, atol, rtol, maxevals, nrm, segbuf) w

segheap = segbuf === nothing ? collect(segs) : (resize!(segbuf, N-1) .= segs)
heapify!(segheap, Reverse)
new_segs = adapt(f, segheap, I, E, numevals, x,w,gw,n, atol_, rtol_, maxevals, nrm)
return resum(f, new_segs)
return resum(f, adapt(f, segheap, I, E, numevals, x,w,gw,n, atol_, rtol_, maxevals, nrm))
end

# internal routine to perform the h-adaptive refinement of the integration segments (segs)
function adapt(f::F, segs::Vector{T}, I, E, numevals, x,w,gw,n, atol, rtol, maxevals, nrm) where {F, T}
# Pop the biggest-error segment and subdivide (h-adaptation)
# until convergence is achieved or maxevals is exceeded.
while (tol = max(atol, rtol*nrm(I))) < E && numevals < maxevals
next = bisect(f, segs, I,E, numevals, tol, x,w,gw,n, atol, rtol, maxevals, nrm)
next isa Vector && return next
while E > atol && E > rtol * nrm(I) && numevals < maxevals
next = refine(f, segs, I, E, numevals, x,w,gw,n, atol, rtol, maxevals, nrm)
next isa Vector && return next # handle type-unstable functions
I, E, numevals = next
end
return segs
end

# re-sum (paranoia about accumulated roundoff)
function resum(f, segs)
if f isa InplaceIntegrand
I = f.I .= segs[1].I
E = segs[1].E
for i in 2:length(segs)
I .+= segs[i].I
E += segs[i].E
end
else
I = segs[1].I
E = segs[1].E
for i in 2:length(segs)
I += segs[i].I
E += segs[i].E
end
end
return (I, E)
end

# internal routine to bisect the segments whose combined error surpasses the tolerance
function bisect(f::F, segs::Vector{T}, I,E, numevals, tol, x,w,gw,n, atol, rtol, maxevals, nrm) where {F, T}
# internal routine to refine the segment with largest error
function refine(f::F, segs::Vector{T}, I, E, numevals, x,w,gw,n, atol, rtol, maxevals, nrm) where {F, T}
s = heappop!(segs, Reverse)
mid = (s.a + s.b) / 2
s1 = evalrule(f, s.a, mid, x,w,gw, nrm)
s2 = evalrule(f, mid, s.b, x,w,gw, nrm)
numevals += 4n+2

if f isa InplaceIntegrand
I .= (I .- s.I) .+ s1.I .+ s2.I
else
I = (I - s.I) + s1.I + s2.I
end
E = (E - s.E) + s1.E + s2.E
numevals += 4n+2

# handle type-unstable functions by converting to a wider type if needed
Tj = promote_type(typeof(s1), promote_type(typeof(s2), T))
if Tj !== T
new_segs = Vector{Tj}(segs)
heappush!(heappush!(new_segs, s1, Reverse), s2, Reverse)
return adapt(f, new_segs, I, E, numevals, x,w,gw,n, atol, rtol, maxevals, nrm)
end

# continue bisecting if the remaining error surpasses current tolerance
if (tol += s1.E + s2.E) < E
next = bisect(f, segs, I,E, numevals, tol, x,w,gw,n, atol, rtol, maxevals, nrm)
if next isa Vector
heappush!(next, s1, Reverse)
heappush!(next, s2, Reverse)
return next
end
I, E, numevals = next
return adapt(f, heappush!(heappush!(Vector{Tj}(segs), s1, Reverse), s2, Reverse),
I, E, numevals, x,w,gw,n, atol, rtol, maxevals, nrm)
end

# add to heap after bisection since otherwise the relative tolerance wouldn't work
heappush!(segs, s1, Reverse)
heappush!(segs, s2, Reverse)

return I, E, numevals
end

# it would be nice to not rely on recursion, since this will lead to additional code
# generation, however the alternative is to give evalsegs an iterator instead of a tuple,
# and that would break since it is assumed the iterator is stateless and has known length,
# both of which are not true for the heap. There may be workarounds but this is easier
function bisect(f::BatchIntegrand{F}, segs::Vector{T}, I,E, numevals, tol, x,w,gw,n, atol, rtol, maxevals, nrm, ss::T...) where {F,T}
if ss isa Tuple{}
# special case this for inference
if tol < E
s = heappop!(segs, Reverse)
return bisect(f, segs, I,E, numevals, tol + s.E, x,w,gw,n, atol, rtol, maxevals, nrm, s, ss...)
else
throw(ArgumentError("no segments to bisect"))
# re-sum (paranoia about accumulated roundoff)
function resum(f, segs)
if f isa InplaceIntegrand
I = f.I .= segs[1].I
E = segs[1].E
for i in 2:length(segs)
I .+= segs[i].I
E += segs[i].E
end
elseif tol < E
s = heappop!(segs, Reverse)
return bisect(f, segs, I,E, numevals, tol + s.E, x,w,gw,n, atol, rtol, maxevals, nrm, s, ss...)
else
new = seg_to_bisect(ss...)
segitr = BatchedSegmentIterator(f, new, x,w,gw, nrm) # fewer allocations
# segitr = evalsegs(f, new, x,w,gw, nrm)
next_seg = iterate(segitr)
next_seg === nothing && throw(ArgumentError("exhausted segments before completion"))
I, E = update_segs!(segs, I, E, segitr, next_seg, ss...)
return I, E, numevals
I = segs[1].I
E = segs[1].E
for i in 2:length(segs)
I += segs[i].I
E += segs[i].E
end
end
end

seg_to_bisect() = ()
function seg_to_bisect(seg, segs...)
mid = (seg.a + seg.b) / 2
return (seg_to_bisect(segs...)..., (seg.a, mid), (mid, seg.b))
end

# matching the order of operations in the unbatched case
update_segs!(segheap, I, E, segitr, ::Nothing) = I, E
function update_segs!(segheap, I, E, segitr, next_seg, s, segs...)
next_seg === nothing && throw(ArgumentError("exhausted segments before completion"))
s1, segstate = next_seg
next_seg = iterate(segitr, segstate)
next_seg === nothing && throw(ArgumentError("exhausted segments before completion"))
s2, segstate = next_seg

I = (I - s.I) + s1.I + s2.I
E = (E - s.E) + s1.E + s2.E

I, E = update_segs!(segheap, I, E, segitr, iterate(segitr, segstate), segs...)

heappush!(segheap, s1, Reverse)
heappush!(segheap, s2, Reverse)
return I, E
return (I, E)
end

realone(x) = false
Expand Down Expand Up @@ -230,37 +165,6 @@ function handle_infinities(workfunc, f::InplaceIntegrand, s)
return workfunc(f, s, identity)
end

function handle_infinities(workfunc, f::BatchIntegrand, s)
s1, s2 = s[1], s[end]
if realone(s1) && realone(s2) # check for infinite or semi-infinite intervals
inf1, inf2 = isinf(s1), isinf(s2)
if inf1 || inf2
xtmp = f.x # buffer to store evaluation points
ftmp = f.y # original integrand may have different units
if inf1 && inf2 # x = t/(1-t^2) coordinate transformation
return workfunc(BatchIntegrand((v, t) -> begin resize!(xtmp, length(t)); resize!(ftmp, length(v));
f.f!(ftmp, xtmp .= oneunit(s1) .* t ./ (1 .- t .* t)); v .= ftmp .* (1 .+ t .* t) .* oneunit(s1) ./ (1 .- t .* t) .^ 2; end, typeof(oneunit(eltype(f.y))*oneunit(s1)), typeof(one(eltype(f.x)))),
map(x -> isinf(x) ? (signbit(x) ? -one(x) : one(x)) : 2x / (oneunit(x)+hypot(oneunit(x),2x)), s),
t -> oneunit(s1) * t / (1 - t^2))
end
let (s0,si) = inf1 ? (s2,s1) : (s1,s2) # let is needed for JuliaLang/julia#15276
if si < zero(si) # x = s0 - t/(1-t)
return workfunc(BatchIntegrand((v, t) -> begin resize!(xtmp, length(t)); resize!(ftmp, length(v));
f.f!(ftmp, xtmp .= s0 .- oneunit(s1) .* t ./ (1 .- t)); v .= ftmp .* oneunit(s1) ./ (1 .- t) .^ 2; end, typeof(oneunit(eltype(f.y))*oneunit(s1)), typeof(one(eltype(f.x)))),
reverse(map(x -> 1 / (1 + oneunit(x) / (s0 - x)), s)),
t -> s0 - oneunit(s1)*t/(1-t))
else # x = s0 + t/(1-t)
return workfunc(BatchIntegrand((v, t) -> begin resize!(xtmp, length(t)); resize!(ftmp, length(v));
f.f!(ftmp, xtmp .= s0 .+ oneunit(s1) .* t ./ (1 .- t)); v .= ftmp .* oneunit(s1) ./ (1 .- t) .^ 2; end, typeof(oneunit(eltype(f.y))*oneunit(s1)), typeof(one(eltype(f.x)))),
map(x -> 1 / (1 + oneunit(x) / (x - s0)), s),
t -> s0 + oneunit(s1)*t/(1-t))
end
end
end
end
return workfunc(f, s, identity)
end

# Gauss-Kronrod quadrature of f from a to b to c...

"""
Expand Down Expand Up @@ -326,12 +230,6 @@ repeated allocation.
quadgk(f, segs...; kws...) =
quadgk(f, promote(segs...)...; kws...)

function quadgk(f::BatchIntegrand{F,Y,Nothing}, segs::T...; kwargs...) where {F,Y,T}
FT = float(T) # the gk points are floating-point
g = BatchIntegrand{F,Y,FT}(f.f!, f.y, FT[], f.max_batch)
return quadgk(g, segs...; kwargs...)
end

function quadgk(f, segs::T...;
atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing) where {T}
handle_infinities(f, segs) do f, s, _
Expand Down
Loading

0 comments on commit 8e498bf

Please sign in to comment.