From 57ecb5010ec71a7c1af6cb860f4181c94f7fd80c Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 8 Mar 2021 20:25:52 +0900 Subject: [PATCH] add test for special analysis passes for multithreading code --- test/interactive_utils.jl | 4 +- test/test_abstractinterpretation.jl | 110 ++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/test/interactive_utils.jl b/test/interactive_utils.jl index c25aa59a9..5f220b16d 100644 --- a/test/interactive_utils.jl +++ b/test/interactive_utils.jl @@ -36,7 +36,9 @@ macro def(ex) @assert isexpr(ex, :block) return quote let vmod = $(gen_virtual_module)() - Core.eval(vmod, $(QuoteNode(ex))) + for x in $(ex.args) + Core.eval(vmod, x) + end vmod # return virtual module end end end diff --git a/test/test_abstractinterpretation.jl b/test/test_abstractinterpretation.jl index 6318b2b49..f821f7fab 100644 --- a/test/test_abstractinterpretation.jl +++ b/test/test_abstractinterpretation.jl @@ -682,3 +682,113 @@ end @test !isempty(interp.reports) @test !any(r->isa(r, InvalidInvokeErrorReport), interp.reports) end + +@testset "additional analysis pass for task parallelism code" begin + # general case with `schedule(::Task)` pattern + interp, frame = profile_call() do + t = Task() do + sum("julia") + end + schedule(t) + fetch(t) + end + test_sum_over_string(interp) + + # handle `Threads.@spawn` (https://github.com/aviatesk/JET.jl/issues/114) + interp, frame = profile_call() do + fetch(Threads.@spawn 1 + "foo") + end + @test length(interp.reports) == 1 + let r = first(interp.reports) + @test isa(r, NoMethodErrorReport) + @test r.atype === Tuple{typeof(+), Int, String} + end + + # handle `Threads.@threads` + interp, frame = profile_call((Int,)) do n + a = String[] + Threads.@threads for i in 1:n + push!(a, i) + end + return a + end + @test !isempty(interp.reports) + @test any(interp.reports) do r + isa(r, NoMethodErrorReport) && + r.atype === Tuple{typeof(convert), Type{String}, Int} + end + + # multiple tasks in the same frame + interp, frame = profile_call() do + t1 = Threads.@spawn 1 + "foo" + t2 = Threads.@spawn "foo" + 1 + fetch(t1), fetch(t2) + end + @test length(interp.reports) == 2 + let r = interp.reports[1] + @test isa(r, NoMethodErrorReport) + @test r.atype === Tuple{typeof(+), Int, String} + end + let r = interp.reports[2] + @test isa(r, NoMethodErrorReport) + @test r.atype === Tuple{typeof(+), String, Int} + end + + # nested tasks + interp, frame = profile_call() do + t0 = Task() do + t = Threads.@spawn sum("julia") + fetch(t) + end + schedule(t0) + fetch(t0) + end + test_sum_over_string(interp) + + # don't fail into infinite loop (rather, don't spoil inference termination) + m = @def begin + # adapated from https://julialang.org/blog/2019/07/multithreading/ + import Base.Threads.@spawn + + # sort the elements of `v` in place, from indices `lo` to `hi` inclusive + function psort!(v, lo::Int=1, hi::Int=length(v)) + if lo >= hi # 1 or 0 elements; nothing to do + return v + end + if hi - lo < 100000 # below some cutoff, run in serial + sort!(view(v, lo:hi), alg = MergeSort) + return v + end + + mid = (lo+hi)>>>1 # find the midpoint + + half = @spawn psort!(v, lo, mid) # task to sort the lower half; will run + psort!(v, mid+1, hi) # in parallel with the current call sorting + # the upper half + wait(half) # wait for the lower half to finish + + temp = v[lo:mid] # workspace for merging + + i, k, j = 1, lo, mid+1 # merge the two sorted sub-arrays + @inbounds while k < j <= hi + if v[j] < temp[i] + v[k] = v[j] + j += 1 + else + v[k] = temp[i] + i += 1 + end + k += 1 + end + @inbounds while k < j + v[k] = temp[i] + k += 1 + i += 1 + end + + return v + end + end + interp, frame = profile_call(m.psort!, (Vector{Int},)) + @test true +end