Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Core] Use torch.cuda.memory_stats() to profile peak memory usage #9352

Merged
merged 11 commits into from
Oct 18, 2024

Conversation

joerunde
Copy link
Contributor

I've noticed odd behavior with the peak memory profiling for the new multi-modal mllama models with cross attention. It's known that we need to decrease the max_num_seqs parameter because cross attention will create very large tensors to handle multi-modal inputs. However, when playing around with the max_num_seqs settings I noticed that

vllm serve meta-llama/Llama-3.2-11B-Vision-Instruct --enforce-eager --max-num-seqs 69

fails on an 80GB A100, reporting that the peak memory usage is 83464552448 bytes, and there is no room to create a KV cache. However:

vllm serve meta-llama/Llama-3.2-11B-Vision-Instruct --enforce-eager --max-num-seqs 70

happily serves, reporting a peak memory usage of 50562334720 bytes.

The memory profile for each looks about the same, as you would expect. Here is 70:
image
and 69:
image

Notably, the 69 case vastly overestimates peak memory, while the 70 case underestimates it by about 8GB

From looking at the cached segment timeline, it appears that --max-num-seqs 70 triggers GC during the forward pass, while --max-num-seqs 69 does not:

70 has GC:
image

69 has no GC:
image

It looks like the current strategy of determining peak memory usage is to measure the current free gpu memory after running a forward pass. This is sensitive to garbage collection as shown here, it will overestimate if gc doesn't happen and large tensors that were allocated and freed at completely different times are all counted, and it can underestimate if gc does happen and we free up space that was used by large tensors during the forward pass that are no longer accounted for.

It seems to me that this problem is only now exacerbated by the cross attention models which need to allocate a large amount of memory to track the cross attention states.

This change instead uses torch.cuda.memory_stats()["allocated_bytes.all.peak"] (see https://pytorch.org/docs/stable/generated/torch.cuda.memory_stats.html) to measure the actual peak memory allocation during the forward pass.

After this change, both --max-num-seqs 69 and --max-num-seqs 70 properly fail in this case, as there isn't enough ram to build a KV cache for 130k tokens. As a bonus, vllm serve meta-llama/Llama-3.2-11B-Vision-Instruct --enforce-eager --max-num-seqs 48 now works, as it correctly determines that the KV cache can fit. (Our guidance here says to use 16, previously I think the maximum you could set and boot on an A100 was 28 #8826)

I'm assuming that there is historical context that I'm missing as to why this method wasn't used originally- so I'm okay to be told why this won't work

FIX #7256

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@njhill
Copy link
Member

njhill commented Oct 15, 2024

Thanks @joerunde this is awesome!

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 15, 2024
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
@@ -29,7 +29,7 @@ def test_lazy_outlines(sample_regex):
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
gpu_memory_utilization=0.6)
Copy link
Contributor Author

@joerunde joerunde Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test was working before due to the over-estimation of peak memory usage of the model which caused a smaller KV cache to be allocated. Two LLMs both set gpu_memory_utilization=0.3, but once the first LLM uses the full 30% of the gpu, there's no space left to allocate room for the second one.

This setting is a bit confusing- how it has been coded is "The total GPU allocation may not exceed x% of the gpu memory when loading this model", but it looks like the test assumed the setting meant "You may not allocate more than x% of the gpu memory for this model, regardless of how much of the gpu memory ends up being allocated." In other words, it assumed this was a per-model limit and not a global limit on gpu memory allocation.

Maybe that should be made more clear in the docs?

(Just a comment for readers- I don't intend to make more docs changes in the scope of this PR)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Could you add a short comment, since the reader may find it odd that the second call sets gpu_memory_utilization differently from the first?

Alternatively, looks like the first llm doesn't need to be live when the second one is created, so we could try to force it to be garbage collected but I don't think it's worth jumping through hoops for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's worth jumping through hoops for this

I agree :D

I added a small comment here for future readers

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Comment on lines 276 to 278
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this gc.collect() should also be moved up above torch.cuda.empty_cache()?

Also I'm not sure what the reason for the remove_all_loras() is here.. perhaps that should also be moved up to before the gc is done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I'm not sure on that either, I'm guessing we're just trying to clean up everything we did during profiling, before the KV cache is allocated?

re: moving the gc.collect(), I was trying to leave it later here in case there was something allocated outside torch that hadn't been GC'ed yet that we may need to account for in the peak memory usage. If we run all the cleanup and then check the free memory, then the only reason it would be lower is if there's a memory leak, right?

idk- I could go either way. I'm not 100% sold we need the extra check for non-torch allocated memory since it's pretty flaky to try to check for. Think we should just back that out and leave the torch.empty_cache() down here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: @tjohnson31415 will give this a go to see if he can reproduce the NCCL allocations he was seeing that were blowing up vram usage. If this code catches it we'll keep it in, if not I'll back it out

Copy link
Contributor

@tjohnson31415 tjohnson31415 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs a small fix, but the leftover_allocations does pick up on the extra memory allocated by NCCL during the profile run. When it is not accounted for, I got an OOM when allocating the KV Cache for 405B w/ TP=8...

The call to remove_all_loras() seems like it would make more sense to have in profile_run(). Any tensors allocated for LoRA would be included in the torch allocated peak.

In my test with Llama-3.1-8B-Instruct w/ TP=8, moving gc.collect() and remove_all_loras() above leftover_allocations made no difference in the printed messages, so they only clean up a small amount of memory if anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🌶️🌶️🌶️!

The call to remove_all_loras() seems like it would make more sense to have in profile_run(). Any tensors allocated for LoRA would be included in the torch allocated peak.

I agree the cleanup in general could better live in the profile run executions, but I do want to limit the blast radius here to this file. I'll leave as-is unless anybody feels strongly about refactoring into the individual model runners

Comment on lines 244 to 245
leftover_allocations = torch.cuda.mem_get_info(
)[0] - self.init_gpu_memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.init_gpu_memory is the free memory on the device before the model is even loaded, so the value of leftover_allocations will always be negative.
I think what you want here is the free memory before running profile_run() and then subtract the current free memory from that.

Comment on lines 257 to 264
logger.info("Initial memory usage before profile: %.2f GB",
(total_gpu_memory - self.init_gpu_memory) / (1024**3))
logger.info("Peak memory usage during profile: %.2f GB",
peak_memory / (1024**3))
logger.info(
"Available memory for KV cache with %.2f gpu utilization: %.2f GB",
self.cache_config.gpu_memory_utilization,
available_kv_cache_memory / (1024**3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These messages are a little spammy when running with TP, maybe just a one message summary is good enough

Something like this (with the right values populated):

        logger.info("Memory profiling results: initial_memory=%.2fGiB"
                    " peak_torch_memory=%.2fGiB non_torch_memory=%.2fGiB"
                    " kv_cache_size=%.2fGiB gpu_memory_utilization=%.2f"

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of inline comments but I think this is really good

@@ -29,7 +29,7 @@ def test_lazy_outlines(sample_regex):
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
gpu_memory_utilization=0.6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Could you add a short comment, since the reader may find it odd that the second call sets gpu_memory_utilization differently from the first?

Alternatively, looks like the first llm doesn't need to be live when the second one is created, so we could try to force it to be garbage collected but I don't think it's worth jumping through hoops for this

engine_config.cache_config, engine_config.model_config,
engine_config.parallel_config)

assert gpu_blocks == (8.2843 * 1024**3) // block_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a tolerance here? Since the profile run actually runs the model, this can be influenced by subtle changes in the model execution. For example any workspace used in GEMMs, or anything to do with torch.compile can change the memory footprint. I wouldn't expect this to be stable across either GPU architecture or version of vLLM

return num_gpu_blocks, num_cpu_blocks

def assert_no_other_gpu_processes(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this function name a little misleading, since IIUC it doesn't actually assert that there are no other GPU processes, but rather it's checking that if any other process frees memory during the profile run then the result of profiling is invalid.

Maybe assert_memory_footprint_increased_during_profiling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds way better 👍

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
@joerunde
Copy link
Contributor Author

Thanks for the feedback @tlrmchlsmth!

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

Looks like it would make sense to try the same approach for xpu_worker as well? (cpu_worker, neuron_worker, and openvino_worker don't do profile runs. And for the tpu_worker there's a similar PR #9438 in progress)

# Check within a small tolerance for portability
# Hardware, kernel, or dependency changes could all affect memory
# utilization
assert abs(gpu_blocks - expected_blocks) < 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is a large enough tolerance TBH but am good with setting it to something and adjusting in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, well I also don't have too much context personally for how much this could swing with any hardware or software differences. At least this currently works on the A100s I tested on, and whatever worker nodes the CI runs have landed on today 😉

I also don't want to go too wide on the tolerance and end up having this test pass if some changes are accidentally made to the profiling code. This test should catch about 8MB of memory allocated outside of torch, and 5 blocks should be about 3MB in this configuration. I can bump it up to 10 so there's 6MB of wiggle room if that sounds alright.

I will also happily accept people yelling at me if this test becomes super flaky

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll keep an eye on it

@tlrmchlsmth tlrmchlsmth merged commit de4008e into vllm-project:main Oct 18, 2024
54 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: profile_run Inaccurate estimation leads to gpu OutOfMemoryError
4 participants