Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

[Bugfix] Fix marlin 2:4 kernel crash on H100 #243

Merged
merged 1 commit into from
May 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[Bugfix] Fix marlin 2:4 kernel crash on H100
mgoin authored May 15, 2024
commit 2460ca3b90beec28127083d1fa61dec131c1ab0c
16 changes: 5 additions & 11 deletions csrc/quantization/marlin/sparse/common/mem.h
Original file line number Diff line number Diff line change
@@ -45,19 +45,13 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
);
}

// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for
// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
// for inputs A and outputs C.
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
// Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
);
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
}

// Async copy fence.
10 changes: 5 additions & 5 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -392,7 +392,7 @@ __global__ void Marlin_24(
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o;
@@ -401,15 +401,15 @@ __global__ void Marlin_24(
#pragma unroll
for (int i = 0; i < m_sh_iters; i++) {
if (m_sh_wr_pred)
cp_async4_stream(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
meta_ptr[i]);
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
@@ -763,12 +763,12 @@ __global__ void Marlin_24(
if constexpr (group_blocks == -1) {
if constexpr (num_bits == 8) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
} else {
if (last) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
}