Skip to content

Commit

Permalink
mmq_type_traits
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jun 5, 2024
1 parent fe1c4bb commit fd65ff3
Showing 1 changed file with 89 additions and 59 deletions.
148 changes: 89 additions & 59 deletions ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#include <climits>
#include <cstdint>

typedef void (*load_tiles_cuda_t)(
typedef void (*load_tiles_mmq_t)(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
typedef void (*vec_dot_q_mul_mat_cuda_t)(
typedef void (*vec_dot_mmq_t)(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);

Expand Down Expand Up @@ -959,57 +959,88 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(

// -------------------------------------------------------------------------------------------------------------------------------------

static constexpr __device__ int get_need_sum(ggml_type type) {
return type == GGML_TYPE_Q4_0 ||
type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q5_1 ||
type == GGML_TYPE_Q4_K ||
type == GGML_TYPE_Q5_K;
}
template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
struct mmq_type_traits;

template <int mmq_y, int nwarps, bool need_check>
static constexpr __device__ load_tiles_cuda_t get_load_tiles(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? load_tiles_q4_0<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q4_1 ? load_tiles_q4_1<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q5_0 ? load_tiles_q5_0<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q5_1 ? load_tiles_q5_1<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q8_0 ? load_tiles_q8_0<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q2_K ? load_tiles_q2_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q3_K ? load_tiles_q3_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q4_K ? load_tiles_q4_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q5_K ? load_tiles_q5_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q6_K ? load_tiles_q6_K<mmq_y, nwarps, need_check> :
nullptr;
}
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
static constexpr bool need_sum = true;
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

static constexpr __device__ int get_vdr_mmq(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMQ :
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMQ :
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMQ :
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMQ :
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMQ :
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMQ :
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMQ :
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMQ :
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMQ :
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMQ :
0;
}
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
static constexpr bool need_sum = true;
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps>
static constexpr __device__ vec_dot_q_mul_mat_cuda_t get_vec_dot_mmq(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
nullptr;
}
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
static constexpr bool need_sum = true;
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
static constexpr bool need_sum = true;
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
static constexpr bool need_sum = true;
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};

template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
Expand All @@ -1033,12 +1064,14 @@ static __global__ void mul_mat_q(
return;
}

constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qr = ggml_cuda_type_traits<type>::qr;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device(mmq_x);
constexpr bool need_sum = get_need_sum(type);
constexpr int vdr = get_vdr_mmq(type);
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qr = ggml_cuda_type_traits<type>::qr;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device(mmq_x);
constexpr bool need_sum = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;

constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);

Expand All @@ -1050,9 +1083,6 @@ static __global__ void mul_mat_q(
int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE]
half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];

constexpr load_tiles_cuda_t load_tiles = get_load_tiles<mmq_y, nwarps, need_check>(type);
constexpr vec_dot_q_mul_mat_cuda_t vec_dot = get_vec_dot_mmq<mmq_x, mmq_y, nwarps>(type);

const block_q8_1 * y = (const block_q8_1 *) yc;

const int blocks_per_row_x = ne00 / qk;
Expand Down

0 comments on commit fd65ff3

Please sign in to comment.