diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index 200e0d7cee3ef..b557e8a4b1d5e 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -18,9 +18,6 @@ on `threadid()`. """ nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint))) -# Only read/written by the main thread -const in_threaded_loop = Ref(false) - function _threadsfor(iter,lbody) lidx = iter.args[1] # index range = iter.args[2] @@ -65,15 +62,11 @@ function _threadsfor(iter,lbody) end end end - # Hack to make nested threaded loops kinda work - if threadid() != 1 || in_threaded_loop[] - # We are in a nested threaded loop + if threadid() != 1 + # only thread 1 can enter/exit _threadedregion Base.invokelatest(threadsfor_fun, true) else - in_threaded_loop[] = true - # the ccall is not expected to throw ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun) - in_threaded_loop[] = false end nothing end diff --git a/test/threads_exec.jl b/test/threads_exec.jl index 700114c379826..eb5ee45eb0972 100644 --- a/test/threads_exec.jl +++ b/test/threads_exec.jl @@ -670,3 +670,19 @@ let ch = Channel{Char}(0), t schedule(t) @test String(collect(ch)) == "hello" end + +# errors inside @threads +function _atthreads_with_error(a, err) + Threads.@threads for i in eachindex(a) + if err + error("failed") + end + a[i] = Threads.threadid() + end + a +end +@test_throws TaskFailedException _atthreads_with_error(zeros(nthreads()), true) +let a = zeros(nthreads()) + _atthreads_with_error(a, false) + @test a == [1:nthreads();] +end