Skip to content

Commit

Permalink
Add basic bf16 support to ggml-cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed May 23, 2024
1 parent 152da28 commit caf0dcb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
Expand Down Expand Up @@ -123,6 +124,7 @@
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
Expand Down
2 changes: 2 additions & 0 deletions ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cuda<__nv_bfloat16>;
default:
return nullptr;
}
Expand Down
20 changes: 20 additions & 0 deletions ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,14 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
v.y = x[ib + iqs + 1];
}

static __device__ void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const __nv_bfloat16 * x = (const __nv_bfloat16 *) vx;

// automatic __nv_bfloat16 -> float type cast if dfloat == float
v.x = x[ib + iqs + 0];
v.y = x[ib + iqs + 1];
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
Expand Down Expand Up @@ -584,6 +592,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

static void convert_mul_mat_vec_bf16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_bf16>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

void ggml_cuda_op_dequantize_mul_mat_vec(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
Expand Down Expand Up @@ -649,6 +666,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
case GGML_TYPE_F16:
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_BF16:
convert_mul_mat_vec_bf16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
default:
GGML_ASSERT(false);
break;
Expand Down

0 comments on commit caf0dcb

Please sign in to comment.