Skip to content

Commit

Permalink
Merge pull request #497 from IanButterworth/ib/task_local
Browse files Browse the repository at this point in the history
move from bad thread-local to task-local
  • Loading branch information
ToucheSir authored Jun 17, 2023
2 parents a3cdee6 + f30b3dd commit a5fbf95
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 50 deletions.
54 changes: 32 additions & 22 deletions src/impl/conv_im2col.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ function conv_im2col!(
y::AbstractArray{T,5}, x::AbstractArray{T,5},
w::AbstractArray{T,5}, cdims::DenseConvDims;
col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)),
alpha::T=T(1), beta::T=T(0)) where {T}
alpha::T=T(1), beta::T=T(0),
ntasks::Int=nthreads()) where {T}
check_dims(size(x), size(w), size(y), cdims)

# COL * W -> Y
Expand All @@ -44,16 +45,20 @@ function conv_im2col!(
N = channels_out(cdims)
K = prod(kernel_size(cdims))*channels_in(cdims)

@threads for batch_idx in 1:size(x,5)
# col_slice is a thread-local workspace
col_slice = view(col, :, :, threadid())

im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
GC.@preserve col_slice w y begin
col_ptr = pointer(col_slice)
w_ptr = pointer(w)
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks))

@sync for (task_n, part) in enumerate(parts)
Threads.@spawn begin
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
for batch_idx in part
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
GC.@preserve col_slice w y begin
col_ptr = pointer(col_slice)
w_ptr = pointer(w)
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
end
end
end
end
return y
Expand Down Expand Up @@ -122,7 +127,8 @@ function ∇conv_data_im2col!(
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
w::AbstractArray{T,5}, cdims::DenseConvDims;
col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)),
alpha::T=T(1), beta::T=T(0)) where {T}
alpha::T=T(1), beta::T=T(0),
ntasks::Int=nthreads()) where {T}
check_dims(size(dx), size(w), size(dy), cdims)

# dY W' -> dX
Expand All @@ -144,17 +150,21 @@ function ∇conv_data_im2col!(
N = prod(kernel_size(cdims))*channels_in(cdims)
K = channels_out(cdims)

@threads for batch_idx in 1:size(dx, 5)
# col_slice is a thread-local workspace
col_slice = view(col, :, :, threadid())

GC.@preserve col_slice w dy begin
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
w_ptr = pointer(w)
col_ptr = pointer(col_slice)
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks))

@sync for (task_n, part) in enumerate(parts)
Threads.@spawn begin
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
for batch_idx in part
GC.@preserve col_slice w dy begin
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
w_ptr = pointer(w)
col_ptr = pointer(col_slice)
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
end
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
end
end
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
end
return dx
end
Expand Down
67 changes: 39 additions & 28 deletions src/impl/depthwiseconv_im2col.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ function depthwiseconv_im2col!(
y::AbstractArray{T,5}, x::AbstractArray{T,5},
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
col::AbstractArray{T,3} = similar(x, im2col_dims(cdims)),
alpha::T=T(1), beta::T=T(0)) where T
alpha::T=T(1), beta::T=T(0),
ntasks::Int=nthreads()) where T
check_dims(size(x), size(w), size(y), cdims)

# This functions exactly the same as conv_im2col!(), except that we shard the
Expand All @@ -25,21 +26,26 @@ function depthwiseconv_im2col!(
N = channel_multiplier(cdims)
K = prod(kernel_size(cdims))

dcdims = DenseConvDims(cdims)
@threads for batch_idx in 1:size(x)[end]
# col_slice is a thread-local workspace
col_slice = view(col, :, :, threadid())
parts = Iterators.partition(axes(y)[end], ceil(Int, size(y, 5) / ntasks))

im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
dcdims = DenseConvDims(cdims)

# We do a separate convolution for each channel in x, as we must
for c_in in 1:channels_in(cdims)
# Walk each pointer forward as we process each input channel
GC.@preserve col_slice w y begin
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
w_ptr = pointer(w, (c_in-1)*K*N+1)
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
@sync for (task_n, part) in enumerate(parts)
Threads.@spawn begin
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
for batch_idx in part
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)

# We do a separate convolution for each channel in x, as we must
for c_in in 1:channels_in(cdims)
# Walk each pointer forward as we process each input channel
GC.@preserve col_slice w y begin
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
w_ptr = pointer(w, (c_in-1)*K*N+1)
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
end
end
end
end
end
Expand Down Expand Up @@ -101,28 +107,33 @@ function ∇depthwiseconv_data_im2col!(
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)),
alpha::T=T(1), beta::T=T(0)) where T
alpha::T=T(1), beta::T=T(0),
ntasks::Int=nthreads()) where T
check_dims(size(dx), size(w), size(dy), cdims)

M = prod(output_size(cdims))
N = prod(kernel_size(cdims))
K = channel_multiplier(cdims)

@threads for batch_idx in 1:size(dx)[end]
# col_slice is a thread-local workspace
col_slice = view(col, :, :, threadid())

# We do a separate convolution for each channel in x, as we must
for cidx in 1:channels_in(cdims)
GC.@preserve col_slice w dy begin
# Walk each pointer forward as we process each input channel
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks))

@sync for (task_n, part) in enumerate(parts)
Threads.@spawn begin
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
for batch_idx in part
# We do a separate convolution for each channel in x, as we must
for cidx in 1:channels_in(cdims)
GC.@preserve col_slice w dy begin
# Walk each pointer forward as we process each input channel
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
end
end
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
end
end
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims)
end
return dx
end

0 comments on commit a5fbf95

Please sign in to comment.