Skip to content

Commit

Permalink
Separate kernel for GemmStridedBatchedEx (#163)
Browse files Browse the repository at this point in the history
2D blocking works reasonably quickly for the other kernels, but caused a
slowdown with GemmStridedBatchedEx because the work was not split
optimally. Chaning the constants BM, BN, BK helps solve this.
  • Loading branch information
ahgamut authored Jan 3, 2024
1 parent e0127f7 commit c2bc6e6
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion llamafile/tinyblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,77 @@ cublasStatus_t tinyblasGemmBatchedEx(cudaStream_t stream,
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
#undef BM
#undef BN
#undef BK
#define BM 32
#define BN 4
#define BK 32

static __device__ void matmul_block2d_sb(int m, int n, int k, int x, int y,
const half *A, int lda, float *As,
const half *B, int ldb, float *Bs,
void *C, cudaDataType_t Ctype, int ldc,
float *Cs) {
assert(blockDim.x == BM);
static_assert(BK <= BM);
const int i = threadIdx.x;
int j, l, blob;
// within each block
// we first zero out Cs
for (j = 0; j < BN; ++j) Cs[j] = 0;

for (blob = 0; blob < k; blob += BK) {
// we copy into As from A
if (i < BM) {
if ((x + i) < m) {
for (j = 0; j < BK && blob + j < k; ++j) {
As[(i * BK) + j] =
READ16(A, CUBLAS_OP_T, lda, x + i, blob + j);
}
for (; j < BK; ++j) As[(i * BK) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BK; ++j) As[(i * BK) + j] = 0;
}
}

// we copy into Bs from B
if (i < BK) {
if ((blob + i) < k) {
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] =
READ16(B, CUBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0;
}
}
__syncthreads();

// We matmul the blobs, basically Cs += matmul(As, Bs)
for (j = 0; j < BN; ++j) {
for (l = 0; l < BK; ++l) {
Cs[j] += As[(i * BK) + l] * Bs[(l * BN) + j];
}
}
__syncthreads();
}

// We write Cs out into C
if (x + i < m) {
if (Ctype == CUDA_R_16F) {
for (j = 0; j < BN && y + j < n; ++j) {
*((half *)C + (x + i) + (y + j) * ldc) = __float2half(Cs[j]);
}
} else {
for (j = 0; j < BN && y + j < n; ++j) {
*((float *)C + (x + i) + (y + j) * ldc) = Cs[j];
}
}
}
__syncthreads();
}

static __global__ void tinyblasGSBE_entry(int m, int n, int k,
const half *A,
Expand Down Expand Up @@ -376,7 +447,7 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k,
for (z = blockIdx.z; z < batchCount; z += jump3) {
for (x = blockIdx.x * BM; x < m; x += jump1) {
for (y = blockIdx.y * BN; y < n; y += jump2) {
matmul_block2d(m, n, k, x, y, //
matmul_block2d_sb(m, n, k, x, y, //
A + z * strideA, lda, As, //
B + z * strideB, ldb, Bs, //
(Ctype == CUDA_R_16F
Expand Down

0 comments on commit c2bc6e6

Please sign in to comment.