Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal: minor q4 optimization and reduce code size #2248

Merged
merged 3 commits into from
Jul 20, 2023

Conversation

lshzh-ww
Copy link
Collaborator

@lshzh-ww lshzh-ww commented Jul 17, 2023

The first commit uses uint16_t instead of uint8_t in q4_0 and q4_1 kernels. Apple GPUs don't like 8-bit types, and for any operation on a 8-bit number the GPU will first copy it to an empty register and then do the calculation. This brings ~3% improvement for 33B model on M1 Max.

After this commit, we can achieve 340-350 GB/s memory read speed on M1 Max for 33B model, if we only run the q4_0-f32 MAT_MUL kernel and skip all other operations. This is very close to the reported hardware limit.
Screenshot 2023-07-17 at 01 18 29

(Only run matrix vector multiplications and skip all other operations.)

The second commit updates the RMS_NORM kernel by minimizing the use of threadgroup memory barrier, brings ~2% improvement for 33B model on M1 Max.

The third commit uses template to reduce code size. q5_0 and q5_1 support can be added quite efficiently using the new template. The new template also improve the behavior when nb is not divisible by 32. (thanks to discussion with @ikawrakow !)

Overall speed up (M1 Max, Updated as f3f2e8e):

master this PR Speed up
33B q4_0 256 tokens 82.1 ms/tok 72.5 ms/tok ~13%
7B q4_0 256 tokens 21.1 ms/tok 19.6 ms/tok ~7.6%

./main -m model -n 256 -c 512 -s 123 -p "I believe the meaning of life is" -ngl 1 --no-mmap -t 8

Apple GPU doesn't like uint8_t. For every operation on uint8_t
the gpu need to copy the uint8_t to an empty 16 bit register, then
it can issue other instructions.

For the matrix-vector multiplication kernel only, we observed a
340~350 GB/s memory read speed on M1 Max after this commit, which is
very close to the reported hardware limit.
This commit double the speed of rms_norm operations by using 512 threads
per threadgroup, combining with SIMD primitives to minimize the need for
thread group barriers.
@lshzh-ww lshzh-ww requested a review from ikawrakow July 17, 2023 05:27
@lshzh-ww
Copy link
Collaborator Author

@ggerganov I plan to write a matrix-matrix multiplication kernel for metal backend, so that we can finally run prompt evaluation on GPU. I know that you want to avoid having different kernels for matrix-matrix and matrix-vector, but I found it's a bit hard to achieve high performance on both scenario with one kernel.

This is part of the reason why I came up with this template. After this I can prepare a new template for q4_0, q4_1, q5_0 and q5_1 matrix-matrix multiplication. I guess we still need two more templates for q_k quantizations (they have larger block size that can't be put in registers, so the logic will be a little different). Overall there may be 4 templates and we can still have a reasonable code size.

Let me know your opinions.

@lshzh-ww
Copy link
Collaborator Author

@ikawrakow BTW, you are right on the prefetch test. Using the disassembler from here, it looks like the prefetch codes are optimized out by the compiler. I was test multiple optimizations at that time and wrongly attributed the speed up to prefetch.

@ggerganov
Copy link
Owner

I observe similar perf gains on M1 Pro:

model master PR
7B Q4_0 32.4 ms/t 31.1 ms/t
7B Q4_1 34.6 ms/t 33.5 ms/t
13B Q4_0 56.5 ms/t 54.6 ms/t
13B Q4_1 60.9 ms/t 59.4 ms/t

@ggerganov
Copy link
Owner

@ggerganov I plan to write a matrix-matrix multiplication kernel for metal backend, so that we can finally run prompt evaluation on GPU. I know that you want to avoid having different kernels for matrix-matrix and matrix-vector, but I found it's a bit hard to achieve high performance on both scenario with one kernel.

Please, go ahead. I haven't started working on this so I have no meaningful points to make. Reusing the dot-product code would be nice, but if it's not enough to obtain the best performance then we can do something else, for example what you propose.

There is some similar work going on in #2160
Not sure if insights between both backends could be shared, but might be still useful to keep an eye

Btw, there is one strange thing that is observed with the Metal implementation, that would be nice to have an explanation of:

Do the following 2 runs:

  • Put return on the first line of the matrix multiplication kernels
  • Do not enqueue the matrix multiplications kernels at all

These 2 runs I would expect to have very similar performance, but for some reason the former one is noticeably slower.
At least this was the case last time I checked some time ago.

The difference was quite significant. I'm not a Metal expert, but it looked like a big cost for queuing an empty kernel.
Maybe worth looking into this and see if this overhead can be improved in some way. Might bring some extra performance gains

@lshzh-ww
Copy link
Collaborator Author

The difference was quite significant. I'm not a Metal expert, but it looked like a big cost for queuing an empty kernel.

Yes there is latency even if you launch an empty kernel. I think that latency comes from that GPU has to prepare threadgroup memory and copy parameters to the constant buffer. On contemporary NVIDIA hardware people reported a kernel launch takes at least 5 us and starting every threadgroup takes ~0.5 us. Such number may be higher for metal.

There are quite a lot we can do to optimize performance, like making element-wise add/mul kernels use less threadgroups, removing unneeded parameters pass to kernels, etc. I am confident that we can speed up token generation by 10%-15% easily.

Like you said some ideas from cuda backend may be useful, and I am keeping an eye on it.

ggml-metal.metal Outdated Show resolved Hide resolved
ggml-metal.metal Outdated Show resolved Hide resolved
ggml-metal.metal Outdated
@@ -400,20 +436,15 @@ kernel void kernel_mul_mat_q4_0_f32(
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
// this design is q4_0 and q4_1 centered, but I think most of the people use these two quantizations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many people do use quantizations other than Q4_0 and Q4_1.

Copy link
Collaborator Author

@lshzh-ww lshzh-ww Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the misunderstanding, I mean among Q4_0, Q4_1, Q5_0, Q5_1 and Q8_0 people mostly use Q4_0 and Q4_1. I think Q5 can also benefits a little from this design, and only for Q8 we need to do 4 more multiplications for each block.
I know that quite a lot people use k-quants, and for myself I also use k-quants model mostly for it provides better perplexity. This template mainly serves for non-k quants, because k-quants use much larger block size, and may need a different load strategy. Once I finish optimizing this template, I will try to apply the same design rules to k-quants and see if we can merge all quantization methods together or we have to prepare a separate template for k-quants.

ggml-metal.metal Outdated
@@ -8,14 +8,14 @@ using namespace metal;
#define QR4_0 2
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
uint16_t qs[QK4_0 / 4]; // nibbles / quants
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally do not like having the Metal block_q4_0/1 definition differ from the definition in ggml.c. One can get the same effect by just casting qs to uint16_t * where needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it's better to keep a uniform design across the library. However, changing uint16_t back to uint8_t and casting qs to uint16_t * almost kills all the performance gain.

Disassembler shows that the complier is just not clever enough to load 16 uint8_t in a reasonable way. It will first load 2 Bytes, then 8 Bytes, then all the left Bytes, taking 3~4 load instructions. It does load 8 uint16_t in 2 instructions, with each instructions loading 8 Bytes.

The saved mov instructions by using uint16_t should benefit more on future matrix-matrix multiplication kernel, for that is more compute bound.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, changing uint16_t back to uint8_t and casting qs to uint16_t * almost kills all the performance gain.

Disassembler shows that the complier is just not clever enough to load 16 uint8_t in a reasonable way. It will first load 2 Bytes, then 8 Bytes, then all the left Bytes, taking 3~4 load instructions. It does load 8 uint16_t in 2 instructions, with each instructions loading 8 Bytes.

Surprising. I also would have preferred to keep the structs the same, but if there is no way to workaround this I guess we can make the change.

Can you give a very short tutorial for beginners how to generate and look at the disassembler output?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ggerganov, A quick guide is post at #2279.

Revert modifications on block_q4_0 and block_q4_1.
@lshzh-ww
Copy link
Collaborator Author

lshzh-ww commented Jul 20, 2023

The new commit revert the modification of the block_q4_0 and block_q4_0 (Thanks inspiration from @ikawrakow !).

The new template also improve the behavior when nb is not divisible by 32: We first let each simd group calculate 32 blocks from the same row every time, until there are less than 32 blocks left for each row. In this case, we let each simd group loads 16 blocks from the first row and 16 blocks from the next row. By this design, we don't waste resource any more as long as nb%16==0.

The only exception is 7B model, in which part of the matrices have ne00=11008 and nb%32=24. The last iteration for these matrices will have half of the threads idle. Hence the performance improvement for 7B won't be as big as others.

Updated benchmark:

master this PR Speed up
33B q4_0 256 tokens 82.1 ms/tok 72.5 ms/tok ~13%
7B q4_0 256 tokens 21.1 ms/tok 19.6 ms/tok ~7.6%

@ikawrakow ikawrakow merged commit 417a85a into ggerganov:master Jul 20, 2023
4 checks passed
@ikawrakow ikawrakow mentioned this pull request Jul 20, 2023
@lshzh-ww lshzh-ww deleted the mps-q4-optimize branch July 20, 2023 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants