Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Measure model memory usage #3120

Merged
merged 7 commits into from
Mar 7, 2024

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Feb 29, 2024

There is already a measure for kv cache blocks memory usage, indirectly through how many blocks were allocated, but no direct measure of how much memory the model weights are using. This PR tries to add that by wrapping the model loading with torch.cuda.max_memory_allocated() calls. I'm not sure how this will work with non-Nvidia devices, so happy to disable this in that case

Exposes a new ModelRunner.model_memory_usage member variable

Example code:

from vllm import LLM
LLM("facebook/opt-125m")

Output:

INFO 02-29 19:21:40 llm_engine.py:79] Initializing an LLM engine with config: model='facebook/opt-125m', tokenizer='facebook/opt-125m', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, sparsity=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 02-29 19:21:44 weight_utils.py:176] Using model weights format ['*.bin']
>> INFO 02-29 19:21:45 model_runner.py:90] Loading model weights took 244.61 MB
INFO 02-29 19:21:45 llm_engine.py:338] # GPU blocks: 76240, # CPU blocks: 7281

@mgoin mgoin marked this pull request as ready for review February 29, 2024 19:25
@mgoin
Copy link
Member Author

mgoin commented Mar 4, 2024

Hey @simon-mo , what do you think about this?

@mgoin
Copy link
Member Author

mgoin commented Mar 6, 2024

@WoosukKwon @zhuohan123 what do you think?

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left some small comments

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 7, 2024

Is that right with tensor_parallel_size > 1?

@mgoin
Copy link
Member Author

mgoin commented Mar 7, 2024

Thanks for the reviews @zhuohan123 and @esmeetu. For TP>1, my assumption is it still makes sense to report the per-worker model memory usage rather than trying to figure out and pipe the whole model memory usage to all workers.

To be explicit, here is the output of running with TP=1 and TP=2

TP=1

> CUDA_VISIBLE_DEVICES=7 python -c 'from vllm import LLM;LLM("facebook/opt-125m", tensor_parallel_size=1)'
INFO 03-07 15:15:10 llm_engine.py:88] Initializing an LLM engine (v0.3.3) with config: model='facebook/opt-125m', tokenizer='facebook/opt-125m', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 03-07 15:15:14 model_runner.py:96] Loading model weights took 0.2389 GB

TP=2

> CUDA_VISIBLE_DEVICES=6,7 python -c 'from vllm import LLM;LLM("facebook/opt-125m", tensor_parallel_size=2)'
2024-03-07 15:15:34,427 INFO worker.py:1724 -- Started a local Ray instance.
INFO 03-07 15:15:36 llm_engine.py:88] Initializing an LLM engine (v0.3.3) with config: model='facebook/opt-125m', tokenizer='facebook/opt-125m', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=2, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 03-07 15:15:47 model_runner.py:96] Loading model weights took 0.1189 GB
(RayWorkerVllm pid=798461) INFO 03-07 15:15:48 model_runner.py:96] Loading model weights took 0.1189 GB

Here you see that the model weights look evenly split between runners

@zhuohan123 zhuohan123 merged commit 385da2d into vllm-project:main Mar 7, 2024
22 checks passed
AdrianAbeyta pushed a commit to AdrianAbeyta/vllm that referenced this pull request Mar 8, 2024
dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants