Skip to content

Commit

Permalink
Fix precompilation
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Mar 11, 2024
1 parent 051c50b commit 6d12657
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
8 changes: 4 additions & 4 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ module JACCAMDGPU

using JACC, AMDGPU

function JACC.parallel_for(N::I, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
@roc groupsize = threads gridsize = threads * blocks _parallel_for_amdgpu(f, x...)
# AMDGPU.synchronize()
end

function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -20,7 +20,7 @@ function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:
# AMDGPU.synchronize()
end

function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -34,7 +34,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}

end

function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down
8 changes: 4 additions & 4 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ module JACCCUDA

using JACC, CUDA

function JACC.parallel_for(N::I, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
maxPossibleThreads = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X)
threads = min(N, maxPossibleThreads)
blocks = ceil(Int, N / threads)
CUDA.@sync @cuda threads = threads blocks = blocks _parallel_for_cuda(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -18,7 +18,7 @@ function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:
CUDA.@sync @cuda threads = (Mthreads, Nthreads) blocks = (Mblocks, Nblocks) _parallel_for_cuda_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -30,7 +30,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
end


function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:CuArray}}) where {I<:Integer,F<:Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down
8 changes: 4 additions & 4 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ module JACCONEAPI

using JACC, oneAPI

function JACC.parallel_for(N::I, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
#maxPossibleItems = oneAPI.oneL0.compute_properties(device().maxTotalGroupSize)
maxPossibleItems = 256
items = min(N, maxPossibleItems)
groups = ceil(Int, N / items)
oneAPI.@sync @oneapi items = items groups = groups _parallel_for_oneapi(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
maxPossibleItems = 16
Mitems = min(M, maxPossibleItems)
Nitems = min(N, maxPossibleItems)
Expand All @@ -20,7 +20,7 @@ function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:
oneAPI.@sync @oneapi items = (Mitems, Nitems) groups = (Mgroups, Ngroups) _parallel_for_oneapi_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
numItems = 256
items = min(N, numItems)
groups = ceil(Int, N / items)
Expand All @@ -31,7 +31,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
return rret
end

function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:oneArray}}) where {I<:Integer,F<:Function}
numItems = 16
Mitems = min(M, numItems)
Nitems = min(N, numItems)
Expand Down
8 changes: 4 additions & 4 deletions src/JACC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ export parallel_for

global Array

function parallel_for(N::I, f::F, x...) where {I<:Integer,F<:Function}
function parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
Threads.@threads :static for i in 1:N
f(i, x...)
end
end

function parallel_for((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
Threads.@threads :static for j in 1:N
for i in 1:M
f(i, j, x...)
end
end
end

function parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
function parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
Threads.@threads :static for i in 1:N
Expand All @@ -34,7 +34,7 @@ function parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function}
return ret
end

function parallel_reduce((M, N)::Tuple{I,I}, f::F, x...) where {I<:Integer,F<:Function}
function parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:Base.Array}}) where {I<:Integer,F<:Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
Threads.@threads :static for j in 1:N
Expand Down

0 comments on commit 6d12657

Please sign in to comment.