Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VRAM optimization + matrix multiplication discussion #1935

Closed

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Jun 19, 2023

Currently the CUDA code allocates temporary buffers during prompt processing via ggml_cuda_pool_malloc to hold the weight matrices dequantized to f32 for cuBLAS. As it turns out the amount of VRAM used for these temporary buffers is quite substantial, more than 1 GiB for 33b:

Model Max. buffer size [MiB] Pool VRAM master [MiB] Pool VRAM PR [MiB]
7b q4_0 500.00 736.16 813.41
13b q4_0 625.00 995.17 813.42
33b q4_0 812.50 1436.68 813.43

One of the problems is that the buffers are allocated in a bad order: there are three relevant sizes for the dequantized matrices and they are allocated from smallest to largest. As a consequence the VRAM allocated for the first two matrices is essentially wasted. In this PR I did a hacky patch that just allocates and frees a single 813 MiB buffer during initialization that can then be reused for all three sizes. For 33b this does reduce VRAM usage by ~600 MiB but 813 MiB is still a lot for temporary buffers that are only used during prompt processing. So I think that a proper solution would be to implement matrix multiplication that does dequantization on the fly the same way that the dequantize mul mat vec kernel does. The problem is that efficiently parallelizing general matrix multiplication is very hard.

This brings me to the topic at hand: does anyone have a good idea for how to fuse dequantization and general matrix multiplication in ggml? I think I could at the very least do a basic implementation that at least greatly reduces VRAM usage but it may perform significantly worse for prompt processing, especially for GPUs without tensor cores. Ideally I would want to implement a kernel that is efficient both in terms of VRAM and speed.

@JohannesGaessler JohannesGaessler added performance Speed related topics hardware Hardware related labels Jun 19, 2023
@ggerganov
Copy link
Owner

does anyone have a good idea for how to fuse dequantization and general matrix multiplication in ggml?

I briefly wrote on this topic here: #1867 (comment) (ignore the Metal-specific parts). It is something that we definitely want to implement in ggml and potentially drop all BLAS dependencies. Will try to formulate a strategy soon, but I think it will most likely start with a CPU-only implementation that utilizes some of the common GEMM tricks for speed (i.e. block-based multiplication), combining them with the quantization methods that we have. After that, we will think of ways to translate this to the GPU kernels. Ideally, we would find a way to reuse the dot product implementations during GEMM in order to reduce code duplication

@slaren
Copy link
Collaborator

slaren commented Jun 19, 2023

Regarding the allocation order: yes, it is bad. A while ago, I fixed that in a (hopelessly outdated) branch by calculating the maximum required memory during the initial pass in ggml_graph_compute:
https://github.com/slaren/llama.cpp/blob/420617867f79e71e95cb876e3131793f9bb2723b/ggml.c#L11658-L11660
https://github.com/slaren/llama.cpp/blob/420617867f79e71e95cb876e3131793f9bb2723b/ggml.c#L11823-L11834

The logic is a lot more complicated now so that won't be so easy, but it should still be possible to do this using a similar approach. To also avoid wasting memory when the batch size increases, you could just free the entire pool and reallocate it again when that happens. Or just make an initial dry run with maximum batch size.

I agree that eventually the best solution will be to implement our own mat muls kernels. The best way to learn how to do that may be to look into how CUTLASS does it. CUTLASS is an open source library that contains many of the kernels used in cuBLAS. It would also be possible to use CUTLASS directly, but it is a heavy dependency that we are probably not interesting in adding to ggml.

This article may also be interesting: https://siboehm.com/articles/22/CUDA-MMM

Copy link
Contributor

@ikawrakow ikawrakow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer to let @slaren and @ggerganov decide if to merge. The change reduces VRAM usage for the larger models, but increases for the 7B model (and I guess even more relative to master for e.g. OpenLLaMA 3B). Which means that people with low-end devices will be able to fit even fewer layers on the GPU.

@JohannesGaessler
Copy link
Collaborator Author

I don't mean for this to get merged; I just thought that for the broader discussion of what to do about VRAM usage the effects of a simple attempt at fixing it would be relevant so I made a PR that combines discussion with code.

@ikawrakow
Copy link
Contributor

This brings me to the topic at hand: does anyone have a good idea for how to fuse dequantization and general matrix multiplication in ggml?

Isn't the simplest possible implementation of the product between a quantized matrix and a float32 matrix simply a series of the dot products we already have? Before even going into trying to optimize it by using some of the tricks utilized in the various BLAS implementations, doesn't it make sense to first try this simplest possible version and see how performance compares against what we currently do (dequantize plus cuBLAS)?

@ikawrakow
Copy link
Contributor

As a follow up on my previous comment, here is an interesting data point: on a Ryzen 7950X CPU, perplexity calculation is faster if I disable OPEN_BLAS. Which means that on that CPU, the simplest possible matrix multiplication logic as implemented in ggml_compute_forward_mul_mat_q_f32 outperforms the OPEN_BLAS matrix multiplication.

@JohannesGaessler
Copy link
Collaborator Author

Isn't the simplest possible implementation of the product between a quantized matrix and a float32 matrix simply a series of the dot products we already have?

I'm primarily thinking about the CUDA implementation since I'm trying to optimize VRAM. For that, when I profiled it dequantization + cuBLAS on average takes 5.5 µs per call and token for 33b q4_0. The runtime of just applying dot products would be comparable to that of a dequantize_mul_mat_vec_kernel which sits at 84.7 µs per call and token. So doing it like that would literally be 15 times slower because you'd be processing the prompt at the same speed that you generate tokens at.

The profiling data also suggests that the actual matrix multiplication takes ~4 times as long as the dequantization. And since the dequantization is essentially just I/O that could be optimized away by tensor fusion a kernel that reaches 80% of the performance of cuBLAS should make the program faster overall (the actual threshold is lower because you also save some I/O for the matrix multiplication).

@JohannesGaessler
Copy link
Collaborator Author

As a follow up on my previous comment, here is an interesting data point: on a Ryzen 7950X CPU, perplexity calculation is faster if I disable OPEN_BLAS. Which means that on that CPU, the simplest possible matrix multiplication logic as implemented in ggml_compute_forward_mul_mat_q_f32 outperforms the OPEN_BLAS matrix multiplication.

On my system also the performance with OpenBLAS is unexpectedly bad but it's still faster than the ggml implementation. However, to me this suggests that something is wrong with OpenBLAS or how it's used in ggml rather than that the simple algorithm is good. Either that or the matrices used are too small to be sensitive to the difference. If you just do one dot product per entry in the matrix that you want to calculate you end up with a lot of potential cache misses which kills your performance once the matrix gets too big to fit into cache.

@ikawrakow
Copy link
Contributor

I'm not sure I can follow the math. Let's look at token prediction. On my RTX-4080 this takes 9 ms/token for Q4_0 with the 7B model. If I put return true in case GGML_OP_MUL_MAT: in ggml_cuda_compute_forward(), run time is 1.5 ms/token. So, dot products take 7.5 ms/token. To do the dot products, the GPU needs to read the entire model from VRAM (the model is way to big, so it does not fit/stay in the cache). The model is 3.8 GB. So, we have 1000/7.5 * 3.8 = 507 GB/s data rate. If we were somehow achieving 15X that via cuBLAS, that would mean 7.6 TB/s, or 10X the theoretical memory bandwidth for this GPU (768 GB/s). If I do the same with the 13B model, I measure 15.8 ms/token, 2 ms/token without dot products, so 1000 / (15.8 - 2) * 7.37 GB = 534 GB/s. This is better simply because the time for synchronization where we gather the results of a thread group when performing the dot products is a smaller fraction of the actual dot product computation, which now involves bigger vectors. I cannot meaningfully benchmark the 33B model as t does not fully fit on this GPU. So, overall, the best performance improvement one can achieve by even the most magical matrix multiplication implementation compared to just doing a series of dot products is limited by memory bandwidth, so cannot be better than ~1.5X (768 / 506) compared to just dot products. My bet is that if we start doing matrix multiplications via dot products and take care to minimize the synchronization overhead, we will come pretty close to theoretical limit, so will be on par or outperform cuBLAS.

@JohannesGaessler
Copy link
Collaborator Author

If we were somehow achieving 15X that via cuBLAS, that would mean 7.6 TB/s, or 10X the theoretical memory bandwidth for this GPU (768 GB/s).

This is where your logic is going wrong. When you multiply a square matrix of size $N$ with a vector then you are doing $O(N^2)$ operations on $O(N^2)$ data values. However, if you multiply two square matrices you instead do $O(N^3)$ operations on $O(N^2)$ data values so you are much less I/O bound. In other words, the higher the minimum dimension of the matrices that you multiply is, the fewer memory accesses per operation do you need. In practical terms this means that instead of loading and dequantizing the weight matrices for each individual token you can instead reuse them for many tokens and you therefore need way fewer memory accesses per token. The simple implementation of just doing one dot product per entry in the final matrix does not take advantage of this and requires $O(N^3)$ memory accesses for two square matrices.

Caveat: there are algorithms that need less than $O(N^3)$ operations but let's not overcomplicate things.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jun 19, 2023

@JohannesGaessler
I did a quick fix for that already here:
cmp-nct@7c8249f

  1. It uses a "best fit" method to find the best free buffer
    (that's just 2-3 for cycles, no performance loss)
  2. I added a function that frees up all unused cuda buffers on a device
    I am calling that function right before n_batch processing and after n_batch.
    I'm not sure if it's the same situation on llama.cpp but for falcon 100% of the eval loop runs as vector multiplications except for multi token processing, those go through cuBLAS so those first set of buffers are way too large for later use, wasting VRAM.

That combined frees up about 1 GB of VRAM that can be used for quantized offloading instead.
I plan to run a 2nd offload after n_batch processing which allows to offload additional layers from those buffers that won't be needed anymore.

  1. In another earlier test I converted the 2nd parameters to 16 bit and used cuBLAS in partial 16 bit mode instead of 32 bit. The performance was 1:1 the same as with 32 but memory consumption is half.

@JohannesGaessler
Copy link
Collaborator Author

@cmp-nct There definitely are ways to optimize the current method for prompt processing. I'm just thinking that a custom matrix multiplication kernel would make such optimizations unnecessary so I wanted to talk about that.

@0cc4m
Copy link
Collaborator

0cc4m commented Jun 20, 2023

A similar optimization to the one proposed by @cmp-nct has already been implemented for OpenCL in #1675 and shown good results.

@ikawrakow
Copy link
Contributor

This is where your logic is going wrong. When you multiply a square matrix of size $N$ with a vector then you are doing $O(N^2)$ operations on $O(N^2)$ data values. However, if you multiply two square matrices you instead do $O(N^3)$ operations on $O(N^2)$ data values so you are much less I/O bound.

That would only be true if the two matrices fit completely in fast cache, no? If we want to do big-O analysis, lets look at a processor where exactly one row from the left matrix ($A$) and one column from the right matrix ($B$) fit in cache (and we don't worry about cache->processor throughput). We fetch a row from $A$ and a column from $B$ and do the dot product. This is $O(N)$ data fetches and $O(N)$ operations. Now we need to evict either the row or the column from cache to get the next piece of data. Let's decide to evict the column. We fetch the next column and do the math. We need to repeat this $N$ times, each time fetching $N$ elements, so $O(N^2)$ data fetches and $O(N^2)$ operations. We need to repeat this $N$ times for the $N$ rows of $A$, so $O(N^3)$ data fetches and $O(N^3)$ operations. In reality modern processors are better than that, so we can hold $L$ rows/columns (or blocks of size $L N$) from $A$ and $B$. That reduces memory fetches into cache to $O(N^3/L)$ (but if we are doing big-O analysis, $1/L$ is just an irrelevant constant factor). Hence, the matrix multiplication eventually also becomes memory bound.

When I suggested that the simplest possible matrix multiplication is just a series of dot products, I did not mean literally just individual dot products. Instead, a relatively simple change is to make the dot product function multiply $L$ columns instead of just 1. That should be enough to match dequantize+cuBLAS performance.

@JohannesGaessler
Copy link
Collaborator Author

That would only be true if the two matrices fit completely in fast cache, no?

It is true regardless of how much cache you have. In practice however the amount of cache/CUDA shared memory is a limitation. For e.g. tiling algorithms you get better performance with larger tiles but those tiles have to fit into cache/VRAM to properly reduce memory accesses.

When I suggested that the simplest possible matrix multiplication is just a series of dot products, I did not mean literally just individual dot products. Instead, a relatively simple change is to make the dot product function multiply columns instead of just 1. That should be enough to match dequantize+cuBLAS performance.

You're free to try but I highly doubt it.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jun 28, 2023

I've a half-hacky solution on ggllm.cpp until a full mul-mat integer kernel is available.
Two commits would probably transfer over quite well to llama.cpp:

  1. I added a "access counter" into the buffer and after each eval I clean up buffers that were not accessed, after a couple runs without cleanups the activity is stopped (stable). That saves a up to 1.5 gigabytes of GPU VRAM depending on model size.

  2. The bigger impact is that I added 16 bit dequantization kernels (ggml_get_to_fp16_cuda) for all quantizers and a ggml_cuda_op_mul_mat_cublas_f16_f32 function.
    I did not entirely replace the 32 bit cuBLAS but I think it's a full replacement, tested a lot of generations and it appears to work reliably.
    Some of the q->32 bit tensors were 2 GB in size, that's a lot of VRAM. All of them are now half in size which is used for more offloaded layers.

I integrated the f16 multiplication using a 32 bit wrapper and by exchanging the function pointer right in ggml_cuda_op(), that's the hacky part but anything else would have changed my ggml-cuda implementation too much to maintain later.

I'm not super happy that I duplicated all kernels, in hindsight I think maybe a 32->16 wrapper around them would also have been possible, overall those two changes saved gigabytes of VRAM (depending on parameters used and model size).

I'm not sure but wouldn't it make sense to use 16 bit for everything in ggml-cuda ? 32 bit seems so wasteful to me

Maybe some of it is useable:
cmp-nct@4ca3961#diff-66b17223e8ba54054fb2600ecbd31107f8b917bac36c7f3789811b0f0e9802a1

@JianbangZ
Copy link

@JohannesGaessler Is this related to what I asked? #2118
I am not cuda expert by any means, but maybe helpful to take a look at https://github.com/ztxz16/fastllm/tree/master/src/devices/cuda
fastllm is something I have been testing with for the past 2 days, I found their inference speed on GPU with quant model is not as good as the current version of ggml (65 t/s for fastllm int4 vs 106 t/s for ggml q4_0 on RTX A6000), fp16 inference about the same for both (44 t/s), but their solution takes much less memory.
From my testing, for the same vicuna_7b_v1.3 model, when running
llama.cpp/ggml q4_0 CPU RAM = 3600 MB, VRAM 4800 MB
fastllm int4 CPU RAM = 1200 MB, VRAM 3600 MB

@JohannesGaessler
Copy link
Collaborator Author

I don't understand Chinese so I have no idea what that repository is doing. Even if I did, that project is very likely not directly comparable to ggml-based projects and I don't want to dig through the source code to find out the differences.

@JohannesGaessler
Copy link
Collaborator Author

Superseded by #2160 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hardware Hardware related performance Speed related topics
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants