Skip to content

Commit

Permalink
CUDA: optimize MMQ int8 tensor core performance (#8062)
Browse files Browse the repository at this point in the history
* CUDA: optimize MMQ int8 tensor core performance

* only a single get_mma_tile_x_k function

* simplify code, make functions constexpr
  • Loading branch information
JohannesGaessler authored Jun 24, 2024
1 parent 52fc870 commit 9a590c8
Show file tree
Hide file tree
Showing 3 changed files with 879 additions and 547 deletions.
4 changes: 2 additions & 2 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S;
};

static int get_mmq_x_max_host(const int cc) {
static constexpr int get_mmq_x_max_host(int cc) {
#ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else
Expand All @@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
}

// Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc) {
static constexpr int get_mmq_y_host(int cc) {
return cc >= CC_VOLTA ? 128 : 64;
}

Expand Down
56 changes: 56 additions & 0 deletions ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_A_I16K8 {
Expand All @@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_B_J8K4 {
Expand All @@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(x[0])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_B_J8K8 {
Expand All @@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_C_I16J8 {
Expand Down
Loading

0 comments on commit 9a590c8

Please sign in to comment.