Skip to content

Commit

Permalink
add test for special analysis passes for multithreading code
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 8, 2021
1 parent c99b7ec commit 57ecb50
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
4 changes: 3 additions & 1 deletion test/interactive_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ macro def(ex)
@assert isexpr(ex, :block)
return quote let
vmod = $(gen_virtual_module)()
Core.eval(vmod, $(QuoteNode(ex)))
for x in $(ex.args)
Core.eval(vmod, x)
end
vmod # return virtual module
end end
end
Expand Down
110 changes: 110 additions & 0 deletions test/test_abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,3 +682,113 @@ end
@test !isempty(interp.reports)
@test !any(r->isa(r, InvalidInvokeErrorReport), interp.reports)
end

@testset "additional analysis pass for task parallelism code" begin
# general case with `schedule(::Task)` pattern
interp, frame = profile_call() do
t = Task() do
sum("julia")
end
schedule(t)
fetch(t)
end
test_sum_over_string(interp)

# handle `Threads.@spawn` (https://github.com/aviatesk/JET.jl/issues/114)
interp, frame = profile_call() do
fetch(Threads.@spawn 1 + "foo")
end
@test length(interp.reports) == 1
let r = first(interp.reports)
@test isa(r, NoMethodErrorReport)
@test r.atype === Tuple{typeof(+), Int, String}
end

# handle `Threads.@threads`
interp, frame = profile_call((Int,)) do n
a = String[]
Threads.@threads for i in 1:n
push!(a, i)
end
return a
end
@test !isempty(interp.reports)
@test any(interp.reports) do r
isa(r, NoMethodErrorReport) &&
r.atype === Tuple{typeof(convert), Type{String}, Int}
end

# multiple tasks in the same frame
interp, frame = profile_call() do
t1 = Threads.@spawn 1 + "foo"
t2 = Threads.@spawn "foo" + 1
fetch(t1), fetch(t2)
end
@test length(interp.reports) == 2
let r = interp.reports[1]
@test isa(r, NoMethodErrorReport)
@test r.atype === Tuple{typeof(+), Int, String}
end
let r = interp.reports[2]
@test isa(r, NoMethodErrorReport)
@test r.atype === Tuple{typeof(+), String, Int}
end

# nested tasks
interp, frame = profile_call() do
t0 = Task() do
t = Threads.@spawn sum("julia")
fetch(t)
end
schedule(t0)
fetch(t0)
end
test_sum_over_string(interp)

# don't fail into infinite loop (rather, don't spoil inference termination)
m = @def begin
# adapated from https://julialang.org/blog/2019/07/multithreading/
import Base.Threads.@spawn

# sort the elements of `v` in place, from indices `lo` to `hi` inclusive
function psort!(v, lo::Int=1, hi::Int=length(v))
if lo >= hi # 1 or 0 elements; nothing to do
return v
end
if hi - lo < 100000 # below some cutoff, run in serial
sort!(view(v, lo:hi), alg = MergeSort)
return v
end

mid = (lo+hi)>>>1 # find the midpoint

half = @spawn psort!(v, lo, mid) # task to sort the lower half; will run
psort!(v, mid+1, hi) # in parallel with the current call sorting
# the upper half
wait(half) # wait for the lower half to finish

temp = v[lo:mid] # workspace for merging

i, k, j = 1, lo, mid+1 # merge the two sorted sub-arrays
@inbounds while k < j <= hi
if v[j] < temp[i]
v[k] = v[j]
j += 1
else
v[k] = temp[i]
i += 1
end
k += 1
end
@inbounds while k < j
v[k] = temp[i]
k += 1
i += 1
end

return v
end
end
interp, frame = profile_call(m.psort!, (Vector{Int},))
@test true
end

0 comments on commit 57ecb50

Please sign in to comment.