diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 8f6fd71cfea35e..8c1beb32e9b0f1 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h" @@ -123,6 +124,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index c0a4447075c6ef..736b7091162fb6 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda<__nv_bfloat16>; default: return nullptr; } diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 47d4d5d9e91da5..33c4e5ed16ea6f 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -422,6 +422,14 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int v.y = x[ib + iqs + 1]; } +static __device__ void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const __nv_bfloat16 * x = (const __nv_bfloat16 *) vx; + + // automatic __nv_bfloat16 -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} + template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block @@ -584,6 +592,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa <<>>(vx, y, dst, ncols, nrows); } +static void convert_mul_mat_vec_bf16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_bf16> + <<>>(vx, y, dst, ncols, nrows); +} + void ggml_cuda_op_dequantize_mul_mat_vec( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, @@ -649,6 +666,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec( case GGML_TYPE_F16: convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; default: GGML_ASSERT(false); break;