Skip to content

Commit

Permalink
Merge pull request #139 from YichengDWu/async
Browse files Browse the repository at this point in the history
docs on pipelining
  • Loading branch information
YichengDWu authored Apr 28, 2024
2 parents a0e232f + 91ba300 commit 80109bd
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ makedocs(; modules=[MoYe],
# ]
"TiledCopy & TiledMMA" => "manual/tiled_matmul.md",
"Memcpy Async" => "manual/async.md",
"Pipeline" => "manual/pipeline.md",
"Tensor Cores" => "manual/tensor_core.md",
],
"API Reference" => [
Expand Down
8 changes: 8 additions & 0 deletions docs/src/assets/pipeline.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 15 additions & 0 deletions docs/src/manual/async.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ function matmul(A, B, C)
B, sB_layout, copy_B,
C, mma_C)
end


function test()
A = CUDA.randn(Float32, 2048, 256)
B = CUDA.randn(Float32, 2048, 256)
C = CUDA.randn(Float32, 2048, 2048)
matmul(A, B, C)
CUDA.synchronize()
@test C == A * B'
CUDA.unsafe_free!(A)
CUDA.unsafe_free!(B)
CUDA.unsafe_free!(C)
end

test()
```

## Vectorized copy
Expand Down
145 changes: 144 additions & 1 deletion docs/src/manual/pipeline.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Overlap global-to-shared copies with mma compute

We can overlap global-to-shared memory copies with mma compute on registers.
We can overlap global-to-shared memory copies with mma compute.

![](https://developer-blogs.nvidia.com/wp-content/uploads/2020/09/sequence-asynchronous-copy-batches-1.png)

Expand Down Expand Up @@ -77,4 +77,147 @@ function matmul_kernel(A, sA_layout, copy_A,
copyto!(tCgC, tCrC)
return nothing
end
```

## Double buffer

We can also overlap shared-to-registers memory copies with mma compute.

To do this we will need to allocate two shared memory buffers, one for the current compute and one
for the next tile. We prefetch the next tile from global memory to shared memory asynchronously.

![matmuil](../assets/pipeline.svg)

```julia
@views function matmul_kernel(A, sA_layout, copy_A,
B, sB_layout, copy_B,
C, mma_C)
M = size(A, 1)
N = size(B, 1)
K = size(A, 2)

bM = size(sA_layout, 1)
bN = size(sB_layout, 1)
bK = size(sB_layout, 2)

sA = MoYeSharedArray(eltype(A), sA_layout) # (bM, bK, 2)
sB = MoYeSharedArray(eltype(B), sB_layout) # (bN, bK, 2)

mA = MoYeArray(A, (M, K))
mB = MoYeArray(B, (N, K))
mC = MoYeArray(C, (M, N))

gA = @tile mA (bM, bK) (blockIdx().x, :)
gB = @tile mB (bN, bK) (blockIdx().y, :)
gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)

# copy partition
thr_copy_a = get_slice(copy_A, threadIdx().x)
tAgA = partition_S(thr_copy_a, gA) # (CPY, CPY_M, CPY_K, k)
tAsA = partition_D(thr_copy_a, sA) # (CPY, CPY_M, CPY_K, 2)

thr_copy_b = get_slice(copy_B, threadIdx().x)
tBgB = partition_S(thr_copy_b, gB) # (CPY, CPY_N, CPY_K, k)
tBsB = partition_D(thr_copy_b, sB) # (CPY, CPY_N, CPY_K, 2)

# Copy gmem to smem for k_tile=1
copyto!(copy_A, tAsA[:, :, :, 1], tAgA[:, :, :, 1])
copyto!(copy_B, tBsB[:, :, :, 1], tBgB[:, :, :, 1])

# mma partition
thr_mma = get_slice(mma_C, threadIdx().x)
tCsA = partition_A(thr_mma, sA) # (MMA, MMA_M, MMA_K, 2)
tCsB = partition_B(thr_mma, sB) # (MMA, MMA_M, MMA_K, 2)
tCgC = partition_C(thr_mma, gC) # (MMA, MMA_M, MMA_N)

# mma registers
tCrA = make_fragment_A(thr_mma, tCsA[:, :, :, 1]) # (MMA, MMA_M, MMA_K)
tCrB = make_fragment_B(thr_mma, tCsB[:, :, :, 1]) # (MMA, MMA_N, MMA_K)
tCrC = make_fragment_C(thr_mma, tCgC) # (MMA, MMA_M, MMA_N)
zeros!(tCrC)

cp_async_wait()
sync_threads()

# Copy smem to rmem for k_block=1
smem_read = 1
smem_write = 2
tCsA_p = view(tCsA, :, :, :, smem_read)
tCsB_p = view(tCsB, :, :, :, smem_read)
copyto!(tCrA[:, :, 1], tCsA_p[:, :, 1])
copyto!(tCrB[:, :, 1], tCsB_p[:, :, 1])

k_tile_max = size(tAgA, 4)
k_block_max = static_size(tCrA, 3)
for k_tile in 1:k_tile_max
@loopinfo unroll for k_block in _1:k_block_max
k_block_next = k_block + 1
if k_block == k_block_max
cp_async_wait()
sync_threads()
tCsA_p = view(tCsA, :, :, :, smem_read)
tCsB_p = view(tCsB, :, :, :, smem_read)
k_block_next = 1
end

copyto!(tCrA[:, :, k_block_next], tCsA_p[:, :, k_block_next])
copyto!(tCrB[:, :, k_block_next], tCsB_p[:, :, k_block_next])

if k_block == _1 && k_tile<k_tile_max
copyto!(copy_A, tAsA[:, :, :, smem_write], tAgA[:, :, :, k_tile+1])
copyto!(copy_B, tBsB[:, :, :, smem_write], tBgB[:, :, :, k_tile+1])
smem_read, smem_write = smem_write, smem_read
end

@gc_preserve gemm!(mma_C, tCrA[:, :, k_block], tCrB[:, :, k_block], tCrC)
end
end

copyto!(tCgC, tCrC)
return nothing
end

function matmul(A, B, C)
bM = _128
bN = _128
bK = _8

sA_layout = make_layout((bM, bK, _2), (_1, bM + _1, (bM + _1) * bK))
sB_layout = make_layout((bN, bK, _2), (_1, bN + _1, (bN + _1) * bK))

TA = eltype(A)
TB = eltype(B)
TC = eltype(C)

copy_A = make_tiled_copy(CopyAtom{UniversalCopy{TA}, TA}(),
@Layout((32, 8)),
@Layout((4, 1)))
copy_B = make_tiled_copy(CopyAtom{UniversalCopy{TB}, TB}(),
@Layout((32, 8)),
@Layout((4, 1)))

mma_C = make_tiled_mma(UniversalFMA{TA,TB, TC}(), # MMA operation
@Layout((16,16))) # Atom layout

threads = Int(size(mma_C))
blocks = (cld(size(A, 1), bM), cld(size(B, 1), bN))

@cuda threads=threads blocks=blocks matmul_kernel(A, sA_layout, copy_A,
B, sB_layout, copy_B,
C, mma_C)
end

function test()
A = CUDA.randn(Float32, 2048, 256)
B = CUDA.randn(Float32, 2048, 256)
C = CUDA.randn(Float32, 2048, 2048)
matmul(A, B, C)
CUDA.synchronize()
@test C == A * B'
CUDA.unsafe_free!(A)
CUDA.unsafe_free!(B)
CUDA.unsafe_free!(C)
end

test()
```
4 changes: 2 additions & 2 deletions src/MoYe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using StrideArraysCore: static_length, static_size, static_axes
using StrideArraysCore: @gc_preserve
using CUDA, BFloat16s, LLVM
using CUDA: @device_override
using LLVMLoopInfo
using LLVMLoopInfo: @loopinfo
using Core: LLVMPtr
import Adapt
using MacroTools: @capture
Expand Down Expand Up @@ -59,7 +59,7 @@ include("device/smem.jl")
include("print.jl")

# rexport
export static, @gc_preserve, static_size
export static, @gc_preserve, static_size, @loopinfo

# tuple algorithms
export flatten, unflatten
Expand Down
4 changes: 2 additions & 2 deletions src/arch/copy/copy_async.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ end
end

"""
cp_async_wait(i::Int32)
cp_async_wait(N::Int32)
cp_async_wait()
`cp.async.wait.group` and `cp.async.wait.all`.
`cp_async_wait(N)` is equivalent to `cp.async.wait.group(N)` and `cp_async_wait()` is equivalent to `cp.async.wait.all` in CUDA.
"""
@inline cp_async_wait(i::Int32) = ccall("llvm.nvvm.cp.async.wait.group", llvmcall, Cvoid, (Int32,), i)
@inline cp_async_wait() = ccall("llvm.nvvm.cp.async.wait.all", llvmcall, Cvoid, ())
Expand Down
8 changes: 6 additions & 2 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ layout(::Type{<:StaticMoYeArray{T,N,E,L}}) where {T,N,E,L} = L

# static interface
@inline StaticArrayInterface.static_size(x::StaticMoYeArray) = map(capacity, shape(layout(x)))
@inline StaticArrayInterface.static_size(x::StaticMoYeArray, i::IntType) = size(layout(x), i)
@inline StaticArrayInterface.static_size(x::A, i::Union{Int, StaticInt}) where {A<:StaticMoYeArray}= size(layout(x), i)

@inline function StaticArrayInterface.static_axes(x::StaticMoYeArray{T,N,<:ViewEngine}) where {T,N}
return map(Base.oneto, static_size(x))
Expand Down Expand Up @@ -160,7 +160,7 @@ end
Return a pointer to the element at the logical index `i` in `A`, not the physical index.
"""
@inline function Base.pointer(x::MoYeArray{T}, i::IntType) where {T}
idx = x.layout(convert(Int, i))
idx = x.layout(i)
return pointer(x) + (idx-one(idx))*sizeof(T)
end
@inline function Base.pointer(x::MoYeArray{T}, coord::Tuple) where {T}
Expand Down Expand Up @@ -294,6 +294,10 @@ julia> x3 = recast(Int64, x)
@gc_preserve _recast(NewType, x)
end

@inline function recast(::Type{OldType}, x::MoYeArray{OldType}) where {OldType}
return x
end

function _recast(::Type{NewType}, x::MoYeArray{OldType}) where {NewType, OldType}
@inline
old_layout = layout(x)
Expand Down
6 changes: 3 additions & 3 deletions src/atom/mma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end
return :($thr_layout_vmnk_)
end

function thrfrg_C(m::TiledMMA, C::Layout{2})
function thrfrg_C(m::TiledMMA, C::Layout)
thr_layout_vmnk = get_thr_layout_vmnk(m)
atom_mnk = shape_mnk(m.atom)
permutation_mnk = m.permutation_mnk
Expand All @@ -135,7 +135,7 @@ function thrfrg_C(m::TiledMMA, C::Layout{2})
return thr_array
end

function thrfrg_A(m::TiledMMA, A::Layout{2})
function thrfrg_A(m::TiledMMA, A::Layout)
thr_layout_vmnk = get_thr_layout_vmnk(m)
atom_mnk = shape_mnk(m.atom)
permutation_mnk = m.permutation_mnk
Expand All @@ -151,7 +151,7 @@ function thrfrg_A(m::TiledMMA, A::Layout{2})
return thr_array
end

function thrfrg_B(m::TiledMMA, B::Layout{2})
function thrfrg_B(m::TiledMMA, B::Layout)
thr_layout_vmnk = get_thr_layout_vmnk(m)
atom_mnk = shape_mnk(m.atom)
permutation_mnk = m.permutation_mnk
Expand Down

0 comments on commit 80109bd

Please sign in to comment.