diff --git a/llamafile/tinyblas.cu b/llamafile/tinyblas.cu index 19779f6427..c320be954c 100644 --- a/llamafile/tinyblas.cu +++ b/llamafile/tinyblas.cu @@ -23,88 +23,118 @@ #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) -template -static __device__ void matmul32_block2d(int m, int n, int k, int x, int y, - const float *A, int lda, float *As, - const float *B, int ldb, float *Bs, - void *C, int ldc, float *Cs) { - assert(blockDim.x == BK); - static_assert(BK == BM, ""); - static_assert(BN <= BM, ""); - const int i = threadIdx.x; - int j, l, blob; +template +__device__ __forceinline__ DST typechange(SRC x) { + static_assert(std::is_same::value, "write a specialization"); + if (std::is_same::value) { + return x; + } else { + return (DST)(0); + } +} + +template<> +__device__ __forceinline__ float typechange(half x) { + return __half2float(x); +} + +template<> +__device__ __forceinline__ half typechange(float x) { + return __float2half(x); +} + +template +static __device__ void matmul_block2d(int m, int n, int k, int x, int y, + const SRC *A, int lda, + const SRC *B, int ldb, + DST *C, int ldc) { + static_assert((BM % TM == 0) && (BN % TN == 0), + "can't divide work for threads"); + static_assert(BK == ((BM * BN) / (TM * TN)), // optional + "threads can't load memory properly"); + static_assert((BM * BN) <= (BM * BK) + (BK * BN), + "didn't allocate enough shared mem for threads"); + const int ii0 = threadIdx.x / (BN / TN); /* {0, ..., (BM/TM) - 1} */ + const int ii1 = threadIdx.x % (BN / TN); /* {0, ..., (BN/TN) - 1} */ + extern __shared__ float svals[]; // shared across all threads in a block + SRC *As = (SRC *) svals; + SRC *Bs = (SRC *) svals + BM * BK; + + SRC Cs[TM * TN]; + SRC At[TM]; + SRC Bt[TN]; + int i, h, j, l, blob; // within each block // we first zero out Cs - for (j = 0; j < BN; ++j) Cs[j] = 0; + for (j = 0; j < TM * TN; ++j) Cs[j] = 0; for (blob = 0; blob < k; blob += BK) { - if (i < BK) { - if ((blob + i) < k) { - // we copy into As from A - for (j = 0; j < BM && x + j < m; ++j) { - As[(j * BK) + i] = - READ(A, TINYBLAS_OP_T, lda, x + j, blob + i); - } - for (; j < BM; ++j) As[(j * BK) + i] = 0; - // we copy into Bs from B - for (j = 0; j < BN && y + j < n; ++j) { - Bs[(i * BN) + j] = - READ(B, TINYBLAS_OP_N, ldb, blob + i, y + j); - } - for (; j < BN; ++j) Bs[(i * BN) + j] = 0; - } else { // UNLIKELY - for (j = 0; j < BM; ++j) As[(j * BK) + i] = 0; - for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0; + for (i = threadIdx.x; i < BK; i += blockDim.x) { + for (j = 0; j < BM + BN; ++j) { + As[(j * BK) + i] = 0; } } __syncthreads(); + for (i = threadIdx.x; i < BK && blob + i < k; i += blockDim.x) { + // we copy into As from A + for (j = 0; j < BM && x + j < m; ++j) { + As[(i * BM) + j] = READ(A, TINYBLAS_OP_T, lda, x + j, blob + i); + } + } + for (i = threadIdx.x; i < BK && blob + i < k; i += blockDim.x) { + // we copy into Bs from B + for (j = 0; j < BN && y + j < n; ++j) { + Bs[(i * BN) + j] = READ(B, TINYBLAS_OP_N, ldb, blob + i, y + j); + } + } + __syncthreads(); // We matmul the blobs, basically Cs += matmul(As, Bs) - for (j = 0; j < BN; ++j) { - for (l = 0; l < BK; ++l) { - Cs[j] += As[(i * BK) + l] * Bs[(l * BN) + j]; + for (l = 0; l < BK; ++l) { + for (j = 0; j < TM; ++j) At[j] = As[(l * BM) + ii0 * TM + j]; + for (h = 0; h < TN; ++h) Bt[h] = Bs[(l * BN) + ii1 * TN + h]; + for (j = 0; j < TM; ++j) { + for (h = 0; h < TN; ++h) { + Cs[j * TN + h] += At[j] * Bt[h]; + } } } __syncthreads(); } - - for (j = 0; j < BN; ++j) { - As[(i*BN) + j] = Cs[j]; - } + __syncthreads(); // We write Cs out into C - if (y + i < n && i < BN) { - for (j = 0; j < BM && x + j < m; ++j) { - *((float *)C + (x + j) + (y + i) * ldc) = As[j*BN + i]; + x += ii0 * TM; + y += ii1 * TN; + for (j = 0; j < TM && x + j < m; ++j) { + for (l = 0; l < TN && y + l < n; ++l) { + *(C + (x + j) + (y + l) * ldc) = typechange(Cs[j * TN + l]); } } __syncthreads(); } -template +template static __global__ void tinyblasS_entry(int m, int n, int k, const float *A, int lda, const float *B, int ldb, float *C, int ldc) { + assert(blockDim.x == BK); int x = blockIdx.x * BM; const int jump1 = gridDim.x * BM; int y = blockIdx.y * BN; const int jump2 = gridDim.y * BN; - extern __shared__ float svals[]; // shared across all threads in a block - float *As = svals; - float *Bs = svals + BM * BK; - float Cs[BN]; // only within a particular thread - // each block handles a sub-matrix of C, of size BM * BN - // each thread handles a sub-row of size BN + // each thread handles a sub-matrix of size TM * TN for (x = blockIdx.x * BM; x < m; x += jump1) { for (y = blockIdx.y * BN; y < n; y += jump2) { - matmul32_block2d(m, n, k, x, y, // - A, lda, As, // - B, ldb, Bs, // - C, ldc, Cs); + matmul_block2d( + m, n, k, x, y, // + A, lda, // + B, ldb, // + C, ldc); } } } @@ -127,14 +157,14 @@ static bool check_args(tinyblasOperation_t transa, tinyblasOperation_t transb, *(float *)pBeta == 0.0f))); } -template +template static void tinyblasS_wrapper(tinyblasHandle_t stream, int m, int n, int k, const float *A, int lda, const float *B, int ldb, float *C, int ldc) { dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1); - int maxthreads = BK; + int maxthreads = ((BM * BN) / (TM * TN)); - tinyblasS_entry + tinyblasS_entry <<>>(m, n, k, A, lda, B, ldb, C, ldc); } @@ -153,113 +183,48 @@ tinyblasStatus_t tinyblasSgemm(tinyblasHandle_t stream, return TINYBLAS_STATUS_NOT_SUPPORTED; } - tinyblasS_wrapper<48, 12, 48>(stream, m, n, k, A, lda, B, ldb, C, ldc); + tinyblasS_wrapper<48, 24, 64, 6, 3>(stream, m, n, k, A, lda, B, ldb, C, ldc); return TINYBLAS_STATUS_SUCCESS; } -template -static __device__ void matmul_block2d(int m, int n, int k, int x, int y, - const half *A, int lda, float *As, - const half *B, int ldb, float *Bs, - void *C, cudaDataType_t Ctype, int ldc, - float *Cs) { - assert(blockDim.x == BK); - static_assert(BK == BM, ""); - static_assert(BN <= BM, ""); - const int i = threadIdx.x; - int j, l, blob; - // within each block - // we first zero out Cs - for (j = 0; j < BN; ++j) Cs[j] = 0; - - for (blob = 0; blob < k; blob += BK) { - if (i < BK) { - if ((blob + i) < k) { - // we copy into As from A - for (j = 0; j < BM && x + j < m; ++j) { - As[(j * BK) + i] = - READ16(A, TINYBLAS_OP_T, lda, x + j, blob + i); - } - for (; j < BM; ++j) As[(j * BK) + i] = 0; - // we copy into Bs from B - for (j = 0; j < BN && y + j < n; ++j) { - Bs[(i * BN) + j] = - READ16(B, TINYBLAS_OP_N, ldb, blob + i, y + j); - } - for (; j < BN; ++j) Bs[(i * BN) + j] = 0; - } else { // UNLIKELY - for (j = 0; j < BM; ++j) As[(j * BK) + i] = 0; - for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0; - } - } - __syncthreads(); - - // We matmul the blobs, basically Cs += matmul(As, Bs) - for (j = 0; j < BN; ++j) { - for (l = 0; l < BK; ++l) { - Cs[j] += As[(i * BK) + l] * Bs[(l * BN) + j]; - } - } - __syncthreads(); - } - - for (j = 0; j < BN; ++j) { - As[(i*BN) + j] = Cs[j]; - } - - // We write Cs out into C - if (y + i < n && i < BN) { - if (Ctype == CUDA_R_16F) { - for (j = 0; j < BM && x + j < m; ++j) { - *((half *)C + (x + j) + (y + i) * ldc) = __float2half(As[j*BN + i]); - } - } else { - for (j = 0; j < BM && x + j < m; ++j) { - *((float *)C + (x + j) + (y + i) * ldc) = As[j*BN + i]; - } - } - } - __syncthreads(); -} // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex -template +template static __global__ void tinyblasGE_entry(int m, int n, int k, const half *A, int lda, const half *B, int ldb, - void *C, cudaDataType_t Ctype, - int ldc) { + void *C, int ldc) { int x = blockIdx.x * BM; const int jump1 = gridDim.x * BM; int y = blockIdx.y * BN; const int jump2 = gridDim.y * BN; - extern __shared__ float svals[]; // shared across all threads in a block - float *As = svals; - float *Bs = svals + BM * BK; - float Cs[BN]; // only within a particular thread - // each block handles a sub-matrix of C, of size BM * BN - // each thread handles a sub-row of size BN for (x = blockIdx.x * BM; x < m; x += jump1) { for (y = blockIdx.y * BN; y < n; y += jump2) { - matmul_block2d(m, n, k, x, y, // - A, lda, As, // - B, ldb, Bs, // - C, Ctype, ldc, Cs); + matmul_block2d(m, n, k, x, y, // + A, lda, // + B, ldb, // + (DST *)C, ldc); } } } -template +template static void tinyblasGE_wrapper(tinyblasHandle_t stream, int m, int n, int k, const half *A, int lda, const half *B, int ldb, void *C, cudaDataType_t Ctype, int ldc) { dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1); - int maxthreads = BK; - - tinyblasGE_entry - <<>>(m, n, k, A, lda, B, ldb, C, Ctype, ldc); + int maxthreads = ((BM * BN) / (TM * TN)); + + if (Ctype == CUDA_R_16F) { + tinyblasGE_entry + <<>>(m, n, k, A, lda, B, ldb, C, ldc); + } else { + tinyblasGE_entry + <<>>(m, n, k, A, lda, B, ldb, C, ldc); + } } tinyblasStatus_t tinyblasGemmEx(tinyblasHandle_t stream, @@ -286,19 +251,18 @@ tinyblasStatus_t tinyblasGemmEx(tinyblasHandle_t stream, return TINYBLAS_STATUS_NOT_SUPPORTED; } - tinyblasGE_wrapper<48, 12, 48>(stream, m, n, k, (const half *)A, lda, - (const half *)B, ldb, C, Ctype, ldc); + tinyblasGE_wrapper<48, 32, 64, 3, 8>(stream, m, n, k, (const half *)A, lda, + (const half *)B, ldb, C, Ctype, ldc); return TINYBLAS_STATUS_SUCCESS; } // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmbatchedex -template +template static __global__ void tinyblasGBE_entry(int m, int n, int k, const half *const Aarray[], int lda, const half *const Barray[], int ldb, - void *const Carray[], - cudaDataType_t Ctype, int ldc, + void *const Carray[], int ldc, int batchCount) { int x = blockIdx.x * BM; const int jump1 = gridDim.x * BM; @@ -307,38 +271,39 @@ static __global__ void tinyblasGBE_entry(int m, int n, int k, int z = blockIdx.z; const int jump3 = gridDim.z; - extern __shared__ float svals[]; // shared across all threads in a block - float *As = svals; - float *Bs = svals + BM * BK; - float Cs[BN]; // only within a particular thread - // each block handles a sub-matrix of C, of size BM * BN - // each thread handles a sub-row of size BN for (z = blockIdx.z; z < batchCount; z += jump3) { for (x = blockIdx.x * BM; x < m; x += jump1) { for (y = blockIdx.y * BN; y < n; y += jump2) { - matmul_block2d(m, n, k, x, y, // - Aarray[z], lda, As, // - Barray[z], ldb, Bs, // - Carray[z], Ctype, ldc, Cs); + matmul_block2d(m, n, k, x, y, // + Aarray[z], lda, // + Barray[z], ldb, // + (DST *)(Carray[z]), ldc); } } } } -template +template static void tinyblasGBE_wrapper(tinyblasHandle_t stream, int m, int n, int k, const half *const Aarray[], int lda, const half *const Barray[], int ldb, void *const Carray[], cudaDataType_t Ctype, int ldc, int batchCount) { dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32); - int maxthreads = BK; - - tinyblasGBE_entry - <<>>(m, n, k, Aarray, lda, Barray, - ldb, Carray, Ctype, ldc, batchCount); + int maxthreads = ((BM * BN) / (TM * TN)); + + if (Ctype == CUDA_R_16F) { + tinyblasGBE_entry + <<>>(m, n, k, Aarray, lda, Barray, ldb, Carray, ldc, + batchCount); + } else { + tinyblasGBE_entry + <<>>(m, n, k, Aarray, lda, Barray, ldb, Carray, ldc, + batchCount); + } } tinyblasStatus_t tinyblasGemmBatchedEx(tinyblasHandle_t stream, @@ -366,14 +331,14 @@ tinyblasStatus_t tinyblasGemmBatchedEx(tinyblasHandle_t stream, return TINYBLAS_STATUS_NOT_SUPPORTED; } - tinyblasGBE_wrapper<48, 12, 48>(stream, m, n, k, (const half **)Aarray, lda, + tinyblasGBE_wrapper<48, 32, 64, 3, 8>(stream, m, n, k, (const half **)Aarray, lda, (const half **)Barray, ldb, Carray, Ctype, ldc, batchCount); return TINYBLAS_STATUS_SUCCESS; } // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex -template +template static __global__ void tinyblasGSBE_entry(int m, int n, int k, const half *A, int lda, @@ -382,7 +347,6 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k, int ldb, long long int strideB, void *C, - cudaDataType_t Ctype, int ldc, long long int strideC, int batchCount) { @@ -393,45 +357,46 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k, int z = blockIdx.z; const int jump3 = gridDim.z; - extern __shared__ float svals[]; // shared across all threads in a block - float *As = svals; - float *Bs = svals + BM * BK; - float Cs[BN]; // only within a particular thread - // each block handles a sub-matrix of C, of size BM * BN - // each thread handles a sub-row of size BN for (z = blockIdx.z; z < batchCount; z += jump3) { for (x = blockIdx.x * BM; x < m; x += jump1) { for (y = blockIdx.y * BN; y < n; y += jump2) { - matmul_block2d( - m, n, k, x, y, // - A + z * strideA, lda, As, // - B + z * strideB, ldb, Bs, // - (Ctype == CUDA_R_16F ? (void *)((half *)C + z * strideC) - : (void *)((float *)C + z * strideC)), - Ctype, ldc, Cs); + matmul_block2d( + m, n, k, x, y, // + A + z * strideA, lda, // + B + z * strideB, ldb, // + ((DST *)C + z * strideC), ldc); } } } } -template +template static void tinyblasGSBE_wrapper(tinyblasHandle_t stream, int m, int n, int k, const half *A, int lda, long long int strideA, const half *B, int ldb, long long int strideB, void *C, cudaDataType_t Ctype, int ldc, long long int strideC, int batchCount) { - // call the entry function dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32); - int maxthreads = BK; - - tinyblasGSBE_entry - <<>>(m, n, k, // - A, lda, strideA, // - B, ldb, strideB, // - C, Ctype, ldc, strideC, // - batchCount); + int maxthreads = ((BM * BN) / (TM * TN)); + + if (Ctype == CUDA_R_16F) { + tinyblasGSBE_entry + <<>>(m, n, k, // + A, lda, strideA, // + B, ldb, strideB, // + C, ldc, strideC, // + batchCount); + } else { + tinyblasGSBE_entry + <<>>(m, n, k, // + A, lda, strideA, // + B, ldb, strideB, // + C, ldc, strideC, // + batchCount); + } } tinyblasStatus_t tinyblasGemmStridedBatchedEx(tinyblasHandle_t stream, @@ -460,7 +425,7 @@ tinyblasStatus_t tinyblasGemmStridedBatchedEx(tinyblasHandle_t stream, return TINYBLAS_STATUS_NOT_SUPPORTED; } - tinyblasGSBE_wrapper<64, 4, 64>(stream, m, n, k, (const half *)A, lda, strideA, + tinyblasGSBE_wrapper<32, 4, 64, 1, 2>(stream, m, n, k, (const half *)A, lda, strideA, (const half *)B, ldb, strideB, C, Ctype, ldc, strideC, batchCount);