Skip to content

Commit

Permalink
IQ1_M: 1.75 bpw quantization (#6302)
Browse files Browse the repository at this point in the history
* iq1_m: basics

* iq1_m: basics-2

* iq1_m: CUDA dequantize works

Very 1st shot I get PPL = 9.76 for LLaMA-v2-7B.

* iq1_m: separate shifts for each group of 8 in a block

We get
PPL(LLaMA-v2-7B ) = 9.2810
PPL(LLaMA-v2-13B) = 6.8105

Not bad, but slightly higher than
  sqrt(PPL(IQ1_S) * PPL(IQ2_XXS))
which is the expected outcome given that IQ1_M is
halfway between IQ1_S and IQ2_XXS in terms of bpw.
From this, we would expect
 PPL = 9.14 for LLaMA-v2-7B
 PPL = 6.63 for LLaMA-v2-13B

* iq1_m: go to 3-bit scales

There is slight increase in PPL, but the 0.0625 bpw reduction
in size is totally worth it.

We now have
PPL(LLaMA-v2-7B ) = 9.4469 at 1.96 bpw
PPL(LLaMA-v2-13B) = 6.8717 at 1.93 bpw
PPL(LLaMA-v2-70B) = 4.8568 at 1.85 bpw

* iq1_m: scalar dot product

* iq1_m: AVX2 dot product

* iq1_m: very slightly faster AVX2 dot product

* iq1_m: ARM_NEON dot product

Works, but very slow (10.5 t/s)

* iq1_m: Metal - dequantize works, dot product does not

* iq1_m: Metal now works

About the same performance as iq1_s.

* iq1_m: minor

* iq1_m: checking pure iq1_m quantization

It is pretty bad: PPL(LLaMA-v2-7B) = 34 if we quantize output.weight
with Q4_K.

* iiq1_m: slightly faster ARM_NEON dot product

10.5 t/s -> 11.65 t/s

* iq1_m: faster ARM_NEON dot product

11.65 t/s -> 14.9 t/s

* iq1_m: another minor ARM_NEON dot product improvement

14.9 -> 15.0 t/s

* iq1_m: small PPL improvement via super-block scale adjustment

After quantizing block scales redo the super-block scale fit.

PPL(LLaMA-v2-7B ) = 9.3346
PPL(LLaMA-v2-13B) = 6.8419
PPL(LLaMA-v2-70B) = 4.8294
PPL(Mistral-7B  ) = 8.1624

* iq1_m: adapt to CUDA refactoring

* iq1_m: remove unused variable

We have progressed to warnings being errors.

* iq1_m: add to backend-ops tests

* iq1_m: fix Windows ARM

* iq1_m: use common definition of iq1m_scale_t

* cuda: assert -> NO_DEVICE_CODE

* iq1_M: PR comments

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
  • Loading branch information
ikawrakow and Kawrakow authored Mar 26, 2024
1 parent e097633 commit 55c1b2a
Show file tree
Hide file tree
Showing 16 changed files with 1,006 additions and 125 deletions.
11 changes: 7 additions & 4 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
Expand Down Expand Up @@ -370,10 +371,12 @@ int main(int argc, char ** argv) {

if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S ||
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) {
fprintf(stderr, "\n===============================================================================================\n");
fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
fprintf(stderr, "===============================================================================================\n\n\n");
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S ||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && imatrix_data.empty()) {
fprintf(stderr, "\n==========================================================================================================\n");
fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
fprintf(stderr, "==========================================================================================================\n\n\n");
return 1;
}

Expand Down
15 changes: 15 additions & 0 deletions ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,20 @@ typedef struct {
} block_iq1_s;
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");

// 1.8125 bpw
typedef struct {
uint8_t qs[QK_K/8]; // grid index, low 8 bits
uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8)
uint8_t scales[QK_K/32]; // 4-bit block scales
} block_iq1_m;
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");

// Used by IQ1_M quants
typedef union {
ggml_half f16;
uint16_t u16;
} iq1m_scale_t;

// Non-linear quants
#define QK4_NL 32
typedef struct {
Expand Down Expand Up @@ -1050,6 +1064,7 @@ GGML_TABLE_END()

#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
#if defined(GGML_COMMON_IMPL_C)
GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
Expand Down
4 changes: 3 additions & 1 deletion ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
Expand Down Expand Up @@ -643,6 +644,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
Expand Down Expand Up @@ -2560,7 +2562,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false;
}
Expand Down
53 changes: 47 additions & 6 deletions ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
#else
assert(false);
NO_DEVICE_CODE;
#endif

}
Expand All @@ -395,7 +395,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
#else
assert(false);
NO_DEVICE_CODE;
#endif

}
Expand All @@ -416,7 +416,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
#else
assert(false);
NO_DEVICE_CODE;
#endif

}
Expand Down Expand Up @@ -444,7 +444,7 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
NO_DEVICE_CODE;
#endif

}
Expand All @@ -470,7 +470,7 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
NO_DEVICE_CODE;
#endif

}
Expand All @@ -496,11 +496,42 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
y[j] = d * (q[j] + delta);
}
#else
assert(false);
NO_DEVICE_CODE;
#endif

}

template<typename dst_t>
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int i = blockIdx.x;
const block_iq1_m * x = (const block_iq1_m *) vx;

const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * sc = (const uint16_t *)x[i].scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = d * (q[j] + delta);
}
#else
NO_DEVICE_CODE;
#endif

}


template<typename dst_t>
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -658,6 +689,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
Expand Down Expand Up @@ -724,6 +761,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
Expand Down Expand Up @@ -769,6 +808,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
Expand Down
11 changes: 11 additions & 0 deletions ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

static void mul_mat_vec_iq1_m_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
Expand Down Expand Up @@ -373,6 +381,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ1_M:
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
Expand Down
Loading

0 comments on commit 55c1b2a

Please sign in to comment.