Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize AVX2 ggml_vec_dot_q4_0 #642

Merged
merged 1 commit into from
Mar 31, 2023
Merged

Optimize AVX2 ggml_vec_dot_q4_0 #642

merged 1 commit into from
Mar 31, 2023

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Mar 31, 2023

This is a port of @perserk's and @sw's AVX implementation of ggml_vec_dot_q4_0 (#617) to AVX2.

------------------------------------------------------------------------
Benchmark                              Time             CPU   Iterations
------------------------------------------------------------------------
BM_ggml_vec_dot_q4_0_avx             668 ns          668 ns      1055239
BM_ggml_vec_dot_q4_0_avx2            578 ns          578 ns      1209367
BM_ggml_vec_dot_q4_0_avx2_new        522 ns          522 ns      1346143

Before:

llama_print_timings: prompt eval time = 10113.34 ms /   116 tokens (   87.18 ms per token)
llama_print_timings:        eval time = 20360.13 ms /   127 runs   (  160.32 ms per run)

perplexity : calculating perplexity over 655 chunks
42.05 seconds per pass - ETA 7.65 hours
[1]4.6512,[2]5.2613,[3]6.0903,

After:

llama_print_timings: prompt eval time =  7627.11 ms /   116 tokens (   65.75 ms per token)
llama_print_timings:        eval time = 19477.24 ms /   127 runs   (  153.36 ms per run)

perplexity : calculating perplexity over 655 chunks
31.88 seconds per pass - ETA 5.80 hours
[1]4.5619,[2]5.1787,[3]6.0491,

@rabidcopy
Copy link
Contributor

Very nice. Compiled on Windows and went from 230ms per token to 195ms per token on average. Which sounds even more promising because I was getting 185ms per token as is on Linux without these optimizations.

@ggerganov ggerganov mentioned this pull request Mar 31, 2023
@sw
Copy link
Collaborator

sw commented Mar 31, 2023

I didn't have consistent run times which would have allowed me to see an improvement, so thanks for looking at this more closely. I wonder how much the speed-up is due to the loop unrolling, though.

@slaren
Copy link
Collaborator Author

slaren commented Mar 31, 2023

I wonder how much the speed-up is due to the loop unrolling, though.

It's not much at all, I couldn't even measure a difference reliably in the google benchmark code, however it seemed to consistently lower perplexity times by a couple of seconds so I decided to leave it in.

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

Out of curiosity, I checked the non-AVX2 binaries out on Windows now that #617 was merged. Getting 177ms per token now. Edit: Got back on Linux and the numbers comparing a AVX-only build versus this PR with AVX2 are as follows.
AVX optimizations: 180ms per token (AVX2 without this PR was only 186ms per token previously)
AVX2 optimizations: 138ms per token (down from 186ms) These numbers are on 7B for clarification. I'd imagine less gains on larger models.

@x02Sylvie
Copy link

Additional AVX2 ggml_vec_dot_q4_0 optimization you might be interested in @slaren ,

Before (with this pull):
llama_print_timings: eval time = 7756.83 ms / 31 runs ( 250.22 ms per run)

After (with this pull + extra optimization):
llama_print_timings: eval time = 7229.14 ms / 31 runs ( 233.20 ms per run)

    #define bs 20
    // Main loop
    const int unroll_count = 4;
    const int loop_count = nb / unroll_count;
    for (int j = 0; j < loop_count; ++j) {
        #pragma unroll
        for (int idx = 0; idx < unroll_count; ++idx) {
            // determin the actual index in the loop
            const int i = j * unroll_count + idx;
            const float * d0_0 = (const float *) ((const uint8_t *)x + i*bs);
            const float * d1_0 = (const float *) ((const uint8_t *)y + i*bs);

            const uint8_t * restrict p0 = (const uint8_t *)x + 4 + i*bs;
            const uint8_t * restrict p1 = (const uint8_t *)y + 4 + i*bs;

            // Prefetch data used later in the loop
            // TODO these numbersi are device dependent shouldn't be hard coded derive
            _mm_prefetch (d0_0 + 32*bs, 1);
            _mm_prefetch (d1_0 + 32*bs, 1);
            _mm_prefetch (p0 + 32*bs, 1);
            _mm_prefetch (p1 + 32*bs, 1);

            // Compute combined scale for the block
            const __m256 scale = _mm256_mul_ps( _mm256_broadcast_ss( d0_0 ), _mm256_broadcast_ss( d1_0 ) );

            // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
            __m256i bx = bytesFromNibbles( p0 );
            __m256i by = bytesFromNibbles( p1 );

            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
            const __m256i off = _mm256_set1_epi8( 8 );
            bx = _mm256_sub_epi8( bx, off );
            by = _mm256_sub_epi8( by, off );

            // Sign-extend first 16 signed bytes into int16_t
            __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
            __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
            // Compute products of int16_t integers, add pairwise
            __m256i i32 = _mm256_madd_epi16( x16, y16 );

            // Sign-extend last 16 signed bytes into int16_t vectors
            x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
            y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
            // Accumulate products of int16_t integers
            i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );

            // Convert int32_t to float
            __m256 p = _mm256_cvtepi32_ps( i32 );
            // Apply the scale, and accumulate
           acc = _mm256_fmadd_ps( scale, p, acc );
        }
    }
    // TODO  extract the loop here to eliminate duplicated code
    for (int i = loop_count * unroll_count; i < nb; ++i) {
        // Compute combined scale for the block
        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );

        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
        __m256i bx = bytesFromNibbles( x[i].qs );
        __m256i by = bytesFromNibbles( y[i].qs );

        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
        const __m256i off = _mm256_set1_epi8( 8 );
        bx = _mm256_sub_epi8( bx, off );
        by = _mm256_sub_epi8( by, off );

        // Sign-extend first 16 signed bytes into int16_t
        const __m256i ax = _mm256_sign_epi8(bx, bx);

        // Sign-extend last 16 signed bytes into int16_t vectors
        const __m256i sy = _mm256_sign_epi8(by, bx);

        // Perform multiplication and create 16-bit values
        const __m256i dot = _mm256_maddubs_epi16(ax, sy);

        const __m256i ones = _mm256_set1_epi16(1);
        const __m256i i32 = _mm256_madd_epi16(ones, dot);

        // Convert int32_t to float
        const __m256 p = _mm256_cvtepi32_ps( i32 );
        // Apply the scale, and accumulate
        acc = _mm256_fmadd_ps( d, p, acc );
    }

I'm on msvc windows so im not sure if those improvements would translate to linux

@slaren
Copy link
Collaborator Author

slaren commented Mar 31, 2023

Is there any reason to unroll the loops manually rather than allowing the compiler to do it?

From what I can tell, #pragma unroll doesn't even work on MSVC. I have not been able to find any pragmas in MSVC for unrolling loops other than maybe omp unroll, but that would require adding a dependency to OpenMP.

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

I'm on msvc windows so im not sure if those improvements would translate to linux

Still getting more or less the same average of 138ms per token with my test prompt on Linux. Will try to compare on Windows later.

@slaren
Copy link
Collaborator Author

slaren commented Mar 31, 2023

I don't think that we can reliably measure small performance differences just by looking at the eval time, the variance is too high. For now I found that the perplexity time per pass is a lot more stable (remember to disable BLAS when doing this though).

@sw sw merged commit 1d08882 into ggerganov:master Mar 31, 2023
@jart
Copy link
Contributor

jart commented Mar 31, 2023

This PR changes the behavior of inference on Intel. I monitor this project for deterministic responses by hard-coding the random seed parameter. After patching this change, token outputs became different and lower quality. There's probably a small bug that needs to be addressed. Please consider rolling back.

@sw
Copy link
Collaborator

sw commented Mar 31, 2023

@jart : You're right. I noticed it before merging but thought it due to changes in floating point operation order. There are such differences between the various optimizations for different processors. But thinking about it, there shouldn't be a difference here because only integer operations were modified. Let me look into this some more.

#617 might have the same problem, but on AVX instead of AVX2.

@jart
Copy link
Contributor

jart commented Mar 31, 2023

Intel support was already in a somewhat broken state compared to Apple M1 support before this change. The differences are kind of minor, like a missing comma. This change has a bug that caused completely different answers that don't make sense to show up. Although it was still real spoken language. I've seen worse. Sometimes when working on changes, I get something wrong and it speaks complete gibberish!

I care much more about reliability, determinism, predictability, and accuracy than I care about an 8% performance boost. Yes differences in floating point exist between architectures. But isn't it usually minor? Stuff like NAN vs. -NAN? Whatever defined differences in behavior exist, I would ideally hope we could avoid triggering them, in the interest of better correctness.

@jart
Copy link
Contributor

jart commented Mar 31, 2023

Or rather than correctness, I should say consistency.

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

Hm, for some reason now that this has been merged, I'm not getting the same performance uplifts, but am getting different outputs. Does this somehow conflict with #617?
Left is latest master, right is this PR by itself before #617 was merged with it.
image
Edit: It appears #654 outperforms this PR now. While still changing the outputs. But without the performance regression of being merged with #617.
#654 (comment)

@sw
Copy link
Collaborator

sw commented Mar 31, 2023

@rabidcopy : the change in inference causes an EOT token to be generated, so I don't think you can compare the performance. Try a longer run with a different seed.

@jart : can you give an example of "answers that don't make sense"? Maybe with seed and model file checksums? I agree that the inference has changed, it just sounds like we're not seeing the same kind of degradation.

After some tests I'm quite confident that the integer operations haven't changed, so I'm back to my initial suspicion that it's just the order of how the floats are added.

Edit: here's what I mean, just to be clear...

echo 'main(){printf("%.20f\n%.20f\n",.7+.2+.1,.7+.1+.2); }' | gcc -x c - && ./a.out
0.99999999999999988898
1.00000000000000000000

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

@rabidcopy : the change in inference causes an EOT token to be generated, so I don't think you can compare the performance. Try a longer run with a different seed.

Had a longer run that is a pretty similar length of tokens despite both seeds not matching up.
Left is current master. Right is this PR before being merged into master alongside #617.

image
Edit: And then this is the same run with #654.
image

@sw
Copy link
Collaborator

sw commented Mar 31, 2023

After more tests, I'm seeing no change in speed between cbef542 and 1d08882 (latest master), and (subjectively) no degradation in language.

Of course your experience may be different, I don't want to discount that and I'm open to reverting it. After all, different processor generations have different AVX throughput/latency.

@slaren : could you repeat and confirm your measurements on master?

@slaren
Copy link
Collaborator Author

slaren commented Mar 31, 2023

On current master:

perplexity : calculating perplexity over 655 chunks
32.12 seconds per pass - ETA 5.84 hours

Which is in line with what I expected from this PR.

This is on a 9900k under WSL2.

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

It's so strange. This PR by itself shows those uplifts for me, but on master, they seem to disappear when combined with #617. While this doesn't seem to be the case with #654 with #617 already merged.
#642 without #617: 329ms / 409 runs
#642 with #617 on master: 390ms / 398 runs (provided different output on same seed)
#654 merged on top of #617: 317ms / 409 runs
On Linux with a Ryzen 2600 with -t 6 set (any less or more is worse)

I really hate to propose this because of how it could come off (I don't want to make it feel like the effort put into this should be overshadowed by a similar PR), but would reverting 1d08882 in favor of #654 be something to consider? That being if we want to move forward with performance gains at the cost of consistency and outputs as @jart said.

@slaren
Copy link
Collaborator Author

slaren commented Mar 31, 2023

Keep in my mind that even in my tests the biggest improvement is when evaluating in batches (prompt or perplexity). The performance difference in generation is within the margin of error. When BLAS is used for the prompt, this function is not used at all.

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

I just really can't get my head around why current master has this regression in the performance uplifts that I clearly get with this PR. 🤷 Consistently 60-70ms slower on master than this PR and I have no idea why.

@slaren
Copy link
Collaborator Author

slaren commented Mar 31, 2023

@rabidcopy if you are on linux and using GCC, can you try lowering the unroll number in ggml.c:1967:

#pragma GCC unroll 16

I wonder if there is some issue with the instruction cache.

@SebastianApel
Copy link
Contributor

@rabidcopy For me, there's an additional performance improvement when compiling with -march=native. I've not tried with #642, but it is there with #654.

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

@rabidcopy if you are on linux and using GCC, can you try lowering the unroll number in ggml.c:1967:

#pragma GCC unroll 16

I wonder if there is some issue with the instruction cache.

Using cmake so that gets ignored. But I do note that in #654, there's this. https://github.com/ggerganov/llama.cpp/pull/654/files#diff-6d9ce99fcb6f51ff76f59e479f6e6fc0bb62edef7442805d7a5bb15b23996b5dR2024 Going to test a theory out real quick.

@slaren
Copy link
Collaborator Author

slaren commented Mar 31, 2023

You may still be using GCC even if you use cmake.

@rabidcopy
Copy link
Contributor

rabidcopy commented Mar 31, 2023

You may still be using GCC even if you use cmake.

Ah, duh. Well, going to mess with that then. Edit: Setting the unroll to 16, 8, and 4 didn't seem to affect performance at all.

Nuked88 pushed a commit to Nuked88/llama.http that referenced this pull request Mar 31, 2023
@rabidcopy
Copy link
Contributor

Oh man. I have egg on my face. It appears I somehow wasn't actually on the latest master. Only now noticed that sections were missing in ggml.c that didn't make sense.
image
Left is master, right is #654.
I am so deeply sorry. It appears I got my folders mixed up. Nothing is wrong with master and there's no phantom regression or instruction cache issue. Ugh.

@slaren slaren deleted the avx2-optim branch June 13, 2023 23:53
Deadsg pushed a commit to Deadsg/llama.cpp that referenced this pull request Dec 19, 2023
Fix issue of missing words due to buffer overflow for Issue ggerganov#642
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants