Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Q8: use int8_t, AVX/AVX2 optimizations #972

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 189 additions & 26 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc

typedef struct {
float d; // delta
uint8_t qs[QK]; // nibbles / quants
int8_t qs[QK]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding");

Expand Down Expand Up @@ -1060,9 +1060,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r

for (int l = 0; l < QK; ++l) {
const float v = x[i*QK + l]*id;
const uint8_t vi = (int8_t)roundf(v) + 128;

y[i].qs[l] = vi;
y[i].qs[l] = roundf(v);
}
}
}
Expand Down Expand Up @@ -1095,15 +1093,98 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int

for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(128.5f));
const int32x4_t vi = vcvtq_s32_f32(vf);
const int32x4_t vi = vcvtnq_s32_f32(v);

y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;

// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );

__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );

// Quantize these floats
const float d = maxScalar / 127.f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );

// Apply the multiplier
v0 = _mm256_mul_ps( v0, mul );
v1 = _mm256_mul_ps( v1, mul );
v2 = _mm256_mul_ps( v2, mul );
v3 = _mm256_mul_ps( v3, mul );

// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );

// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );

#if defined(__AVX2__)
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31

// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );

_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
__m128i ni0 = _mm256_castsi256_si128( i0 );
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
__m128i ni2 = _mm256_castsi256_si128( i1 );
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
__m128i ni4 = _mm256_castsi256_si128( i2 );
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
__m128i ni6 = _mm256_castsi256_si128( i3 );
__m128i ni7 = _mm256_extractf128_si256( i3, 1);

// Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 );
ni2 = _mm_packs_epi32( ni2, ni3 );
ni4 = _mm_packs_epi32( ni4, ni5 );
ni6 = _mm_packs_epi32( ni6, ni7 );
// Convert int16 to int8
ni0 = _mm_packs_epi16( ni0, ni2 );
ni4 = _mm_packs_epi16( ni4, ni6 );

_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
#endif
}
#else
// scalar
quantize_row_q8_0_reference(x, y, k);
Expand Down Expand Up @@ -2508,7 +2589,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *

const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t u128b = vdupq_n_u8(128);

const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
Expand All @@ -2526,21 +2606,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);

// load y
const uint8x16_t v1_0l = vld1q_u8(y0->qs);
const uint8x16_t v1_0h = vld1q_u8(y0->qs + 16);
const uint8x16_t v1_1l = vld1q_u8(y1->qs);
const uint8x16_t v1_1h = vld1q_u8(y1->qs + 16);
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);

// interleave
const uint8x16_t v1_0lz = vuzp1q_u8(v1_0l, v1_0h);
const uint8x16_t v1_0hz = vuzp2q_u8(v1_0l, v1_0h);
const uint8x16_t v1_1lz = vuzp1q_u8(v1_1l, v1_1h);
const uint8x16_t v1_1hz = vuzp2q_u8(v1_1l, v1_1h);

const int8x16_t v1_0ls = vreinterpretq_s8_u8(vsubq_u8(v1_0lz, u128b));
const int8x16_t v1_0hs = vreinterpretq_s8_u8(vsubq_u8(v1_0hz, u128b));
const int8x16_t v1_1ls = vreinterpretq_s8_u8(vsubq_u8(v1_1lz, u128b));
const int8x16_t v1_1hs = vreinterpretq_s8_u8(vsubq_u8(v1_1hz, u128b));
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);

#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
Expand Down Expand Up @@ -2578,14 +2653,102 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
}

sumf = sum0 + sum1;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

// Main loop
for (int i = 0; i < nb; ++i) {
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );

__m256i bx = bytesFromNibbles(x[i].qs);

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );

__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);

// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);

// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);

// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

const __m256i ones = _mm256_set1_epi16(1);
__m256i xy_q = _mm256_madd_epi16(ones, dot);

/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps( xy_q );

/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( d, q, acc );
}

// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );

sumf = _mm_cvtss_f32( res );
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

// Main loop
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );

__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m128i off = _mm_set1_epi8( 8 );
bx = _mm_sub_epi8( bx, off );

// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(bx, bx);

// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(by, bx);

// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);

const __m128i ones = _mm_set1_epi16(1);
i32[j] = _mm_madd_epi16(ones, dot);
}

// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}

// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );

sumf = _mm_cvtss_f32( res );
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float d1 = y[i].d;

const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = y[i].qs;
const int8_t * restrict p1 = y[i].qs;

int sumi = 0;
for (int j = 0; j < QK/2; j++) {
Expand All @@ -2594,10 +2757,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const int i0 = (int8_t) (v0 & 0xf) - 8;
const int i1 = (int8_t) (v0 >> 4) - 8;

const int i2 = (int) p1[2*j + 0] - 128;
const int i3 = (int) p1[2*j + 1] - 128;

/*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
const int i2 = p1[2*j + 0];
const int i3 = p1[2*j + 1];

sumi += i0*i2 + i1*i3;
}
Expand Down Expand Up @@ -9923,7 +10084,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
} else
#endif
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
{
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
}
} else {
GGML_ASSERT(false);
}
Expand Down