Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize AVX2 ggml_vec_dot_q4_0 #642

Merged
merged 1 commit into from
Mar 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
const block_q4_0 * restrict x = vx;
const block_q4_0 * restrict y = vy;

ggml_float sumf = 0.0;
float sumf = 0.0;

#if defined(__ARM_NEON)
float sum0 = 0.0f;
Expand Down Expand Up @@ -1821,7 +1821,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
#endif
}

sumf = (ggml_float)(sum0 + sum1);
sumf = sum0 + sum1;
#elif defined(__AVX512F__)
// Initialize accumulator with zeros
__m512 acc0 = _mm512_setzero_ps();
Expand Down Expand Up @@ -1855,6 +1855,10 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
__m256 acc = _mm256_setzero_ps();

// Main loop
// TODO: figure a way to do this in a portable way
#ifdef __GNUC__
#pragma GCC unroll 16
#endif
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
Expand All @@ -1868,20 +1872,21 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
bx = _mm256_sub_epi8( bx, off );
by = _mm256_sub_epi8( by, off );

// Sign-extend first 16 signed bytes into int16_t
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
// Compute products of int16_t integers, add pairwise
__m256i i32 = _mm256_madd_epi16( x16, y16 );
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);

// Sign-extend last 16 signed bytes into int16_t vectors
x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
// Accumulate products of int16_t integers
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);

// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

const __m256i ones = _mm256_set1_epi16(1);
const __m256i i32 = _mm256_madd_epi16(ones, dot);

// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( i32 );
const __m256 p = _mm256_cvtepi32_ps( i32 );

// Apply the scale, and accumulate
acc = _mm256_fmadd_ps( d, p, acc );
}
Expand Down