Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Core] Use torch.cuda.memory_stats() to profile peak memory usage #9352

Merged
merged 11 commits into from
Oct 18, 2024
2 changes: 1 addition & 1 deletion tests/entrypoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_lazy_outlines(sample_regex):
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
gpu_memory_utilization=0.6)
Copy link
Contributor Author

@joerunde joerunde Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test was working before due to the over-estimation of peak memory usage of the model which caused a smaller KV cache to be allocated. Two LLMs both set gpu_memory_utilization=0.3, but once the first LLM uses the full 30% of the gpu, there's no space left to allocate room for the second one.

This setting is a bit confusing- how it has been coded is "The total GPU allocation may not exceed x% of the gpu memory when loading this model", but it looks like the test assumed the setting meant "You may not allocate more than x% of the gpu memory for this model, regardless of how much of the gpu memory ends up being allocated." In other words, it assumed this was a per-model limit and not a global limit on gpu memory allocation.

Maybe that should be made more clear in the docs?

(Just a comment for readers- I don't intend to make more docs changes in the scope of this PR)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Could you add a short comment, since the reader may find it odd that the second call sets gpu_memory_utilization differently from the first?

Alternatively, looks like the first llm doesn't need to be live when the second one is created, so we could try to force it to be garbage collected but I don't think it's worth jumping through hoops for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's worth jumping through hoops for this

I agree :D

I added a small comment here for future readers

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
Expand Down
37 changes: 31 additions & 6 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
Expand All @@ -228,29 +229,53 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
assert self.init_gpu_memory - free_gpu_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")

# Get the peak memory allocation recorded by torch
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

# Edge case: Check for any memory left around that may have been
# allocated on the gpu outside of `torch`
torch.cuda.empty_cache()
leftover_allocations = torch.cuda.mem_get_info(
)[0] - self.init_gpu_memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.init_gpu_memory is the free memory on the device before the model is even loaded, so the value of leftover_allocations will always be negative.
I think what you want here is the free memory before running profile_run() and then subtract the current free memory from that.

if leftover_allocations > 0:
logger.info(
"Found %.2f GB of allocated memory leftover after clearing "
"torch cache. Adding to peak memory usage.",
leftover_allocations / (1024**3))
peak_memory += leftover_allocations

available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)

logger.info("Initial memory usage before profile: %.2f GB",
(total_gpu_memory - self.init_gpu_memory) / (1024**3))
logger.info("Peak memory usage during profile: %.2f GB",
peak_memory / (1024**3))
logger.info(
"Available memory for KV cache with %.2f gpu utilization: %.2f GB",
self.cache_config.gpu_memory_utilization,
available_kv_cache_memory / (1024**3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These messages are a little spammy when running with TP, maybe just a one message summary is good enough

Something like this (with the right values populated):

        logger.info("Memory profiling results: initial_memory=%.2fGiB"
                    " peak_torch_memory=%.2fGiB non_torch_memory=%.2fGiB"
                    " kv_cache_size=%.2fGiB gpu_memory_utilization=%.2f"


cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
Comment on lines 274 to 276
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this gc.collect() should also be moved up above torch.cuda.empty_cache()?

Also I'm not sure what the reason for the remove_all_loras() is here.. perhaps that should also be moved up to before the gc is done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I'm not sure on that either, I'm guessing we're just trying to clean up everything we did during profiling, before the KV cache is allocated?

re: moving the gc.collect(), I was trying to leave it later here in case there was something allocated outside torch that hadn't been GC'ed yet that we may need to account for in the peak memory usage. If we run all the cleanup and then check the free memory, then the only reason it would be lower is if there's a memory leak, right?

idk- I could go either way. I'm not 100% sold we need the extra check for non-torch allocated memory since it's pretty flaky to try to check for. Think we should just back that out and leave the torch.empty_cache() down here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: @tjohnson31415 will give this a go to see if he can reproduce the NCCL allocations he was seeing that were blowing up vram usage. If this code catches it we'll keep it in, if not I'll back it out

Copy link
Contributor

@tjohnson31415 tjohnson31415 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs a small fix, but the leftover_allocations does pick up on the extra memory allocated by NCCL during the profile run. When it is not accounted for, I got an OOM when allocating the KV Cache for 405B w/ TP=8...

The call to remove_all_loras() seems like it would make more sense to have in profile_run(). Any tensors allocated for LoRA would be included in the torch allocated peak.

In my test with Llama-3.1-8B-Instruct w/ TP=8, moving gc.collect() and remove_all_loras() above leftover_allocations made no difference in the printed messages, so they only clean up a small amount of memory if anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🌶️🌶️🌶️!

The call to remove_all_loras() seems like it would make more sense to have in profile_run(). Any tensors allocated for LoRA would be included in the torch allocated peak.

I agree the cleanup in general could better live in the profile run executions, but I do want to limit the blast radius here to this file. I'll leave as-is unless anybody feels strongly about refactoring into the individual model runners

torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks

def initialize_cache(self, num_gpu_blocks: int,
Expand Down