Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) #6611

Merged
merged 13 commits into from
Jul 26, 2024

Conversation

mgoin
Copy link
Collaborator

@mgoin mgoin commented Jul 20, 2024

FIX #5722

Based off huggingface/transformers#31699 - Nemotron-3 loads and produces reasonable output. Nemotron-4 and the most recently released Minitron works and evals can be reproduced.

For CI, a Minitron-4B-Base GSM8k eval has been added to the lm-eval test suite.

The architecture is pretty similar to Llama, with these changes:

  • There is no gate_proj, just up_proj
  • Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
  • Squared ReLU instead of SwiGLU
  • Adds a rotary_percent to RoPE

Collection of checkpoints (Nemotron-3, Nemotron-4, Minitron): https://huggingface.co/collections/mgoin/nemotron-in-vllm-66a151b4240bcd9c28735ec5

Loading nvidia/Minitron-4B-Base:

>>> from vllm import LLM
>>> model = LLM("nvidia/Minitron-4B-Base")
INFO 07-23 14:18:22 llm_engine.py:175] Initializing an LLM engine (v0.5.2) with config: model='nvidia/Minitron-4B-Base', speculative_config=None, tokenizer='nvidia/Minitron-4B-Base', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=nvidia/Minitron-4B-Base, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-23 14:18:26 weight_utils.py:219] Using model weights format ['*.bin']
pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8.38G/8.38G [02:13<00:00, 62.9MB/s]
INFO 07-23 14:20:44 model_runner.py:563] Loading model weights took 7.8059 GB
INFO 07-23 14:20:44 gpu_executor.py:102] # GPU blocks: 31261, # CPU blocks: 2048
INFO 07-23 14:20:45 model_runner.py:851] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 07-23 14:20:45 model_runner.py:855] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 07-23 14:20:49 model_runner.py:1052] Graph capturing finished in 3 secs.
>>> model.generate("I love you!")
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.51it/s, est. speed input: 30.07 toks/s, output: 120.28 toks/s]
[RequestOutput(request_id=0, prompt='I love you!', prompt_token_ids=[1317, 2624, 1346, 252397], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' I just wish things were about more than one of us going to be tragically injured', token_ids=(1317, 1722, 5962, 2644, 1671, 1632, 1602, 1742, 1610, 1299, 1512, 2433, 1294, 1354, 133186, 13956), cumulative_logprob=-56.41404923796654, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721744492.323585, last_token_time=1721744492.323585, first_scheduled_time=1721744492.3276443, first_token_time=1721744492.3389344, time_in_queue=0.004059314727783203, finished_time=1721744492.4604568), lora_request=None)]

Loading nemotron3-8b-base:

>>> from vllm import LLM, SamplingParams
>>> model = LLM("thhaus/nemotron3-8b")
INFO 07-21 00:55:19 llm_engine.py:175] Initializing an LLM engine (v0.5.2) with config: model='thhaus/nemotron3-8b', speculative_config=None, tokenizer='thhaus/nemotron3-8b', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=thhaus/nemotron3-8b, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-21 00:55:23 weight_utils.py:219] Using model weights format ['*.safetensors']
INFO 07-21 00:55:26 model_runner.py:563] Loading model weights took 15.9077 GB
INFO 07-21 00:55:27 gpu_executor.py:102] # GPU blocks: 6770, # CPU blocks: 512
INFO 07-21 00:55:28 model_runner.py:851] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 07-21 00:55:28 model_runner.py:855] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 07-21 00:55:32 model_runner.py:1052] Graph capturing finished in 4 secs.
>>> output = model.generate("1, 2, 3, ", SamplingParams(temperature=0, max_tokens=100))
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.26s/it, est. speed input: 8.75 toks/s, output: 79.50 toks/s]
>>> print(output)
[RequestOutput(request_id=0, prompt='1, 2, 3, ', prompt_token_ids=[2, 251490, 251525, 251514, 251490, 251527, 251514, 251490, 251556, 251514, 251490], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30', token_ids=(251564, 251514, 251490, 251557, 251514, 251490, 251574, 251514, 251490, 251581, 251514, 251490, 251577, 251514, 251490, 251561, 251514, 251490, 251525, 251521, 251514, 251490, 251525, 251525, 251514, 251490, 251525, 251527, 251514, 251490, 251525, 251556, 251514, 251490, 251525, 251564, 251514, 251490, 251525, 251557, 251514, 251490, 251525, 251574, 251514, 251490, 251525, 251581, 251514, 251490, 251525, 251577, 251514, 251490, 251525, 251561, 251514, 251490, 251527, 251521, 251514, 251490, 251527, 251525, 251514, 251490, 251527, 251527, 251514, 251490, 251527, 251556, 251514, 251490, 251527, 251564, 251514, 251490, 251527, 251557, 251514, 251490, 251527, 251574, 251514, 251490, 251527, 251581, 251514, 251490, 251527, 251577, 251514, 251490, 251527, 251561, 251514, 251490, 251556, 251521), cumulative_logprob=-4.905927623214666, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721523332.4632323, last_token_time=1721523332.4632323, first_scheduled_time=1721523332.4670587, first_token_time=1721523332.4824336, time_in_queue=0.0038263797760009766, finished_time=1721523333.724725), lora_request=None)]

Loading mgoin/Nemotron-4-340B-Instruct-FP8-Dynamic on 8xA100:

>>> from vllm import LLM, SamplingParams
>>> model = LLM("mgoin/Nemotron-4-340B-Instruct-FP8-Dynamic", tensor_parallel_size=8, distributed_executor_backend="ray", enforce_eager=True)
2024-07-21 03:09:56,222 INFO worker.py:1749 -- Started a local Ray instance.
INFO 07-21 03:09:57 llm_engine.py:175] Initializing an LLM engine (v0.5.2) with config: model='mgoin/Nemotron-4-340B-Instruct-FP8-Dynamic', speculative_config=None, tokenizer='mgoin/Nemotron-4-340B-Instruct-FP8-Dynamic', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/home/mgoin/models/Nemotron-4-340B-Instruct-FP8, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2
INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-21 03:10:17 shm_broadcast.py:240] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7fdf89ac73d0>, local_subscribe_port=56965, local_sync_port=40185, remote_subscribe_port=None, remote_sync_port=None)
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
(RayWorkerWrapper pid=1233230) WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
WARNING 07-21 03:10:35 utils.py:569] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
INFO 07-21 03:10:36 model_runner.py:563] Loading model weights took 40.8975 GB
(RayWorkerWrapper pid=1233718) WARNING 07-21 03:10:45 utils.py:569] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2 [repeated 6x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5 [repeated 6x across cluster]
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json [repeated 6x across cluster]
(RayWorkerWrapper pid=1234002) WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change. [repeated 6x across cluster]
(RayWorkerWrapper pid=1233718) INFO 07-21 03:10:46 model_runner.py:563] Loading model weights took 40.8975 GB
INFO 07-21 03:10:54 distributed_gpu_executor.py:56] # GPU blocks: 22421, # CPU blocks: 3640
>>> model.generate("Hello!")
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.19s/it, est. speed input: 1.69 toks/s, output: 13.49 toks/s]
[RequestOutput(request_id=0, prompt='Hello!', prompt_token_ids=[14716, 252397], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='8493045638869919', token_ids=(252377, 252372, 252370, 252365, 252334, 252372, 252366, 252376, 252365, 252377, 252377, 252376, 252370, 252370, 252338, 252370), cumulative_logprob=-38.44283938407898, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721531506.653993, last_token_time=1721531506.653993, first_scheduled_time=1721531506.6587515, first_token_time=1721531506.8336384, time_in_queue=0.004758596420288086, finished_time=1721531507.8438315), lora_request=None)]

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mgoin
Copy link
Collaborator Author

mgoin commented Jul 21, 2024

Able to load Nemotron-4-340B-Instruct (had to make a lot of edits to this checkpoint, will upload) with cpu offloading:

>>> from vllm import LLM, SamplingParams
>>> model = LLM("/home/mgoin/models/Nemotron-4-340B-Instruct", tensor_parallel_size=8, distributed_executor_backend="ray", cpu_offload_gb=20, enforce_eager=True)
2024-07-21 02:32:18,002 INFO worker.py:1749 -- Started a local Ray instance.
INFO 07-21 02:32:19 llm_engine.py:175] Initializing an LLM engine (v0.5.2) with config: model='/home/mgoin/models/Nemotron-4-340B-Instruct', speculative_config=None, tokenizer='/home/mgoin/models/Nemotron-4-340B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/home/mgoin/models/Nemotron-4-340B-Instruct, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-21 02:32:36 utils.py:779] Found nccl from library libnccl.so.2
INFO 07-21 02:32:36 pynccl.py:63] vLLM is using nccl==2.20.5
(RayWorkerWrapper pid=1139686) INFO 07-21 02:32:36 utils.py:779] Found nccl from library libnccl.so.2
(RayWorkerWrapper pid=1139686) INFO 07-21 02:32:36 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 07-21 02:32:39 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-21 02:32:39 shm_broadcast.py:240] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7fda1b3b33d0>, local_subscribe_port=55417, local_sync_port=53103, remote_subscribe_port=None, remote_sync_port=None)
(RayWorkerWrapper pid=1139686) INFO 07-21 02:32:39 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-21 02:33:14 model_runner.py:563] Loading model weights took 59.3774 GB
(RayWorkerWrapper pid=1140131) INFO 07-21 02:33:34 model_runner.py:563] Loading model weights took 59.3774 GB
(RayWorkerWrapper pid=1140410) INFO 07-21 02:32:36 utils.py:779] Found nccl from library libnccl.so.2 [repeated 6x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayWorkerWrapper pid=1140410) INFO 07-21 02:32:36 pynccl.py:63] vLLM is using nccl==2.20.5 [repeated 6x across cluster]
(RayWorkerWrapper pid=1140410) INFO 07-21 02:32:39 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json [repeated 6x across cluster]
INFO 07-21 02:33:42 distributed_gpu_executor.py:56] # GPU blocks: 4867, # CPU blocks: 3640
>>> model.generate("Hello!")
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:16<00:00, 16.29s/it, est. speed input: 0.12 toks/s, output: 0.98 toks/s]
[RequestOutput(request_id=0, prompt='Hello!', prompt_token_ids=[14716, 252397], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='8493045638869919', token_ids=(252377, 252372, 252370, 252365, 252334, 252372, 252366, 252376, 252365, 252377, 252377, 252376, 252370, 252370, 252338, 252370), cumulative_logprob=-38.362853050231934, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721529255.4465594, last_token_time=1721529255.4465594, first_scheduled_time=1721529255.4538574, first_token_time=1721529256.6029613, time_in_queue=0.007297992706298828, finished_time=1721529271.7377276), lora_request=None)]

@mgoin
Copy link
Collaborator Author

mgoin commented Jul 21, 2024

I made an FP8 W8 quantized checkpoint based on the above and it produces the same output. The tokenizer is not right for this model so that is the next step.

>>> from vllm import LLM, SamplingParams
>>> model = LLM("/home/mgoin/models/Nemotron-4-340B-Instruct-FP8", tensor_parallel_size=8, distributed_executor_backend="ray", enforce_eager=True)
2024-07-21 03:09:56,222 INFO worker.py:1749 -- Started a local Ray instance.
INFO 07-21 03:09:57 llm_engine.py:175] Initializing an LLM engine (v0.5.2) with config: model='/home/mgoin/models/Nemotron-4-340B-Instruct-FP8', speculative_config=None, tokenizer='/home/mgoin/models/Nemotron-4-340B-Instruct-FP8', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/home/mgoin/models/Nemotron-4-340B-Instruct-FP8, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2
INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-21 03:10:17 shm_broadcast.py:240] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7fdf89ac73d0>, local_subscribe_port=56965, local_sync_port=40185, remote_subscribe_port=None, remote_sync_port=None)
(RayWorkerWrapper pid=1233230) INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
(RayWorkerWrapper pid=1233230) WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
WARNING 07-21 03:10:35 utils.py:569] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
INFO 07-21 03:10:36 model_runner.py:563] Loading model weights took 40.8975 GB
(RayWorkerWrapper pid=1233718) WARNING 07-21 03:10:45 utils.py:569] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:14 utils.py:779] Found nccl from library libnccl.so.2 [repeated 6x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:14 pynccl.py:63] vLLM is using nccl==2.20.5 [repeated 6x across cluster]
(RayWorkerWrapper pid=1234002) INFO 07-21 03:10:17 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json [repeated 6x across cluster]
(RayWorkerWrapper pid=1234002) WARNING 07-21 03:10:17 fp8.py:38] Detected fp8 checkpoint. Please note that the format is experimental and subject to change. [repeated 6x across cluster]
(RayWorkerWrapper pid=1233718) INFO 07-21 03:10:46 model_runner.py:563] Loading model weights took 40.8975 GB
INFO 07-21 03:10:54 distributed_gpu_executor.py:56] # GPU blocks: 22421, # CPU blocks: 3640
>>> model.generate("Hello!")
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.19s/it, est. speed input: 1.69 toks/s, output: 13.49 toks/s]
[RequestOutput(request_id=0, prompt='Hello!', prompt_token_ids=[14716, 252397], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='8493045638869919', token_ids=(252377, 252372, 252370, 252365, 252334, 252372, 252366, 252376, 252365, 252377, 252377, 252376, 252370, 252370, 252338, 252370), cumulative_logprob=-38.44283938407898, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1721531506.653993, last_token_time=1721531506.653993, first_scheduled_time=1721531506.6587515, first_token_time=1721531506.8336384, time_in_queue=0.004758596420288086, finished_time=1721531507.8438315), lora_request=None)]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: how is the architecture different from Llama?

Copy link
Collaborator Author

@mgoin mgoin Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fairly certain it could be implemented inside our Llama implementation. I'm not sure how to deal with the absence of gate_proj though.

The key differences are:

  1. There is no gate_proj, just up_proj
  2. Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
  3. Squared ReLU instead of SwiGLU
  4. Adds a rotary_percentage to RoPE

This is a good overview of main changes: https://twitter.com/danielhanchen/status/1801671106266599770

@mgoin mgoin changed the title [WIP][Model] Support Nemotron [Model] Support Nemotron Jul 23, 2024
@mgoin mgoin marked this pull request as ready for review July 23, 2024 14:39
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 23, 2024
@mgoin mgoin changed the title [Model] Support Nemotron [Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) Jul 26, 2024
Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update documentation as well?

Comment on lines +162 to +165
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
Copy link
Member

@ywang96 ywang96 Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might also need to add this op to CPU similar to what I did here previously with QuickGELU.

Per offline discussion, no CPU op should need to be added since it's just calling torch API. @bigPYJ1151 It would be great if you can confirm that, thanks!

@mgoin mgoin merged commit 07278c3 into main Jul 26, 2024
72 checks passed
cadedaniel pushed a commit to cadedaniel/vllm-public that referenced this pull request Jul 27, 2024
@mgoin mgoin deleted the nemotron-support branch July 27, 2024 15:55
@zhangzx-uiuc
Copy link

This looks great! Just wondering has anyone tried serving the bf16 version on 2 8xA100 nodes, with ray?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: Support Nemotron-4-340B
6 participants