From 80047a386f380f815e9a04a5d7ad46fcf9ff7aee Mon Sep 17 00:00:00 2001 From: Stephan Walter Date: Sun, 16 Apr 2023 15:36:36 +0200 Subject: [PATCH] More AVX2 optimizations --- ggml.c | 151 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 78 insertions(+), 73 deletions(-) diff --git a/ggml.c b/ggml.c index 4bcdd5cd723ca..6a15d8b900613 100644 --- a/ggml.c +++ b/ggml.c @@ -2207,19 +2207,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - for (int i = 0; i < nb; i += 2) { - __m256i bx = bytesFromCrumbs(x[i+1].qs, x[i].qs); + for (int i = 0; i < nb/2; i++) { + __m256i bx = bytesFromCrumbs(x[i*2+1].qs, x[i*2].qs); // Compute combined scale for the block - const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d); - const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d); - const __m256 scale = _mm256_set_m128(scale_hi, scale_lo); + const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+0].d)); + const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+1].d)); + __m256 scale = _mm256_set_m128(scale_hi, scale_lo); + scale = _mm256_mul_ps(scale, _mm256_broadcast_ss(&y[i].d)); const __m256i off = _mm256_set1_epi8(2); bx = _mm256_sub_epi8(bx, off); // Load y vector - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); // Get absolute values of x vectors const __m256i ax = _mm256_sign_epi8(bx, bx); @@ -2272,6 +2273,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK3_0 == 0); const int nb = n / QK3_0; + assert(nb % 2 == 0); const block_q3_0 * restrict x = vx; const block_q8_0 * restrict y = vy; @@ -2281,77 +2283,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * #if defined(__AVX2__) // Initialize accumulator with zeros __m128 acc = _mm_setzero_ps(); - for (int i = 0; i < nb; i++) { - // Compute combined scale for the block - const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d); - - const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0); - const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64); - - __m256i bxx = _mm256_set1_epi64x(x[i].qs); - - // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale - - // shift the copies to be able to reach all values - // 255 192 128 64 0 - // | | | | - // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in - // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left - // _______________________sssssfedcba98765432__________________________________________ shift right - // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out - // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ - // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0 - bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r)); - - // add to itself in masked places to shift some values left one bit - // 127 64 0 - // | | | | | | | | | | | | | | | | - // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in - // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask - // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked - // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum - // - // 255 192 128 - // | | | | | | | | | | | | | | | | - // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in - // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask - // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked - // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum - const __m256i doublemask = _mm256_set1_epi64x(0x078000078000); - bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx)); - - // collect 16 bytes from 256 into 128 bits - const __m256i shufmask = _mm256_set_epi8( - 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1, - -1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0); - bxx = _mm256_shuffle_epi8(bxx, shufmask); + for (int i = 0; i < nb/2; i++) { + const __m128 scale_y = _mm_set1_ps(y[i].d); + for (int u = 0; u < 2; u++) { // let the compiler unroll this + // Compute combined scale for the block + const __m128 scale_x = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+u].d)); + const __m128 scale = _mm_mul_ps(scale_x, scale_y); + + __m256i bxx = _mm256_set1_epi64x(x[i*2+u].qs); + + // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale + + // shift the copies to be able to reach all values + // 255 192 128 64 0 + // | | | | + // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in + // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left + // _______________________sssssfedcba98765432__________________________________________ shift right + // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out + // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ + // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0 + const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0); + const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64); + bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r)); + + // add to itself in masked places to shift some values left one bit + // 127 64 0 + // | | | | | | | | | | | | | | | | + // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in + // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask + // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked + // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum + // + // 255 192 128 + // | | | | | | | | | | | | | | | | + // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in + // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask + // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked + // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum + const __m256i doublemask = _mm256_set1_epi64x(0x078000078000); + bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx)); + + // collect 16 bytes from 256 into 128 bits + const __m256i shufmask = _mm256_set_epi8( + 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1, + -1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0); + bxx = _mm256_shuffle_epi8(bxx, shufmask); + + __m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1)); + + const __m128i mask = _mm_set1_epi8(7); + bx = _mm_and_si128(mask, bx); + + const __m128i off = _mm_set1_epi8(4); + bx = _mm_sub_epi8(bx, off); + + const __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + u*QK3_0)); - __m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1)); - - const __m128i mask = _mm_set1_epi8(7); - bx = _mm_and_si128(mask, bx); - - const __m128i off = _mm_set1_epi8(4); - bx = _mm_sub_epi8(bx, off); - - const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK3_0)); - - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(bx, bx); - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(by, bx); - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(bx, bx); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(by, bx); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); - // Convert int16_t to int32_t by adding pairwise - const __m128i ones = _mm_set1_epi16(1); - __m128i i32 = _mm_madd_epi16(dot, ones); + // Convert int16_t to int32_t by adding pairwise + const __m128i ones = _mm_set1_epi16(1); + __m128i i32 = _mm_madd_epi16(dot, ones); - // Convert int32_t to float - const __m128 p = _mm_cvtepi32_ps(i32); + // Convert int32_t to float + const __m128 p = _mm_cvtepi32_ps(i32); - // Apply the scale, and accumulate - acc = _mm_fmadd_ps(scale, p, acc); + // Apply the scale, and accumulate + acc = _mm_fmadd_ps(scale, p, acc); + } } // Return horizontal sum of the acc vector