Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track and free temporary ggml_tensor_extra_gpu struct #2195

Closed
wants to merge 1 commit into from

Conversation

bullno1
Copy link
Contributor

@bullno1 bullno1 commented Jul 12, 2023

Fix #2145.

Temporary allocations during eval are tracked and freed at the end.
I decided to go with the implicit context idea here: #2146 (comment) since the code change is minimal.

Pooling could be added if needed and freed in llama_backend_free.

Tested on my machine with --ignore-eos to keep generation running and RAM usage does not increase anymore.

@JohannesGaessler
Copy link
Collaborator

I think this solution is overengineered. How about this instead: allocate a small buffer in the beginning and re-use it to hold the ggml_tensor_extra_gpu structs. This would guarantee that there is no additional memory allocation during evaluation and only a single buffer needs to be freed at the end. I'm not sure what the size of ggml_tensor_extra_gpu is but that multiplied with the define GGML_MAX_NODES == 4096 should be enough and relatively small. You wouldn't even need to explicitly specify when the evaluation starts or ends since you could simply have a counter for the position in the buffer (in units of sizeof(ggml_tensor_extra_gpu)) and loop once you reach GGML_MAX_NODES.

@bullno1
Copy link
Contributor Author

bullno1 commented Jul 13, 2023

That works but where is this pool though? Global?

There are 4 cases here: https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda.cu#L3335
I don't know for sure which one is temporary.

So how to know when to alloc from the pool?

The begin/end was there because I can't decide which one is temporary.

@JohannesGaessler
Copy link
Collaborator

That works but where is this pool though? Global?

Since right now everything in ggml-cuda.cu is global I would make the pool global too. However, I believe @slaren is currently working on changing that so maybe he has some input?

So how to know when to alloc from the pool?

The inplace and cpy cases are for tensors that don't actually change the data. So they use the data pointers of the tensors with the actual data. The scratch case is for tensors that hold only temporary results that are okay to overwrite at a later date. The last case is needed for the KV cache whose data should not be overwritten, thus it's not on the scratch buffer. Currently there is a lot of overlap and confusion between ggml_cuda_transform_tensor (which is used for weights) and ggml_cuda_assign_buffers (which is used for all other tensors); at some point the code should maybe be refactored.

In any case, the pool should be used for the inplace, cpy, and scratch cases; the KV cache data should not be overwritten and it gets already freed by the llama_kv_cache destructor.

@bullno1
Copy link
Contributor Author

bullno1 commented Jul 13, 2023

In llama_apply_lora_from_file_internal, there is assign buffer though.

What should be done about those?

Edit: Wait, those are temporary? lora_ctx is freed every iteration.
I thought they are for weights.

@slaren
Copy link
Collaborator

slaren commented Jul 13, 2023

However, I believe @slaren is currently working on changing that so maybe he has some input?

What I am working on is going to change significantly how resources are managed. I will open a draft PR in the next days that will clarify some of these things, but it's going to take a while until it is ready, multi-GPU is not even supported in my branch yet. So if you need to fix this now, just do it in whatever way is more convenient to you, and don't worry too much about making the design future-proof.

Edit: Wait, those are temporary? lora_ctx is freed every iteration.

The loras are merged into the model weights, so whatever resources are needed to apply them, they aren't used afterward.

@JohannesGaessler
Copy link
Collaborator

In llama_apply_lora_from_file_internal, there is assign buffer though.
What should be done about those?

When a LoRA is applied a small graph that modifies the weights is executed. The final node is pre-allocated but there are some temporary tensors in-between. There is no practical difference compared to the larger graphs during eval.

@bullno1
Copy link
Contributor Author

bullno1 commented Jul 14, 2023

Close in favor of: #2220

@bullno1 bullno1 closed this Jul 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[User] ggml_tensor->extra(s) are not freed at the end of llama_eval, causing a memory leak
3 participants