Skip to content

Commit

Permalink
Adding target option
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Apr 16, 2024
1 parent a2a0804 commit e002e63
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 112 deletions.
13 changes: 8 additions & 5 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ module JACCAMDGPU

using JACC, AMDGPU

function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_for(::ROCBackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
@roc groupsize = threads gridsize = blocks _parallel_for_amdgpu(f, x...)
AMDGPU.synchronize()
end

function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_for(::ROCBackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -20,7 +20,7 @@ function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:
AMDGPU.synchronize()
end

function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(::ROCBackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -34,7 +34,7 @@ function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:ROCArray}})

end

function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(::ROCBackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down Expand Up @@ -300,7 +300,10 @@ function reduce_kernel_amdgpu_MN((M, N), red, ret)
end

function __init__()
const JACC.Array = AMDGPU.ROCArray{T,N} where {T,N}
if JACC.JACCPreferences.backend == "amdgpu"
const JACC.default_backend = ROCBackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

end # module JACCAMDGPU
13 changes: 8 additions & 5 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ module JACCCUDA

using JACC, CUDA

function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_for(::CUDABackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
maxPossibleThreads = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X)
threads = min(N, maxPossibleThreads)
blocks = ceil(Int, N / threads)
CUDA.@sync @cuda threads = threads blocks = blocks _parallel_for_cuda(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_for(::CUDABackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -18,7 +18,7 @@ function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:
CUDA.@sync @cuda threads = (Mthreads, Nthreads) blocks = (Mblocks, Nblocks) _parallel_for_cuda_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(::CUDABackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -30,7 +30,7 @@ function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:CuArray}})
end


function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(::CUDABackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down Expand Up @@ -294,7 +294,10 @@ function reduce_kernel_cuda_MN((M, N), red, ret)
end

function __init__()
const JACC.Array = CUDA.CuArray{T,N} where {T,N}
if JACC.JACCPreferences.backend == "cuda"
const JACC.default_backend = CUDABackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

end # module JACCCUDA
13 changes: 8 additions & 5 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ module JACCONEAPI

using JACC, oneAPI

function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_for(::oneAPIBackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
#maxPossibleItems = oneAPI.oneL0.compute_properties(device().maxTotalGroupSize)
maxPossibleItems = 256
items = min(N, maxPossibleItems)
groups = ceil(Int, N / items)
oneAPI.@sync @oneapi items = items groups = groups _parallel_for_oneapi(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_for(::oneAPIBackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
maxPossibleItems = 16
Mitems = min(M, maxPossibleItems)
Nitems = min(N, maxPossibleItems)
Expand All @@ -20,7 +20,7 @@ function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:
oneAPI.@sync @oneapi items = (Mitems, Nitems) groups = (Mgroups, Ngroups) _parallel_for_oneapi_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(::oneAPIBackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
numItems = 256
items = min(N, numItems)
groups = ceil(Int, N / items)
Expand All @@ -31,7 +31,7 @@ function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:oneArray}})
return rret
end

function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(::oneAPIBackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
numItems = 16
Mitems = min(M, numItems)
Nitems = min(N, numItems)
Expand Down Expand Up @@ -294,7 +294,10 @@ function reduce_kernel_oneapi_MN((M, N), red, ret)
end

function __init__()
const JACC.Array = oneAPI.oneArray{T,N} where {T,N}
if JACC.JACCPreferences.backend == "oneapi"
const JACC.default_backend = oneAPIBackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

end # module JACCONEAPI
45 changes: 34 additions & 11 deletions src/JACC.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
module JACC

# module to set back end preferences
# module to set back end preferences
include("JACCPreferences.jl")
include("helper.jl")

export Array
export parallel_for
export parallel_for, parallel_reduce, ThreadsBackend, print_default_backend

global Array
struct ThreadsBackend end

function parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
export default_backend

global default_backend = ThreadsBackend()

# default backend API
function parallel_for(N::I, f::F, x...) where {I<:Integer,F<:Function}
parallel_for(default_backend, N, f, x...)
end

function parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
parallel_for(default_backend, N, f, x...)
end

function parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
parallel_reduce(default_backend, N, f, x...)
end

function parallel_reduce((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
parallel_reduce(default_backend, (M, N), f, x...)
end

function parallel_for(::ThreadsBackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
@maybe_threaded for i in 1:N
f(i, x...)
end
end

function parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
function parallel_for(::ThreadsBackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
@maybe_threaded for j in 1:N
for i in 1:M
f(i, j, x...)
end
end
end

function parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
function parallel_reduce(::ThreadsBackend, N::I, f::F, x...) where {I<:Integer,F<:Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for i in 1:N
Expand All @@ -35,7 +55,7 @@ function parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) wh
return ret
end

function parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
function parallel_reduce(::ThreadsBackend, (M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for j in 1:N
Expand All @@ -50,12 +70,15 @@ function parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:Ba
end

function __init__()
@info("Using JACC backend: $(JACCPreferences.backend)")

if JACCPreferences.backend == "threads"
const JACC.Array = Base.Array{T,N} where {T,N}
const JACC.default_backend = ThreadsBackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

function print_default_backend()
println("Default backend is $default_backend")
end


end # module JACC
6 changes: 1 addition & 5 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
[deps]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
47 changes: 27 additions & 20 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
import JACC
using JACC
using CUDA
using AMDGPU
using oneAPI
using Test

using Pkg

const backend = JACC.JACCPreferences.backend

@static if backend == "cuda"
Pkg.add(name="CUDA", version="v5.1.1")
@show "CUDA backend loaded"
include("tests_cuda.jl")
@testset "JACC Tests" begin
if CUDA.functional()
@testset "CUDA" begin
println("CUDA backend")
include("tests_cuda.jl")
end
end

elseif backend == "amdgpu"
Pkg.add(name="AMDGPU", version="v0.8.6")
@show "AMDGPU backend loaded"
include("tests_amdgpu.jl")
if AMDGPU.functional()
@testset "AMDGPU" begin
println("AMDGPU backend")
include("tests_amdgpu.jl")
end
end

elseif backend == "oneapi"
Pkg.add("oneAPI")
@show "OneAPI backend loaded"
include("tests_oneapi.jl")
if oneAPI.functional()
@testset "oneAPI" begin
println("OneAPI backend")
include("tests_oneapi.jl")
end
end

elseif backend == "threads"
@show "Threads backend loaded"
@testset "ThreadsBackend" begin
println("Threads backend")
include("tests_threads.jl")

end
end
4 changes: 2 additions & 2 deletions test/tests_amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ end
dims = (N)
a = round.(rand(Float32, dims) * 100)

a_device = JACC.Array(a)
JACC.parallel_for(N, f, a_device)
a_device = ROCArray(a)
JACC.parallel_for(ROCBackend, N, f, a_device)

a_expected = a .+ 5.0
@test Array(a_device) a_expected rtol = 1e-5
Expand Down
12 changes: 6 additions & 6 deletions test/tests_amdgpu_perf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Test
end

function axpy_amdgpu(SIZE,alpha,x,y)
maxPossibleThreads = 512
maxPossibleThreads = 512
threads = min(SIZE, maxPossibleThreads)
blocks = ceil(Int, SIZE/threads)
@roc groupsize=threads gridsize=threads*blocks axpy_amdgpu_kernel(alpha,x,y)
Expand All @@ -37,13 +37,13 @@ using Test
x = ones(SIZE)
y = ones(SIZE)
alpha = 2.0
jx = JACC.Array(x)
jy = JACC.Array(y)
JACC.parallel_for(10, axpy, alpha, jx, jy)
jx = ROCArray(x)
jy = ROCArray(y)

JACC.parallel_for(ROCBackend(), 10, axpy, alpha, jx, jy)
for i in [10,100,1_000,1_0000,100_000,1_000_000,10_000_000,100_000_000]
@time begin
JACC.parallel_for(i, axpy, alpha, jx, jy)
JACC.parallel_for(ROCBackend(), i, axpy, alpha, jx, jy)
end
end

Expand Down
14 changes: 5 additions & 9 deletions test/tests_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ import JACC
using Test


@testset "TestBackend" begin
@test JACC.JACCPreferences.backend == "cuda"
end

@testset "VectorAddLambda" begin

function f(i, a)
Expand All @@ -17,8 +13,8 @@ end
dims = (N)
a = round.(rand(Float32, dims) * 100)

a_device = JACC.Array(a)
JACC.parallel_for(N, f, a_device)
a_device = CuArray(a)
JACC.parallel_for(CUDABackend(), N, f, a_device)

a_expected = a .+ 5.0
@test Array(a_device) a_expected rtol = 1e-5
Expand All @@ -43,9 +39,9 @@ end
y = round.(rand(Float32, N) * 100)
alpha = 2.5

x_device = JACC.Array(x)
y_device = JACC.Array(y)
JACC.parallel_for(N, axpy, alpha, x_device, y_device)
x_device = CuArray(x)
y_device = CuArray(y)
JACC.parallel_for(CUDABackend(), N, axpy, alpha, x_device, y_device)

x_expected = x
seq_axpy(N, alpha, x_expected, y)
Expand Down
12 changes: 6 additions & 6 deletions test/tests_cuda_perf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ using Test
end


x_device = CUDA.CuArray(x)
y_device = CUDA.CuArray(y)
x_device = CuArray(x)
y_device = CuArray(y)

for i in 1:11
@time axpy_cuda(N, alpha, x_device, y_device)
end

# JACCCUDA version
# JACCCUDA version
function axpy(i, alpha, x, y)
if i <= length(x)
@inbounds x[i] += alpha * y[i]
end
end

x_device_JACC = JACC.Array(x)
y_device_JACC = JACC.Array(y)
x_device_JACC = CuArray(x)
y_device_JACC = CuArray(y)

for i in 1:11
@time JACC.parallel_for(N, axpy, alpha, x_device_JACC, y_device_JACC)
@time JACC.parallel_for(CUDABackend(), N, axpy, alpha, x_device_JACC, y_device_JACC)
end
end
Loading

0 comments on commit e002e63

Please sign in to comment.