diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 0317310a708ee2..90a0a81ead7896 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -484,57 +484,147 @@ static __device__ __forceinline__ float get_alibi_slope( return powf(base, exph); } -static constexpr __device__ int ggml_blck_size_device(ggml_type type) { - return type == GGML_TYPE_F16 ? 1 : - type == GGML_TYPE_Q4_0 ? QK4_0 : - type == GGML_TYPE_Q4_1 ? QK4_1 : - type == GGML_TYPE_Q5_0 ? QK5_0 : - type == GGML_TYPE_Q5_1 ? QK5_1 : - type == GGML_TYPE_Q8_0 ? QK8_0 : - type == GGML_TYPE_Q2_K ? QK_K : - type == GGML_TYPE_Q3_K ? QK_K : - type == GGML_TYPE_Q4_K ? QK_K : - type == GGML_TYPE_Q5_K ? QK_K : - type == GGML_TYPE_Q6_K ? QK_K : - type == GGML_TYPE_IQ2_XXS ? QK_K : - type == GGML_TYPE_IQ2_XS ? QK_K : - type == GGML_TYPE_IQ2_S ? QK_K : - type == GGML_TYPE_IQ3_XXS ? QK_K : - type == GGML_TYPE_IQ1_S ? QK_K : - type == GGML_TYPE_IQ1_M ? QK_K : - type == GGML_TYPE_IQ4_NL ? QK4_NL : - type == GGML_TYPE_IQ4_XS ? QK_K : - type == GGML_TYPE_IQ3_S ? QK_K : - 0; -} +template +struct ggml_cuda_type_traits; -static constexpr __device__ int get_qr_device(ggml_type type) { - return type == GGML_TYPE_F16 ? 1 : - type == GGML_TYPE_Q4_0 ? QR4_0 : - type == GGML_TYPE_Q4_1 ? QR4_1 : - type == GGML_TYPE_Q5_0 ? QR5_0 : - type == GGML_TYPE_Q5_1 ? QR5_1 : - type == GGML_TYPE_Q8_0 ? QR8_0 : - type == GGML_TYPE_Q2_K ? QR2_K : - type == GGML_TYPE_Q3_K ? QR3_K : - type == GGML_TYPE_Q4_K ? QR4_K : - type == GGML_TYPE_Q5_K ? QR5_K : - type == GGML_TYPE_Q6_K ? QR6_K : - type == GGML_TYPE_IQ2_XXS ? QR2_XXS : - type == GGML_TYPE_IQ2_XS ? QR2_XS : - type == GGML_TYPE_IQ2_S ? QR2_S : - type == GGML_TYPE_IQ3_XXS ? QR3_XXS : - type == GGML_TYPE_IQ1_S ? QR1_S : - type == GGML_TYPE_IQ1_M ? QR1_M : - type == GGML_TYPE_IQ4_NL ? QR4_NL : - type == GGML_TYPE_IQ4_XS ? QR4_XS : - type == GGML_TYPE_IQ3_S ? QR3_S : - 0; -} +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = 1; + static constexpr int qr = 1; +}; -static constexpr __device__ int get_qi_device(ggml_type type) { - return ggml_blck_size_device(type) / (sizeof(int)*get_qr_device(type)); -} +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_0; + static constexpr int qr = QR4_0; + static constexpr int qi = QI4_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_1; + static constexpr int qr = QR4_1; + static constexpr int qi = QI4_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_0; + static constexpr int qr = QR5_0; + static constexpr int qi = QI5_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_1; + static constexpr int qr = QR5_1; + static constexpr int qi = QI5_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK8_0; + static constexpr int qr = QR8_0; + static constexpr int qi = QI8_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_K; + static constexpr int qi = QI2_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_K; + static constexpr int qi = QI3_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_K; + static constexpr int qi = QI4_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_K; + static constexpr int qi = QI5_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR6_K; + static constexpr int qi = QI6_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XXS; + static constexpr int qi = QI2_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XS; + static constexpr int qi = QI2_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_S; + static constexpr int qi = QI2_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_XXS; + static constexpr int qi = QI3_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_S; + static constexpr int qi = QI1_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_M; + static constexpr int qi = QI1_M; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_NL; + static constexpr int qr = QR4_NL; + static constexpr int qi = QI4_NL; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_S; + static constexpr int qi = QI3_S; +}; static int get_mmq_x_max_host(const int cc) { #ifdef CUDA_USE_TENSOR_CORES diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 601d44e3ae2823..174489e0665d38 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -434,8 +434,8 @@ static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - constexpr int qk = ggml_blck_size_device(type); // quantized weights per x block - constexpr int qr = get_qr_device(type); // number of quantized weights per data value in x block + constexpr int qk = ggml_cuda_type_traits::qk; // quantized weights per x block + constexpr int qr = ggml_cuda_type_traits::qr; // number of quantized weights per data value in x block constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type); const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index dda1e341ed5660..c9a6ced71d6310 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -1033,10 +1033,10 @@ static __global__ void mul_mat_q( return; } + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qr = ggml_cuda_type_traits::qr; + constexpr int qi = ggml_cuda_type_traits::qi; constexpr int mmq_y = get_mmq_y_device(mmq_x); - constexpr int qk = ggml_blck_size_device(type); - constexpr int qr = get_qr_device(type); - constexpr int qi = get_qi_device(type); constexpr bool need_sum = get_need_sum(type); constexpr int vdr = get_vdr_mmq(type); diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index e198a6ca7f2997..5f056e91e54606 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -50,8 +50,8 @@ static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { - constexpr int qk = ggml_blck_size_device(type); - constexpr int qi = get_qi_device(type); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; constexpr int vdr = get_vdr_mmvq(type); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);