From d7eeba7704fdee09e0412bedce59c3b443e76337 Mon Sep 17 00:00:00 2001 From: Philip Fackler Date: Fri, 11 Oct 2024 10:54:11 -0500 Subject: [PATCH] Initial working example --- ext/JACCCUDA/JACCCUDA.jl | 59 ++++++++--------- src/JACC.jl | 14 +++-- test/tests_cuda.jl | 132 +++++++++++++++++++++------------------ 3 files changed, 110 insertions(+), 95 deletions(-) diff --git a/ext/JACCCUDA/JACCCUDA.jl b/ext/JACCCUDA/JACCCUDA.jl index 20aa6a1..2503a91 100644 --- a/ext/JACCCUDA/JACCCUDA.jl +++ b/ext/JACCCUDA/JACCCUDA.jl @@ -57,18 +57,21 @@ function JACC.parallel_for( CUDA.@sync @cuda threads = (Lthreads, Mthreads, Nthreads) blocks = (Lblocks, Mblocks, Nblocks) shmem = shmem_size _parallel_for_cuda_LMN(f, x...) end -function JACC.parallel_reduce( - N::I, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_reduce(N::Integer, op, f::Function, x...; init) numThreads = 512 threads = min(N, numThreads) blocks = ceil(Int, N / threads) - ret = CUDA.zeros(Float64, blocks) - rret = CUDA.zeros(Float64, 1) + ret = fill!(CUDA.CuArray{typeof(init)}(undef, 1), init) + rret = CUDA.CuArray([init]) CUDA.@sync @cuda threads=threads blocks=blocks shmem=512 * sizeof(Float64) _parallel_reduce_cuda( - N, ret, f, x...) + N, op, ret, f, x...) CUDA.@sync @cuda threads=threads blocks=1 shmem=512 * sizeof(Float64) reduce_kernel_cuda( - blocks, ret, rret) - return rret + blocks, op, ret, rret) + return Base.Array(rret)[] +end + +function JACC.parallel_reduce(N::Integer, f::Function, x...) + return JACC.parallel_reduce(N, +, f, x...; init = zero(Float64)) end function JACC.parallel_reduce( @@ -113,7 +116,7 @@ function _parallel_for_cuda_LMN(f, x...) return nothing end -function _parallel_reduce_cuda(N, ret, f, x...) +function _parallel_reduce_cuda(N, op, ret, f, x...) shared_mem = @cuDynamicSharedMem(Float64, 512) i = (blockIdx().x - 1) * blockDim().x + threadIdx().x ti = threadIdx().x @@ -126,52 +129,52 @@ function _parallel_reduce_cuda(N, ret, f, x...) end sync_threads() if (ti <= 256) - shared_mem[ti] += shared_mem[ti + 256] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 256]) end sync_threads() if (ti <= 128) - shared_mem[ti] += shared_mem[ti + 128] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 128]) end sync_threads() if (ti <= 64) - shared_mem[ti] += shared_mem[ti + 64] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 64]) end sync_threads() if (ti <= 32) - shared_mem[ti] += shared_mem[ti + 32] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 32]) end sync_threads() if (ti <= 16) - shared_mem[ti] += shared_mem[ti + 16] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 16]) end sync_threads() if (ti <= 8) - shared_mem[ti] += shared_mem[ti + 8] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 8]) end sync_threads() if (ti <= 4) - shared_mem[ti] += shared_mem[ti + 4] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 4]) end sync_threads() if (ti <= 2) - shared_mem[ti] += shared_mem[ti + 2] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 2]) end sync_threads() if (ti == 1) - shared_mem[ti] += shared_mem[ti + 1] + shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 1]) ret[blockIdx().x] = shared_mem[ti] end return nothing end -function reduce_kernel_cuda(N, red, ret) +function reduce_kernel_cuda(N, op, red, ret) shared_mem = @cuDynamicSharedMem(Float64, 512) i = threadIdx().x ii = i tmp::Float64 = 0.0 if N > 512 while ii <= N - tmp += @inbounds red[ii] + tmp = op(tmp, @inbounds red[ii]) ii += 512 end elseif (i <= N) @@ -180,39 +183,39 @@ function reduce_kernel_cuda(N, red, ret) shared_mem[threadIdx().x] = tmp sync_threads() if (i <= 256) - shared_mem[i] += shared_mem[i + 256] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 256]) end sync_threads() if (i <= 128) - shared_mem[i] += shared_mem[i + 128] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 128]) end sync_threads() if (i <= 64) - shared_mem[i] += shared_mem[i + 64] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 64]) end sync_threads() if (i <= 32) - shared_mem[i] += shared_mem[i + 32] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 32]) end sync_threads() if (i <= 16) - shared_mem[i] += shared_mem[i + 16] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 16]) end sync_threads() if (i <= 8) - shared_mem[i] += shared_mem[i + 8] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 8]) end sync_threads() if (i <= 4) - shared_mem[i] += shared_mem[i + 4] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 4]) end sync_threads() if (i <= 2) - shared_mem[i] += shared_mem[i + 2] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 2]) end sync_threads() if (i == 1) - shared_mem[i] += shared_mem[i + 1] + shared_mem[i] = op(shared_mem[i], shared_mem[i + 1]) ret[1] = shared_mem[1] end return nothing diff --git a/src/JACC.jl b/src/JACC.jl index 333d2c0..1b347b6 100644 --- a/src/JACC.jl +++ b/src/JACC.jl @@ -50,18 +50,22 @@ function parallel_for( end end -function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} - tmp = zeros(Threads.nthreads()) - ret = zeros(1) +function parallel_reduce(N::Integer, op, f::Function, x...; init) + ret = init + tmp = fill(init, Threads.nthreads()) @maybe_threaded for i in 1:N - tmp[Threads.threadid()] = tmp[Threads.threadid()] .+ f(i, x...) + tmp[Threads.threadid()] = op.(tmp[Threads.threadid()], f(i, x...)) end for i in 1:Threads.nthreads() - ret = ret .+ tmp[i] + ret = op.(ret, tmp[i]) end return ret end +function parallel_reduce(N::Integer, f::Function, x...) + return parallel_reduce(N, +, f, x...; init = zeros(1)) +end + function parallel_reduce( (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} tmp = zeros(Threads.nthreads()) diff --git a/test/tests_cuda.jl b/test/tests_cuda.jl index 0dab756..800608c 100644 --- a/test/tests_cuda.jl +++ b/test/tests_cuda.jl @@ -97,69 +97,77 @@ end @test zeros(N)≈Array(x) rtol=1e-5 end -# @testset "CG" begin - -# function matvecmul(i, a1, a2, a3, x, y, SIZE) -# if i == 1 -# y[i] = a2[i] * x[i] + a1[i] * x[i+1] -# elseif i == SIZE -# y[i] = a3[i] * x[i-1] + a2[i] * x[i] -# elseif i > 1 && i < SIZE -# y[i] = a3[i] * x[i-1] + a1[i] * +x[i] + a1[i] * +x[i+1] -# end -# end - -# function dot(i, x, y) -# @inbounds return x[i] * y[i] -# end - -# function axpy(i, alpha, x, y) -# @inbounds x[i] += alpha[1, 1] * y[i] -# end - -# SIZE = 10 -# a0 = JACC.ones(Float64, SIZE) -# a1 = JACC.ones(Float64, SIZE) -# a2 = JACC.ones(Float64, SIZE) -# r = JACC.ones(Float64, SIZE) -# p = JACC.ones(Float64, SIZE) -# s = JACC.zeros(Float64, SIZE) -# x = JACC.zeros(Float64, SIZE) -# r_old = JACC.zeros(Float64, SIZE) -# r_aux = JACC.zeros(Float64, SIZE) -# a1 = a1 * 4 -# r = r * 0.5 -# p = p * 0.5 -# global cond = one(Float64) - -# while cond[1, 1] >= 1e-14 - -# r_old = copy(r) - -# JACC.parallel_for(SIZE, matvecmul, a0, a1, a2, p, s, SIZE) - -# alpha0 = JACC.parallel_reduce(SIZE, dot, r, r) -# alpha1 = JACC.parallel_reduce(SIZE, dot, p, s) - -# alpha = alpha0 / alpha1 -# negative_alpha = alpha * (-1.0) - -# JACC.parallel_for(SIZE, axpy, negative_alpha, r, s) -# JACC.parallel_for(SIZE, axpy, alpha, x, p) - -# beta0 = JACC.parallel_reduce(SIZE, dot, r, r) -# beta1 = JACC.parallel_reduce(SIZE, dot, r_old, r_old) -# beta = beta0 / beta1 - -# r_aux = copy(r) +@testset "CG" begin + + function matvecmul(i, a1, a2, a3, x, y, SIZE) + if i == 1 + y[i] = a2[i] * x[i] + a1[i] * x[i+1] + elseif i == SIZE + y[i] = a3[i] * x[i-1] + a2[i] * x[i] + elseif i > 1 && i < SIZE + y[i] = a3[i] * x[i-1] + a1[i] * +x[i] + a1[i] * +x[i+1] + end + end + + function dot(i, x, y) + @inbounds return x[i] * y[i] + end + + function axpy(i, alpha, x, y) + @inbounds x[i] += alpha[1, 1] * y[i] + end + + SIZE = 10 + a0 = JACC.ones(Float64, SIZE) + a1 = JACC.ones(Float64, SIZE) + a2 = JACC.ones(Float64, SIZE) + r = JACC.ones(Float64, SIZE) + p = JACC.ones(Float64, SIZE) + s = JACC.zeros(Float64, SIZE) + x = JACC.zeros(Float64, SIZE) + r_old = JACC.zeros(Float64, SIZE) + r_aux = JACC.zeros(Float64, SIZE) + a1 = a1 * 4 + r = r * 0.5 + p = p * 0.5 + cond = one(Float64) + + while cond[1, 1] >= 1e-14 + + r_old = copy(r) + + JACC.parallel_for(SIZE, matvecmul, a0, a1, a2, p, s, SIZE) + + alpha0 = JACC.parallel_reduce(SIZE, dot, r, r) + alpha1 = JACC.parallel_reduce(SIZE, dot, p, s) + + alpha = alpha0 / alpha1 + negative_alpha = alpha * (-1.0) + + JACC.parallel_for(SIZE, axpy, negative_alpha, r, s) + JACC.parallel_for(SIZE, axpy, alpha, x, p) + + beta0 = JACC.parallel_reduce(SIZE, dot, r, r) + beta1 = JACC.parallel_reduce(SIZE, dot, r_old, r_old) + beta = beta0 / beta1 + + r_aux = copy(r) + + JACC.parallel_for(SIZE, axpy, beta, r_aux, p) + ccond = JACC.parallel_reduce(SIZE, dot, r, r) + cond = ccond + p = copy(r_aux) + end + @test cond[1, 1] <= 1e-14 +end -# JACC.parallel_for(SIZE, axpy, beta, r_aux, p) -# ccond = JACC.parallel_reduce(SIZE, dot, r, r) -# global cond = ccond -# p = copy(r_aux) -# end -# @test cond[1, 1] <= 1e-14 -# end +@testset "reduce" begin + SIZE = 100 + ah = randn(SIZE) + ad = JACC.Array(ah) + mxd = JACC.parallel_reduce(SIZE, max, (i,a)->a[i], ad; init = -Inf) + @test mxd == maximum(ah) +end # @testset "LBM" begin