diff --git a/ggml.c b/ggml.c index b6dd3f3cf74e32..160c91e669d721 100644 --- a/ggml.c +++ b/ggml.c @@ -2288,17 +2288,31 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest const uint8_t * restrict p0 = x[i].qs; const uint8_t * restrict p1 = y[i].qs; - for (int j = 0; j < QK/2; j++) { - const uint8_t v0 = p0[j]; - const uint8_t v1 = p1[j]; - - const float f0 = d0*(v0 & 0xf) + m0; - const float f1 = d0*(v0 >> 4) + m0; - - const float f2 = d1*(v1 & 0xf) + m1; - const float f3 = d1*(v1 >> 4) + m1; - - sumf += f0*f2 + f1*f3; + for (int j = 0; j < QK/4; j++) { + const uint32_t v0 = ((uint32_t *)p0)[j]; + const uint32_t v1 = ((uint32_t *)p1)[j]; + + const uint8_t v0_0 = (v0 >> 0) & 0xf; + const uint8_t v0_1 = (v0 >> 4) & 0xf; + const uint8_t v0_2 = (v0 >> 8) & 0xf; + const uint8_t v0_3 = (v0 >> 12) & 0xf; + + const uint8_t v1_0 = (v1 >> 0) & 0xf; + const uint8_t v1_1 = (v1 >> 4) & 0xf; + const uint8_t v1_2 = (v1 >> 8) & 0xf; + const uint8_t v1_3 = (v1 >> 12) & 0xf; + + const float f0 = d0 * v0_0 + m0; + const float f1 = d0 * v0_1 + m0; + const float f2 = d0 * v0_2 + m0; + const float f3 = d0 * v0_3 + m0; + + const float f4 = d1 * v1_0 + m1; + const float f5 = d1 * v1_1 + m1; + const float f6 = d1 * v1_2 + m1; + const float f7 = d1 * v1_3 + m1; + + sumf += f0 * f4 + f1 * f5 + f2 * f6 + f3 * f7; } } #endif