From dbf6ee22c3efb562fc3ed8f4d8cc7441f36f9464 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 21 Aug 2024 21:53:05 +0530 Subject: [PATCH 01/32] changes for internvl pipeline parallelism --- vllm/model_executor/models/internlm2.py | 46 ++++++++++++++++++------- vllm/model_executor/models/internvl.py | 6 ++-- vllm/model_executor/models/utils.py | 17 +++++++++ 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 499cdb43fc8b..e82a849d981a 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -23,6 +23,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class InternLM2MLP(nn.Module): @@ -213,6 +216,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -222,11 +226,13 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: InternLMDecoderLayer(config, cache_config, quant_config),prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -240,20 +246,30 @@ def forward( intermediate_tensors: IntermediateTensors = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None else: - hidden_states = self.tok_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -277,6 +293,8 @@ def __init__( self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -287,7 +305,7 @@ def forward( intermediate_tensors: IntermediateTensors, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -324,6 +342,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -332,6 +352,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b379c86c1912..513242295024 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -340,6 +340,8 @@ def __init__(self, nn.Linear(llm_hidden_size, llm_hidden_size)) self.img_context_token_id = None + self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors + def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -450,7 +452,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[SamplerOutput, IntermediateTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: inputs_embeds = self.language_model.model.get_input_embeddings( @@ -467,7 +469,7 @@ def forward( positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 91b414b1fd91..a82b12499426 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -12,6 +12,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.multimodal import BatchedTensors from vllm.utils import is_pin_memory_available +from vllm.sequence import IntermediateTensors def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): @@ -227,3 +228,19 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: if name.startswith(missing_layer_name): return True return False + + +def make_empty_intermediate_tensors_factory(keys: List[str], + hidden_size: int) -> Callable: + + def make_empty_intermediate_tensors( + batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + key: torch.zeros((batch_size, hidden_size), + dtype=dtype, + device=device) + for key in keys + }) + + return make_empty_intermediate_tensors \ No newline at end of file From 08b853807db02c7fda798f6f3bcc0d837713ddad Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Thu, 22 Aug 2024 12:31:04 +0530 Subject: [PATCH 02/32] Inclusion of InternVLChatModel in PP_SUPPORTED_MODELS --- vllm/config.py | 1 + vllm/model_executor/models/internlm2.py | 2 +- vllm/model_executor/models/utils.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7e62a727115e..ea5bab07c681 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -47,6 +47,7 @@ "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", + "InternVLChatModel" ] diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index e82a849d981a..dfba919e01f8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -7,7 +7,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index a82b12499426..ddcbce052540 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Optional, Protocol, Tuple +from typing import Dict, Iterable, List, Optional, Protocol, Tuple, Callable import torch import torch.nn as nn From d7f2d545a6b652d070b5fb9649bc2e31c54b6074 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 26 Aug 2024 10:50:06 +0530 Subject: [PATCH 03/32] refactor --- vllm/model_executor/models/internlm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index dfba919e01f8..250359b98aa3 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn From caaca9a0be34ca08f3807b1d4caf26555e38a56c Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 26 Aug 2024 11:37:05 +0530 Subject: [PATCH 04/32] refactor --- vllm/model_executor/models/internlm2.py | 3 ++- vllm/model_executor/models/internvl.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 250359b98aa3..99c56f5e597a 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -228,7 +228,8 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: InternLMDecoderLayer(config, cache_config, quant_config),prefix=f"{prefix}.layers") + lambda prefix: InternLMDecoderLayer(config, + cache_config, quant_config),prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index dbf210f1a109..03f6b97ffbbb 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -340,7 +340,8 @@ def __init__(self, nn.Linear(llm_hidden_size, llm_hidden_size)) self.img_context_token_id = None - self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) def pixel_shuffle(self, x, scale_factor=0.5): From 21376316d0ba772083f55e8f022e55ac6d9440b9 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 26 Aug 2024 11:43:07 +0530 Subject: [PATCH 05/32] refactor --- vllm/model_executor/models/internlm2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 99c56f5e597a..bdeb6a26631b 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -228,8 +228,10 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: InternLMDecoderLayer(config, - cache_config, quant_config),prefix=f"{prefix}.layers") + lambda prefix: InternLMDecoderLayer( + config, cache_config, quant_config + ), prefix=f"{prefix}.layers" + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( From 927e3f8a566985c4caaa424f556983af4fe20935 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 26 Aug 2024 11:53:06 +0530 Subject: [PATCH 06/32] refactor --- vllm/model_executor/models/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ddcbce052540..d5f431e97f29 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Optional, Protocol, Tuple, Callable +from typing import Callable, Dict, Iterable, List, Optional, Protocol, Tuple import torch import torch.nn as nn @@ -11,8 +11,8 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry from vllm.multimodal import BatchedTensors -from vllm.utils import is_pin_memory_available from vllm.sequence import IntermediateTensors +from vllm.utils import is_pin_memory_available def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): From 7e8ef5cfe6dad71c9240c1a2d96e3fd31f89c44e Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 26 Aug 2024 11:54:37 +0530 Subject: [PATCH 07/32] refactor --- vllm/model_executor/models/internlm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index bdeb6a26631b..91ce2455a95f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -23,11 +23,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput + from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) - class InternLM2MLP(nn.Module): def __init__( From c891114fed30f7e0e6d0de0c913865637185c6f1 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 26 Aug 2024 15:53:26 +0530 Subject: [PATCH 08/32] refactor --- vllm/config.py | 20 +++++--------------- vllm/model_executor/models/internlm2.py | 7 +++---- vllm/model_executor/models/internvl.py | 1 - 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 946dcd80caa7..7afe8adfbc79 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -34,21 +34,11 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _PP_SUPPORTED_MODELS = [ - "AquilaModel", - "AquilaForCausalLM", - "DeepseekV2ForCausalLM", - "InternLMForCausalLM", - "JAISLMHeadModel", - "LlamaForCausalLM", - "LLaMAForCausalLM", - "MistralForCausalLM", - "Phi3ForCausalLM", - "GPT2LMHeadModel", - "MixtralForCausalLM", - "NemotronForCausalLM", - "Qwen2ForCausalLM", - "Qwen2MoeForCausalLM", - "QWenLMHeadModel", + "AquilaModel", "AquilaForCausalLM", "DeepseekV2ForCausalLM", + "InternLMForCausalLM", "JAISLMHeadModel", "LlamaForCausalLM", + "LLaMAForCausalLM", "MistralForCausalLM", "Phi3ForCausalLM", + "GPT2LMHeadModel", "MixtralForCausalLM", "NemotronForCausalLM", + "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", "InternVLChatModel" ] diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 91ce2455a95f..3482d941cb89 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -228,10 +228,9 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: InternLMDecoderLayer( - config, cache_config, quant_config - ), prefix=f"{prefix}.layers" - ) + lambda prefix: InternLMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 03f6b97ffbbb..7b04ff6117c7 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -343,7 +343,6 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale From 654248c9e7845d5288726521ffc37aefaadcdc56 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 27 Aug 2024 11:27:29 +0530 Subject: [PATCH 09/32] Added the InternVL2-8B for testing the pipeline parallelism in test_pipeline_parallel.py --- tests/distributed/test_pipeline_parallel.py | 3 +++ tests/utils.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 4d54e43d5788..04ee5945e640 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -31,6 +31,7 @@ (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 2, 1, 1, "OpenGVLab/InternVL2-8B", "ray"), ]) @fork_new_process_for_each_test def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, @@ -72,6 +73,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") pp_env = None + pp_args.append("--trust-remote-code") + tp_args.append("--trust-remote-code") if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 and CHUNKED_PREFILL): # Test Ray ADAG for a subset of the tests diff --git a/tests/utils.py b/tests/utils.py index b73a05b5fe67..d9dfd0243084 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -175,7 +175,7 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ - tokenizer = AutoTokenizer.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] From cf155c2b93e296e226c380ed0b5e838a49fbf87c Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 28 Aug 2024 11:36:35 +0530 Subject: [PATCH 10/32] updating branch --- .buildkite/run-amd-test.sh | 1 + .buildkite/test-pipeline.yaml | 3 +- .github/workflows/mypy.yaml | 1 - CMakeLists.txt | 5 + benchmarks/launch_tgi_server.sh | 2 +- csrc/core/scalar_type.hpp | 3 +- csrc/moe/marlin_moe_ops.cu | 1740 +++++++++++++++++ csrc/moe/marlin_moe_ops.h | 12 + csrc/moe/torch_bindings.cpp | 12 + .../dev/multimodal/multimodal_index.rst | 2 - format.sh | 1 - pyproject.toml | 1 + requirements-rocm.txt | 1 + tests/conftest.py | 21 +- tests/models/test_llava_next.py | 93 +- tests/multimodal/test_base.py | 83 + tests/multimodal/test_utils.py | 35 +- tests/quantization/test_compressed_tensors.py | 2 +- tests/weight_loading/models.txt | 2 + vllm/_core_ext.py | 134 +- vllm/_custom_ops.py | 14 + vllm/block.py | 9 +- vllm/core/block/cpu_gpu_block_allocator.py | 2 +- vllm/core/block_manager_v1.py | 7 +- vllm/core/block_manager_v2.py | 8 +- vllm/core/embedding_model_block_manager.py | 4 +- vllm/core/scheduler.py | 21 +- vllm/engine/async_llm_engine.py | 5 + vllm/engine/llm_engine.py | 7 + vllm/executor/multiproc_gpu_executor.py | 38 +- vllm/executor/multiproc_xpu_executor.py | 26 + .../layers/fused_moe/__init__.py | 14 +- .../layers/fused_moe/fused_moe.py | 134 +- vllm/model_executor/layers/fused_moe/layer.py | 208 +- .../compressed_tensors/compressed_tensors.py | 5 + .../compressed_tensors_moe.py | 283 +++ .../model_executor/layers/quantization/fp8.py | 29 +- vllm/model_executor/model_loader/utils.py | 4 +- vllm/model_executor/models/blip2.py | 7 + vllm/model_executor/models/chameleon.py | 3 + vllm/model_executor/models/clip.py | 8 +- vllm/model_executor/models/fuyu.py | 3 + vllm/model_executor/models/internvl.py | 9 + vllm/model_executor/models/jamba.py | 2 +- vllm/model_executor/models/llava.py | 8 + vllm/model_executor/models/llava_next.py | 25 +- vllm/model_executor/models/minicpmv.py | 11 +- vllm/model_executor/models/mixtral.py | 1 + vllm/model_executor/models/paligemma.py | 8 + vllm/model_executor/models/phi3v.py | 12 +- vllm/model_executor/models/siglip.py | 4 +- vllm/model_executor/models/ultravox.py | 9 + vllm/model_executor/models/utils.py | 61 +- vllm/multimodal/__init__.py | 3 +- vllm/multimodal/base.py | 49 +- vllm/multimodal/utils.py | 51 +- vllm/platforms/cuda.py | 3 + vllm/platforms/rocm.py | 11 + vllm/utils.py | 9 + vllm/worker/model_runner.py | 5 +- vllm/worker/xpu_model_runner.py | 19 +- vllm/worker/xpu_worker.py | 6 + 62 files changed, 2992 insertions(+), 307 deletions(-) create mode 100644 csrc/moe/marlin_moe_ops.cu create mode 100644 csrc/moe/marlin_moe_ops.h create mode 100644 tests/multimodal/test_base.py create mode 100644 vllm/executor/multiproc_xpu_executor.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index ccc2f090565e..5548071390af 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -75,6 +75,7 @@ docker run \ --network host \ --shm-size=16gb \ --rm \ + -e HIP_VISIBLE_DEVICES=0 \ -e HF_TOKEN \ -v ${HF_CACHE}:${HF_MOUNT} \ -e HF_HOME=${HF_MOUNT} \ diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e40693864747..9f449ff650b9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -233,12 +233,13 @@ steps: parallelism: 4 - label: Tensorizer Test # 11min + mirror_hardwares: [amd] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader - tests/tensorizer_loader commands: - - apt-get install -y curl libsodium23 + - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s tensorizer_loader diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 3474bd386159..ea767f4c3e26 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -35,7 +35,6 @@ jobs: mypy mypy tests --follow-imports skip mypy vllm/attention --follow-imports skip - mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip mypy vllm/executor --follow-imports skip diff --git a/CMakeLists.txt b/CMakeLists.txt index ab91b86426cd..5b0d0ba904c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -296,6 +296,11 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") +if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/marlin_moe_ops.cu") +endif() + define_gpu_extension_target( _moe_C DESTINATION vllm diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index f491c90d0683..8c5cd454fbbe 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -6,7 +6,7 @@ TOKENS=$2 docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ - ghcr.io/huggingface/text-generation-inference:1.4.0 \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ --model-id $MODEL \ --sharded false \ --max-input-length 1024 \ diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index b1e10fecb6b5..0e1f360d74bd 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { // This needs to be implemented and throw a TypeError in order for // PyTorch's opcheck to work on ops that use ScalarTypes. int64_t len() const { - throw c10::TypeError("__len__ not implemented"); + throw c10::TypeError({__func__, __FILE__, static_cast(__LINE__)}, + "__len__ not implemented"); return 0; } diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu new file mode 100644 index 000000000000..1e170e80d2f7 --- /dev/null +++ b/csrc/moe/marlin_moe_ops.cu @@ -0,0 +1,1740 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / blockDim.x; + int rest = size_k % blockDim.x; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += blockDim.x; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; + expert_offsets[i + 1] = tot_offset; + } + } + __syncthreads(); +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + constexpr int sorted_sh_stride = threads; + constexpr int sorted_gl_stride = threads; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (group_blocks == -1 || group_blocks == 0) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int* sh_sorted = (int*)(sh_s + shs_size); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // TODO we are currently hitting illegal memory accesses when fetching + // sorted_ids to shared data: fix this + auto fetch_sorted_ids_to_shared = [&]() { + const int mpt = ceildiv(prob_m, threads); + for (int i = 0; i < mpt; i++) { + if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { + sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = + sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + } + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + + FragB frag_b0 = dequant(b_quant); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + FragB frag_b1 = dequant(b_quant_shift); + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + __syncthreads(); + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par // maximum parallelism +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / 64; + par = min((16 * max_block - pad) / 64, max_par); + prob_m = 64 * par; + m_block_ctr += 4 * (par - 1); + max_block = 4; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par // maximum parallelism +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + + " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + + str(prob_n) + "]"); + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int tot_m = prob_m; + + const int* topk_ids_ptr = (const int*)topk_ids; + int* expert_offsets_ptr = (int*)expert_offsets; + compute_expert_offsets<<<1, num_experts, 0, stream>>>( + topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); + + bool do_permute_a = has_act_order; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const int4* A_ptr = (const int4*)A; + int4* a_tmp_ptr = (int4*)a_tmp; + const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + int4* C_ptr = (int4*)C; + const float* topk_weights_ptr = (const float*)topk_weights; + const int* sorted_ids_ptr = (const int*)sorted_ids; + const int4* s_ptr = + (const int4*)s + + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * + prob_n / 8) * + expert_idx; + const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; + const int* perm_ptr = (const int*)perm + prob_k * expert_idx; + int* locks = (int*)workspace; + + if (do_permute_a) { + // Permute A columns + int topk_rows = replicate_input ? tot_m : tot_m * topk; + int block_rows = ceildiv(topk_rows, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + int max_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { + // Define kernel configurations + + // make it max possible value + int thread_m_blocks = 4; + + if (false) { + } + CALL_IF_MOE(16, 4, 256) + CALL_IF_MOE(8, 8, 256) + CALL_IF_MOE(8, 4, 128) + CALL_IF_MOE(4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + } + } +} + +} // namespace marlin_moe + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + const torch::Tensor& g_idx, const torch::Tensor& perm, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights) { + int max_par = 4; + + int dev = a.get_device(); + + auto options_dtype = + torch::TensorOptions().dtype(a.dtype()).device(a.device()); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); + torch::Tensor a_tmp = + replicate_input ? torch::zeros({size_m, size_k}, options_dtype) + : torch::zeros({size_m, topk, size_k}, options_dtype); + torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(1) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + marlin_moe::marlin_mm_moe_f16i4( + a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), + topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), + has_act_order, is_k_full, num_groups, group_size, num_experts, topk, + moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, max_par, replicate_input, apply_weights); + return c; +} \ No newline at end of file diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h new file mode 100644 index 000000000000..01ba8ff69850 --- /dev/null +++ b/csrc/moe/marlin_moe_ops.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + const torch::Tensor& g_idx, const torch::Tensor& perm, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights); \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 86e42af44df1..d4d43e2c601b 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,5 +1,6 @@ #include "core/registration.h" #include "moe_ops.h" +#include "marlin_moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. @@ -7,6 +8,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + +#ifndef USE_ROCM + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " + "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " + "bool replicate_input, bool apply_weights) -> Tensor"); + + m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index a45bc885dc12..241b2ccd0991 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -45,8 +45,6 @@ Base Classes .. autodata:: vllm.multimodal.NestedTensors -.. autodata:: vllm.multimodal.BatchedTensors - .. autodata:: vllm.multimodal.BatchedTensorInputs .. autoclass:: vllm.multimodal.MultiModalDataBuiltins diff --git a/format.sh b/format.sh index 9e0780870303..2204b3ba5949 100755 --- a/format.sh +++ b/format.sh @@ -99,7 +99,6 @@ echo 'vLLM mypy:' mypy --follow-imports skip # Note that this is less strict than CI mypy tests --follow-imports skip mypy vllm/attention --follow-imports skip -mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip mypy vllm/executor --follow-imports skip diff --git a/pyproject.toml b/pyproject.toml index bcedbb53ab88..22a25d9cf32e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ files = [ "vllm/adapter_commons", "vllm/assets", "vllm/entrypoints", + "vllm/core", "vllm/inputs", "vllm/logging", "vllm/multimodal", diff --git a/requirements-rocm.txt b/requirements-rocm.txt index cc955e279a84..121123611d2d 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -8,3 +8,4 @@ botocore ray >= 2.10.0 peft pytest-asyncio +tensorizer>=2.9.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index ae362b228d9d..d8264f65b614 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,10 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] +PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]] +PromptAudioInput = Union[List[Tuple[np.ndarray, int]], + List[List[Tuple[np.ndarray, int]]]] + def _read_prompts(filename: str) -> List[str]: with open(filename, "r") as f: @@ -161,7 +165,7 @@ def example_encoder_decoder_prompts( decoder prompt) tuple. Returns: - + * Encoder prompt list * Decoder prompt list (reverse of encoder prompt list) ''' @@ -578,8 +582,7 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, + images: Optional[PromptImageInput] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) @@ -623,10 +626,8 @@ def generate_w_logprobs( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, - audios: Optional[Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]]] = None + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None @@ -676,10 +677,8 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, - audios: Optional[Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]]] = None, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 9cf55c0858df..d5fe0cbe3288 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -6,24 +6,22 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm -_PREFACE = ( - "A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's " - "questions.") +_LIMIT_IMAGE_PER_PROMPT = 4 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": - f"{_PREFACE} USER: \nWhat's the content of the image? ASSISTANT:", + "[INST] \nWhat's the content of the image? [/INST]", "cherry_blossom": - f"{_PREFACE} USER: \nWhat is the season? ASSISTANT:", + "[INST] \nWhat is the season? [/INST]", }) -models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] +models = ["llava-hf/llava-v1.6-mistral-7b-hf"] def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -114,19 +112,43 @@ def run_test( else: raise ValueError("You must provide either `size_factors` or `sizes`") + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, - max_model_len=4096, + max_model_len=10240, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] with hf_runner(model, dtype=dtype, @@ -136,7 +158,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "[INST] \nDescribe 2 images. [/INST]", + "[INST] \nDescribe 2 images. [/INST]", + "[INST] \nDescribe 4 images. [/INST]", + "[INST] \nWhat is the season? [/INST]" + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], + [ + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py new file mode 100644 index 000000000000..f19a0f33fe06 --- /dev/null +++ b/tests/multimodal/test_base.py @@ -0,0 +1,83 @@ +import torch + +from vllm.multimodal.base import MultiModalInputs, NestedTensors + + +def assert_nested_tensors_equal(expected: NestedTensors, + actual: NestedTensors): + assert type(expected) == type(actual) + if isinstance(expected, torch.Tensor): + assert torch.equal(expected, actual) + else: + for expected_item, actual_item in zip(expected, actual): + assert_nested_tensors_equal(expected_item, actual_item) + + +def assert_multimodal_inputs_equal(expected: MultiModalInputs, + actual: MultiModalInputs): + assert set(expected.keys()) == set(actual.keys()) + for key in expected: + assert_nested_tensors_equal(expected[key], actual[key]) + + +def test_multimodal_input_batch_single_tensor(): + t = torch.rand([1, 2]) + result = MultiModalInputs.batch([{"image": t}]) + assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) + + +def test_multimodal_input_batch_multiple_tensors(): + a = torch.rand([1, 1, 2]) + b = torch.rand([1, 1, 2]) + c = torch.rand([1, 1, 2]) + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) + + +def test_multimodal_input_batch_multiple_heterogeneous_tensors(): + a = torch.rand([1, 2, 2]) + b = torch.rand([1, 3, 2]) + c = torch.rand([1, 4, 2]) + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) + + +def test_multimodal_input_batch_nested_tensors(): + a = torch.rand([2, 3]) + b = torch.rand([2, 3]) + c = torch.rand([2, 3]) + result = MultiModalInputs.batch([{ + "image": [a] + }, { + "image": [b] + }, { + "image": [c] + }]) + assert_multimodal_inputs_equal(result, { + "image": + torch.stack([a.unsqueeze(0), + b.unsqueeze(0), + c.unsqueeze(0)]) + }) + + +def test_multimodal_input_batch_heterogeneous_lists(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 2, 3]) + c = torch.rand([1, 2, 3]) + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + assert_multimodal_inputs_equal( + result, + {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) + + +def test_multimodal_input_batch_multiple_batchable_lists(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 2, 3]) + c = torch.rand([1, 2, 3]) + d = torch.rand([1, 2, 3]) + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}]) + assert_multimodal_inputs_equal( + result, + {"image": torch.stack([torch.stack([a, b]), + torch.stack([c, d])])}) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index cd1fc91c2937..38cd48629f90 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -6,8 +6,10 @@ import numpy as np import pytest from PIL import Image +from transformers import AutoConfig, AutoTokenizer -from vllm.multimodal.utils import async_fetch_image, fetch_image +from vllm.multimodal.utils import (async_fetch_image, fetch_image, + repeat_and_pad_placeholder_tokens) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], data_image_async = await async_fetch_image(data_url) assert _image_equals(data_image_sync, data_image_async) + + +@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +def test_repeat_and_pad_placeholder_tokens(model): + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + + test_cases = [ + ("", 2, "", [32000, 32000]), + ("", 2, "", [32000, 32000, 32000]), + ("", [3, 2], "", + [32000, 32000, 32000, 32000, 32000]), + ("Image:Image:!", [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), + ("", [3, 2], "", [32000, 32000, 32000]), + ] + + for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer=tokenizer, + prompt=prompt, + prompt_token_ids=tokenizer.encode(prompt, + add_special_tokens=False), + placeholder_token_id=image_token_id, + repeat_count=repeat_count, + ) + assert new_prompt == expected_prompt + assert new_token_ids == expected_token_ids diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 2ea340779b81..7dd20636c892 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -160,4 +160,4 @@ def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: output = llm.generate_greedy("Hello world!", max_tokens=20) - assert output + assert output \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 70d6ffc70367..cbe30305c14f 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -13,6 +13,8 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py index aa520e1eafba..a27b8648bee4 100644 --- a/vllm/_core_ext.py +++ b/vllm/_core_ext.py @@ -181,92 +181,98 @@ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, ScalarType = torch.classes._core_C.ScalarType - # Needed for dynamo support of ScalarType. - @torch._library.register_fake_class("_core_C::ScalarType") - class FakeScalarType: + if (hasattr(torch, "_library") + and hasattr(torch._library, "register_fake_class")): + # Needed for dynamo support of ScalarType. + @torch._library.register_fake_class("_core_C::ScalarType") + class FakeScalarType: - def __init__(self, scalar_type): - self.ScalarType = scalar_type + def __init__(self, scalar_type): + self.ScalarType = scalar_type - def bias_getter(self) -> int: - return self.ScalarType.bias + def bias_getter(self) -> int: + return self.ScalarType.bias - def exponent_getter(self) -> int: - return self.ScalarType.exponent + def exponent_getter(self) -> int: + return self.ScalarType.exponent - def mantissa_getter(self) -> int: - return self.ScalarType.mantissa + def mantissa_getter(self) -> int: + return self.ScalarType.mantissa - def signed_getter(self) -> bool: - return self.ScalarType.signed + def signed_getter(self) -> bool: + return self.ScalarType.signed - def size_bits_getter(self) -> int: - return self.ScalarType.size_bits + def size_bits_getter(self) -> int: + return self.ScalarType.size_bits - @property - def size_bits(self) -> int: - return self.ScalarType.size_bits + @property + def size_bits(self) -> int: + return self.ScalarType.size_bits - def min(self) -> Union[int, float]: - return self.ScalarType.min() + def min(self) -> Union[int, float]: + return self.ScalarType.min() - def max(self) -> Union[int, float]: - return self.ScalarType.max() + def max(self) -> Union[int, float]: + return self.ScalarType.max() - def is_signed(self) -> bool: - return self.ScalarType.is_signed() + def is_signed(self) -> bool: + return self.ScalarType.is_signed() - def is_floating_point(self) -> bool: - return self.ScalarType.is_floating_point() + def is_floating_point(self) -> bool: + return self.ScalarType.is_floating_point() - def is_integer(self) -> bool: - return self.ScalarType.is_integer() + def is_integer(self) -> bool: + return self.ScalarType.is_integer() - def has_bias(self) -> bool: - return self.ScalarType.has_bias() + def has_bias(self) -> bool: + return self.ScalarType.has_bias() - def has_infs(self) -> bool: - return self.ScalarType.has_infs() + def has_infs(self) -> bool: + return self.ScalarType.has_infs() - def has_nans(self) -> bool: - return self.ScalarType.has_nans() + def has_nans(self) -> bool: + return self.ScalarType.has_nans() - def is_ieee_754(self) -> bool: - return self.ScalarType.is_ieee_754() + def is_ieee_754(self) -> bool: + return self.ScalarType.is_ieee_754() - def __str__(self) -> str: - return self.ScalarType.__str__() + def __str__(self) -> str: + return self.ScalarType.__str__() - def __repr__(self) -> str: - return self.ScalarType.__repr__() + def __repr__(self) -> str: + return self.ScalarType.__repr__() - def __len__(self) -> int: - return self.ScalarType.__len__() + def __len__(self) -> int: + return self.ScalarType.__len__() - def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: - return torch.classes._core_C.ScalarType.__obj_flatten__( - self.ScalarType) + def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: + return torch.classes._core_C.ScalarType.__obj_flatten__( + self.ScalarType) - @classmethod - def __obj_unflatten__( - cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType': - return cls( - torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type)) + @classmethod + def __obj_unflatten__( + cls, flat_type: Tuple[Tuple[str, Any], + ...]) -> 'ScalarType': + return cls( + torch.classes._core_C.ScalarType.__obj_unflatten__( + flat_type)) - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.int_(size_bits, bias) + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.int_(size_bits, bias) - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.uint(size_bits, bias) + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.uint(size_bits, bias) - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - return ScalarType.float_IEEE754(exponent, mantissa) + @classmethod + def float_IEEE754(cls, exponent: int, + mantissa: int) -> 'ScalarType': + return ScalarType.float_IEEE754(exponent, mantissa) - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: int) -> 'ScalarType': - return ScalarType.float_(exponent, mantissa, finite_values_only, - nan_repr) + @classmethod + def float_(cls, exponent: int, mantissa: int, + finite_values_only: bool, + nan_repr: int) -> 'ScalarType': + return ScalarType.float_(exponent, mantissa, + finite_values_only, nan_repr) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b89a90ef0f70..ae90af563c0c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,6 +300,20 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * 2), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], + size_k, size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, diff --git a/vllm/block.py b/vllm/block.py index 95286048d911..47c381c19383 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -1,9 +1,9 @@ """Token blocks.""" -from typing import List, Optional +from typing import TYPE_CHECKING, Iterator, List, Optional from vllm.utils import Device -DEFAULT_LAST_ACCESSED_TIME = -1 +DEFAULT_LAST_ACCESSED_TIME: float = -1 class PhysicalTokenBlock: @@ -59,6 +59,11 @@ def __len__(self) -> int: def __getitem__(self, key): return self._blocks[key] + if TYPE_CHECKING: + + def __iter__(self) -> Iterator[PhysicalTokenBlock]: + raise RuntimeError("Method should be automatically generated") + def __setitem__(self, key, value): if isinstance(key, slice): blocks = value diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index c6330df2a485..c87246c1c6d6 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -132,7 +132,7 @@ def allocate_mutable_block(self, prev_block: Optional[Block], def allocate_immutable_blocks(self, prev_block: Optional[Block], block_token_ids: List[List[int]], - device: Optional[Device]) -> List[Block]: + device: Device) -> List[Block]: """Allocates a new group of immutable blocks with the provided block token IDs on the specified device. diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 0af04399a4b3..666723313c82 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -278,7 +278,7 @@ def __init__( # request ID self.cross_block_tables: Dict[str, BlockTable] = {} - def _get_seq_num_required_blocks(self, seq: Sequence) -> int: + def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: return 0 if seq is None else seq.n_blocks def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: @@ -310,13 +310,14 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: return AllocStatus.LATER def _allocate_sequence(self, \ - seq: Sequence, \ + seq: Optional[Sequence], \ ref_count: int, \ is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. - num_prompt_blocks = seq.n_blocks + num_prompt_blocks = self._get_seq_num_required_blocks(seq) block_table: BlockTable = BlockTable() + assert seq is not None for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 7d4919a0d94a..7d2db43cb460 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -120,8 +120,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: ) if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None num_required_blocks += BlockTable.get_num_required_blocks( - seq_group.get_encoder_seq().get_token_ids(), + encoder_seq.get_token_ids(), block_size=self.block_size, ) @@ -189,7 +191,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) if seq_group.is_encoder_decoder(): - block_table = self._allocate_sequence(seq_group.get_encoder_seq()) + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index 3d864a73f91d..f16f66e99e7f 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -77,8 +77,8 @@ def access_all_blocks_in_seq( pass def get_common_computed_block_ids(self, - seq_group: SequenceGroup) -> List[int]: - return None # type: ignore + seq_group: List[Sequence]) -> List[int]: + return [] def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 280d7b7e61e2..a4a4285cdf3a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -221,10 +221,10 @@ class SchedulerSwappedInOutputs: """ # Selected sequences that are going to be swapped in and is in a # decoding phase. - decode_seq_groups: List[SequenceGroup] + decode_seq_groups: List[ScheduledSequenceGroup] # Selected sequences that are going to be swapped in and in a prefill # phase. I.e., it means the prefill has been chunked. - prefill_seq_groups: List[SequenceGroup] + prefill_seq_groups: List[ScheduledSequenceGroup] # The blocks to swap in. blocks_to_swap_in: List[Tuple[int, int]] # The blocks to copy. @@ -254,7 +254,7 @@ class SchedulerPrefillOutputs: to be recomputed from scratch. """ # Selected sequences for prefill. - seq_groups: List[SequenceGroup] + seq_groups: List[ScheduledSequenceGroup] # Ignored sequence groups. ignored_seq_groups: List[SequenceGroup] num_lookahead_slots: int @@ -289,7 +289,9 @@ def scheduler_running_outputs_builder(): def scheduled_seq_group_builder(): - return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + return ScheduledSequenceGroup(SequenceGroup("", [], -1), + token_chunk_size=0) + # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) class Scheduler: @@ -791,7 +793,7 @@ def _schedule_prefills( SchedulerPrefillOutputs. """ ignored_seq_groups: List[SequenceGroup] = [] - seq_groups: List[SequenceGroup] = [] + seq_groups: List[ScheduledSequenceGroup] = [] waiting_queue = self.waiting @@ -1086,8 +1088,9 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: ) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: - no_beam_search = (seq_group.sampling_params.best_of == 1 - and not seq_group.sampling_params.use_beam_search) + no_beam_search = seq_group.sampling_params is None or ( + seq_group.sampling_params.best_of == 1 + and not seq_group.sampling_params.use_beam_search) return no_beam_search @@ -1130,7 +1133,9 @@ def schedule( if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup - encoder_seq_data = seq_group.get_encoder_seq().data + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + encoder_seq_data = encoder_seq.data # Block table for cross-attention # Also managed at SequenceGroup level cross_block_table = self.block_manager.get_cross_block_table( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3445b7084bbc..10e14ff996f3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -666,6 +666,11 @@ def _get_executor_cls( initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync executor_class = RayXPUExecutorAsync + elif distributed_executor_backend == "mp": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.multiproc_xpu_executor import ( + MultiprocessingXPUExecutorAsync) + executor_class = MultiprocessingXPUExecutorAsync else: raise RuntimeError( "Not supported distributed execution model on XPU device.") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7356c1abbfa8..addde032f263 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -472,6 +472,13 @@ def _get_executor_cls(cls, initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutor executor_class = RayXPUExecutor + elif distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, Please try ray instead.") else: from vllm.executor.xpu_executor import XPUExecutor executor_class = XPUExecutor diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08a35a074b37..7b98fbea5cd0 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): uses_ray: bool = False def _init_executor(self) -> None: + self._check_executor_parameters() + # Create the parallel GPU workers. world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() @@ -68,16 +64,6 @@ def _init_executor(self) -> None: if world_size > 1: maybe_set_triton_cache_manager() - cuda_device_count = cuda_device_count_stateless() - # Use confusing message for more common TP-only case. - assert tensor_parallel_size <= cuda_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - assert world_size <= cuda_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") - # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. @@ -139,6 +125,26 @@ def shutdown(signum, frame): max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + def _check_executor_parameters(self): + world_size = self.parallel_config.tensor_parallel_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= cuda_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + assert world_size <= cuda_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + def shutdown(self): if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: diff --git a/vllm/executor/multiproc_xpu_executor.py b/vllm/executor/multiproc_xpu_executor.py new file mode 100644 index 000000000000..a66afbf939ef --- /dev/null +++ b/vllm/executor/multiproc_xpu_executor.py @@ -0,0 +1,26 @@ +import vllm.envs as envs +from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) +from vllm.executor.xpu_executor import XPUExecutor +from vllm.logger import init_logger +from vllm.utils import make_async + +logger = init_logger(__name__) + + +class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor): + """Python multiprocessing-based multi-XPU executor""" + + def _check_executor_parameters(self): + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + if mp_method != "spawn": + raise RuntimeError( + "XPU multiprocess executor only support spawn as mp method") + + +class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor, + MultiprocessingGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3e0767c7d266..fd6f41b90042 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,19 +1,17 @@ -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -__all__ = [ - "FusedMoE", - "FusedMoEMethodBase", -] +__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + fused_experts, fused_marlin_moe, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ + "fused_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bcf25d263104..d2b152320e11 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,21 +323,16 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: +def get_default_config(M: int, E: int, N: int, K: int, topk: int, + dtype: Optional[str], + is_marlin: bool) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } - if M <= E: + if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, @@ -347,14 +342,14 @@ def get_default_config( return config -def try_get_optimal_moe_config( - w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, Any]] = None, -): +def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, + Any]] = None, + is_marlin: bool = False): if override_config: config = override_config else: @@ -368,7 +363,8 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin) return config @@ -441,6 +437,108 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids +def fused_marlin_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + #TODO fp8 is not implemented yet + assert not use_fp8 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + get_config_func = functools.partial(try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, + g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, + block_size_m, True, False) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, + w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, + block_size_m, False, True) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) + + def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4e29ab701b93..61ebef5e11f4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from enum import Enum from typing import List, Optional, Tuple import torch @@ -15,6 +16,12 @@ logger = init_logger(__name__) +class FusedMoeWeightScaleSupported(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + + class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod @@ -199,55 +206,182 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) + def _load_per_tensor_weight_scale(self, shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale(self, shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # Load grouped weight scales for group quantization + # or model weights + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + shard_dim: int, shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + def _load_single_value(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + param_data[expert_id] = loaded_weight + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") - # Special case for fp8 scales. - if getattr(param, "is_fp8_scale", False): - self._load_fp8_scale(param.data, loaded_weight, weight_name, - shard_id, expert_id) - return + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # If transposed, weight is saved as [input_dim, output_dim] - # Otherwise, weight is saved as [output_dim, input_dim] - # Default is not transposed/input dim is dim 1 - input_dim = getattr(param, "input_dim", 1) - output_dim = getattr(param, "output_dim", 0) + # is_transposed: whether or not the parameter is transposed on disk + # If transposed, the loaded weight will be transposed and the dim + # to shard the loaded weight will be flipped. + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + loaded_weight = loaded_weight.t().contiguous() + shard_dim = ~shard_dim + + # Case weight_scales + if "weight_scale" in weight_name: + # load the weight scaling based on the quantization scheme + # supported weight scales can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", None) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + return - # Index the loaded weight for tp sharding. - # down_proj: "RowParallel" so tp sharding on input_dim - if shard_id == "w2": - shard_dim = input_dim - shard_size = expert_data.shape[shard_dim] - # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - elif shard_id in ("w1", "w3"): - shard_dim = output_dim - shard_size = expert_data.shape[output_dim] // 2 - offset = shard_size * tp_rank - loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) + if "weight_shape" in weight_name: + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return - # Narrow parameter and load. - # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) - expert_data.copy_(loaded_weight) - # w3, up_proj: Load into second logical weight of w13. - elif shard_id == "w3": - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - expert_data.copy_(loaded_weight) - # w2, down_proj: Load into only logical weight of w2. - elif shard_id == "w2": - expert_data.copy_(loaded_weight) - else: - raise ValueError( - f"Expected shard_id w1,w2 or w3 but got {shard_id}") + # Case input scale + if "input_scale" in weight_name: + # Note: input_scale loading is only supported for fp8 + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return @staticmethod def select_experts(hidden_states: torch.Tensor, @@ -342,4 +476,4 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight + param_data[expert_id] = loaded_weight \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index f0e0b9db8088..0768b37044aa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -3,10 +3,13 @@ import torch from pydantic import BaseModel +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, @@ -69,6 +72,8 @@ def get_quant_method( return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod(self) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 000000000000..0e0ab9ce9169 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,283 @@ +import enum +from enum import Enum +from typing import List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + CompressionFormat) +from vllm.model_executor.utils import set_weight_attrs + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = ["CompressedTensorsMoEMethod"] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy.value + self.group_size = config.group_size + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size // + self.packed_factor, + 2 * intermediate_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + intermediate_size // + self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = intermediate_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int, num_bits: int): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, + scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, + size_n: int, group_size: int, + num_bits: int): + num_experts = s.shape[0] + output = torch.empty((num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, + group_size, num_bits) + return output + + size_k2 = layer.w2_weight_packed.shape[2] + size_k13 = layer.w13_weight_packed.shape[2] + + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + size_k13, + layer.w13_weight_scale.shape[2], + self.group_size, + self.num_bits, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] * self.packed_factor, + size_k2, + self.group_size, + self.num_bits, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_marlin_moe) + + return fused_marlin_moe(x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b10988b992ae..1817dbcb023a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -332,19 +333,16 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) - + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) - set_weight_attrs(w2_weight_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -357,19 +355,14 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: layer.w13_input_scale = None layer.w2_input_scale = None diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 331b859d2ade..4bb943ab3afe 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,11 +23,11 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. + mixtral_supported = ["fp8", "compressed-tensors"] if (model_config.quantization is not None - and model_config.quantization != "fp8" + and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - return ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 20dda2a67820..7c9123079c44 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -555,6 +555,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return Blip2ImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -564,6 +567,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return Blip2ImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index a335e1766b2a..2d4f172ce0be 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -946,6 +946,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return ChameleonImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 093396605533..69bb9f6f3afe 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,7 +1,7 @@ -"""Minimal implementation of CLIPVisionModel intended to be only used +"""Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" from array import array -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -84,7 +84,7 @@ def input_processor_for_clip( llm_inputs: LLMInputs, *, image_token_id: int, - image_feature_size_override: Optional[int] = None, + image_feature_size_override: Optional[Union[int, List[int]]] = None, ): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: @@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class CLIPEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`CLIPEncoderLayer`]. Args: diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index cfc2a5288a37..6cdf331fed8b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -249,6 +249,9 @@ def _parse_and_validate_image_input( image_patches = kwargs.pop("image_patches", None) if isinstance(image_patches, torch.Tensor): + # Remove the N dimension until multiple images are supported. + image_patches = image_patches.squeeze(1) + expected_feature_size = self.image_feature_size if image_patches.size(-1) != expected_feature_size: raise ValueError( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7b04ff6117c7..0e88d055e14f 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): min_num, max_num, use_thumbnail=use_thumbnail) + # Add an N dimension for number of images per prompt (currently 1). + data = data.unsqueeze(0) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -412,6 +414,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Flatten the B and N dimensions + image_embeds = image_embeds.flatten(0, 2) + return InternVLImageEmbeddingInputs( type="image_embeds", data=image_embeds, @@ -424,6 +430,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Flatten the B and N dimensions + pixel_values = pixel_values.flatten(0, 2) + return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b82eb14fb5f2..caeda4e42d8a 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -920,7 +920,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 6433ea380cbf..03a0abf1db48 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -232,6 +232,10 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, torch.Tensor): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -241,6 +245,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return LlavaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c7cb243fa84d..3a8724295411 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -19,6 +19,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_image_feature_size, @@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): input_height=height, input_width=width, ) + elif is_list_of(image_data, Image.Image): + image_feature_size = [ + get_llava_next_image_feature_size(hf_config, + input_height=img.height, + input_width=img.width) + for img in image_data + ] elif isinstance(image_data, torch.Tensor): image_feature_size = image_data.shape[0] else: @@ -353,6 +361,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + # Remove the N dimension until multiple images are supported. + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.squeeze(1) + else: + pixel_values = [t.squeeze(0) for t in pixel_values] + + image_sizes = image_sizes.squeeze(1) + return LlavaNextImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -364,6 +380,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return LlavaNextImageEmbeddingInputs( type="image_embeds", data=image_embeds, @@ -425,7 +444,10 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) - other_patch_embeds = other_patch_embeds \ + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ .view(num_patch_height, num_patch_width, height, width, -1) if "unpad" in strategy: @@ -496,7 +518,6 @@ def _process_image_input( self, image_input: LlavaNextImageInputs, ) -> Union[torch.Tensor, List[torch.Tensor]]: - if image_input["type"] == "image_embeds": return [image_input["data"]] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 29f3640e2458..6a3d5422e0ce 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -594,9 +594,14 @@ def _parse_and_validate_inputs( pixel_values_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = [] - for b in range(len(pixel_values)): - pixel_values_flat += pixel_values[b] - tgt_sizes_flat += tgt_sizes[b] + for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): + if len(pixel_b) != len(tgt_b): + raise ValueError("Inconsistent N lengths, found: " + f"{len(pixel_b)} vs {len(tgt_b)}") + + for pixel_n, tgt_n in zip(pixel_b, tgt_b): + pixel_values_flat += pixel_n + tgt_sizes_flat += tgt_n # NOTE: Input IDs does not contain image tokens during memory profiling, # so we allow it to be empty diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 34f581ac7858..413783ba4b25 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -73,6 +73,7 @@ def __init__(self, self.hidden_size = hidden_size # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8cb5065ed79e..0700f0c29d70 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -185,6 +185,10 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, torch.Tensor): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return PaliGemmaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -194,6 +198,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return PaliGemmaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4872929ec36c..61f1d7397637 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -422,7 +422,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): prompt = llm_inputs.get("prompt") if prompt is None: - image_idx = [] + # for async server request, we assume prompt and its token_ids is always + # in correct format. And num_image_tags == len(image_data) always True. + image_idx = range(1, len(image_data) + 1) new_prompt = None else: image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) @@ -558,6 +560,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + # Merge the B and N dimensions. + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.flatten(0, 1) + else: + pixel_values = torch.cat(pixel_values) + + image_sizes = image_sizes.flatten(0, 1) + return Phi3VImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 7f6186fa010a..073f60bb3a05 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -3,7 +3,7 @@ import math from array import array -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from PIL import Image @@ -93,7 +93,7 @@ def input_processor_for_siglip( llm_inputs: LLMInputs, *, image_token_id: int, - image_feature_size_override: Optional[int] = None, + image_feature_size_override: Optional[Union[int, List[int]]] = None, ): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 842264f76586..c81c2fd114eb 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -333,6 +333,12 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio features. " f"Got type: {type(audio_features)}") + # Remove the N dimension until multiple audios are supported. + if isinstance(audio_features, torch.Tensor): + audio_features = audio_features.squeeze(1) + else: + audio_features = [t.squeeze(0) for t in audio_features] + return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features) @@ -341,6 +347,9 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio embeds. " f"Got type: {type(audio_embeds)}") + # Remove the N dimension until multiple audios are supported. + audio_embeds = audio_embeds.squeeze(1) + return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index d5f431e97f29..0acc843cbebf 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,6 @@ from typing import Callable, Dict, Iterable, List, Optional, Protocol, Tuple +import numpy as np import torch import torch.nn as nn from torch.func import functional_call @@ -10,8 +11,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry -from vllm.multimodal import BatchedTensors -from vllm.sequence import IntermediateTensors +from vllm.multimodal.base import NestedTensors from vllm.utils import is_pin_memory_available @@ -55,9 +55,34 @@ def init_vllm_registered_model( ) +def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: + """ + Recursively concatenates NestedTensors along any heterogeneously sized + dimensions. + """ + + if isinstance(embeddings, torch.Tensor): + return embeddings + + return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) + + +def _embedding_count_expression(embeddings: NestedTensors) -> str: + """ + Constructs a debugging representation of the number of embeddings in the + NestedTensors. + """ + + if isinstance(embeddings, torch.Tensor): + return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) + + return " + ".join( + _embedding_count_expression(inner) for inner in embeddings) + + def merge_multimodal_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, - multimodal_embeddings: BatchedTensors, + multimodal_embeddings: NestedTensors, placeholder_token_id: int) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -70,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, mask = (input_ids == placeholder_token_id) num_expected_tokens = mask.sum() - if isinstance(multimodal_embeddings, torch.Tensor): - batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape - total_tokens = batch_size * batch_tokens - if num_expected_tokens != total_tokens: - expr = f"{batch_size} x {batch_tokens}" - raise ValueError( - f"Attempted to assign {expr} = {total_tokens} " - f"multimodal tokens to {num_expected_tokens} placeholders") - - inputs_embeds[mask] = multimodal_embeddings.view( - total_tokens, embed_dim) - else: - size_per_batch = [t.shape[0] for t in multimodal_embeddings] - total_tokens = sum(size_per_batch) - if num_expected_tokens != total_tokens: - expr = ' + '.join(map(str, size_per_batch)) - raise ValueError( - f"Attempted to assign {expr} = {total_tokens} " - f"multimodal tokens to {num_expected_tokens} placeholders") - - inputs_embeds[mask] = torch.cat(multimodal_embeddings) + flattened = _flatten_embeddings(multimodal_embeddings) + *dims, embed_dim = flattened.shape + num_multimodal_embeddings = np.prod(dims) + if num_multimodal_embeddings != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {num_multimodal_embeddings} " + f"multimodal tokens to {num_expected_tokens} placeholders") + inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) return inputs_embeds diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 456e41ebfad0..489e1e51f05c 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,4 +1,4 @@ -from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins, +from .base import (BatchedTensorInputs, MultiModalDataBuiltins, MultiModalDataDict, MultiModalInputs, MultiModalPlugin, NestedTensors) from .registry import MultiModalRegistry @@ -14,7 +14,6 @@ __all__ = [ "BatchedTensorInputs", - "BatchedTensors", "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 8ada60c8fd6a..5b00117c64e5 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,9 +1,8 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import Callable, Dict, List, Mapping, Optional -from typing import Sequence as GenericSequence -from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final +from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type, + TypedDict, TypeVar, Union, cast, final) import numpy as np import torch @@ -15,23 +14,16 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import JSONTree, json_map_leaves +from vllm.utils import json_map_leaves logger = init_logger(__name__) -NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], torch.Tensor] """ -Use a list instead of a tensor if the dimensions of each element do not match. -Currently only supports up to singly nested list of tensors. +Uses a list instead of a tensor if the dimensions of each element do not match. """ -BatchedTensors: TypeAlias = JSONTree[torch.Tensor] -""" -A nested JSON structure of tensors which have been batched via -:meth:`MultiModalInputs.batch`. -""" - -BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]] +BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via :meth:`MultiModalInputs.batch`. @@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase): """ @staticmethod - def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors: + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ - If each input tensor in the batch has the same shape, return a single - batched tensor; otherwise, return a list of :class:`NestedTensors` with - one element per item in the batch. + Recursively stacks lists of tensors when they all have the same shape. """ - # may be list rather than tensors - if isinstance(tensors[0], list): - return [[t for t in tensor[0]] - for tensor in cast(List[List[torch.Tensor]], tensors)] - - tensors_ = cast(List[torch.Tensor], tensors) + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors - unbatched_shape = tensors_[0].shape[1:] + stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] + if any(isinstance(t, list) for t in stacked): + return stacked - for tensor in tensors_: - if tensor.shape[1:] != unbatched_shape: - return [tensor.squeeze(0) for tensor in tensors_] + tensors_ = cast(List[torch.Tensor], stacked) + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ - return torch.cat(tensors_, dim=0) + return torch.stack(tensors_) @staticmethod def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: @@ -102,7 +91,7 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: item_lists[k].append(v) return { - k: MultiModalInputs._try_concat(item_list) + k: MultiModalInputs._try_stack(item_list) for k, item_list in item_lists.items() } diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3bf430235462..989b2e1a814c 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens( prompt_token_ids: List[int], *, placeholder_token_id: int, - repeat_count: int = 1, + repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, ) -> Tuple[Optional[str], List[int]]: + if isinstance(repeat_count, int): + repeat_count = [repeat_count] + if prompt is None: new_prompt = None else: @@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens( tokenizer.decode(pad_token_left)) pad_token_str_right = (None if pad_token_right is None else tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - placeholder_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) placeholder_token_count = prompt.count(placeholder_token_str) # This is an arbitrary number to distinguish between the two cases @@ -216,28 +212,45 @@ def repeat_and_pad_placeholder_tokens( "Please follow the prompt format that is " "documented on HuggingFace which does not involve " "repeating %s tokens.", placeholder_token_str) - elif placeholder_token_count > 1: - logger.warning("Multiple multi-modal input is not supported yet, " - "so any extra placeholder tokens will be treated " - "as plain text.") - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) + if placeholder_token_count < len(repeat_count): + logger.warning( + "The number of multi-modal placeholder tokens in the prompt " + "is less than the number of multi-modal inputs. Extra " + "placeholder tokens will be treated as plain text") + repeat_count = repeat_count[:placeholder_token_count] + + prompt_parts = prompt.split(placeholder_token_str, + maxsplit=len(repeat_count)) + new_prompt = "" + for i, repeat_count_item in enumerate(repeat_count): + replacement_str = "".join( + repeat_and_pad_token( + placeholder_token_str, + repeat_count=repeat_count_item, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + # The image tokens are removed to be consistent with HuggingFace + new_prompt += prompt_parts[i] + replacement_str + new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] + placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: replacement_ids = repeat_and_pad_token( placeholder_token_id, - repeat_count=repeat_count, + repeat_count=repeat_count[placeholder_token_idx], pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) new_token_ids.extend(replacement_ids) + placeholder_token_idx += 1 - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + 1:]) - break + # No need to further scan the list since we replaced all tokens + if placeholder_token_idx >= len(repeat_count): + new_token_ids.extend(prompt_token_ids[i + 1:]) + break else: new_token_ids.append(token) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bda82d3712f0..8d18527e7c97 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -84,6 +84,9 @@ def warn_if_different_devices(): def device_id_to_physical_device_id(device_id: int) -> int: if "CUDA_VISIBLE_DEVICES" in os.environ: device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string," + " which means GPU support is disabled.") physical_device_id = device_ids[device_id] return int(physical_device_id) else: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3f6f5adee5a5..28525e8ff881 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,10 +1,21 @@ +import os from functools import lru_cache from typing import Tuple import torch +from vllm.logger import init_logger + from .interface import Platform, PlatformEnum +logger = init_logger(__name__) + +if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: + logger.warning("`fork` method is not supported by ROCm. " + "VLLM_WORKER_MULTIPROC_METHOD is overridden to" + " `spawn` instead.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM diff --git a/vllm/utils.py b/vllm/utils.py index 0b7457a70b36..dab8e5fe0435 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,6 +25,7 @@ import psutil import torch import torch.types +from packaging.version import Version from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -1114,3 +1115,11 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + base_torch_version = Version(Version(torch.__version__).base_version) + return base_torch_version >= Version("2.4.0") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a81b89299223..607381096276 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -44,7 +44,8 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available) + flatten_2d_lists, is_hip, is_pin_memory_available, + supports_dynamo) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -946,7 +947,7 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: + if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): self.model = torch.compile(self.model, fullgraph=True, backend="eager") diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 0335bbcd091e..3894658a095f 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,6 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -439,9 +440,11 @@ def profile_run(self) -> None: "Setting it to the minimum value of 1.", expr) max_num_seqs = 1 + batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len seq_data, dummy_multi_modal_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, @@ -465,7 +468,13 @@ def profile_run(self) -> None: finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) - self.execute_model(model_input, kv_caches) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.xpu.synchronize() return @@ -537,7 +546,7 @@ def execute_model( and self.observability_config.collect_model_forward_time): model_forward_start_time = time.time() - hidden_states = model_executable( + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, @@ -545,12 +554,16 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device)) + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end_time = time.time() # Compute the logits. - logits = self.model.compute_logits(hidden_states, + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index b00d1889f8d4..9ad070d042a3 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -14,6 +14,7 @@ SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.utils import is_xpu @@ -198,3 +199,8 @@ def init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + + if parallel_config.pipeline_parallel_size > 1: + # torch-ccl xpu need a collective API warm up + # before calling send/recv API + get_pp_group().all_reduce(torch.zeros(1).xpu()) From 4ae7573cdad26b63d89774519c71f81e43ffbc8a Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 28 Aug 2024 18:22:06 +0530 Subject: [PATCH 11/32] updation --- vllm/model_executor/models/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0acc843cbebf..1c83349fcc7e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -248,8 +248,8 @@ def make_empty_intermediate_tensors_factory(keys: List[str], def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ + device: torch.device) -> NestedTensors: + return NestedTensors({ key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) From 7aefcbf4b18b350be6e248f45015e82e7782d486 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 28 Aug 2024 18:26:34 +0530 Subject: [PATCH 12/32] updation --- vllm/model_executor/models/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1c83349fcc7e..8b3a17c68438 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -246,9 +246,8 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int) -> Callable: - def make_empty_intermediate_tensors( - batch_size: int, dtype: torch.dtype, - device: torch.device) -> NestedTensors: + def make_empty_intermediate_tensors(batch_size: int, dtype: torch.dtype, + device: torch.device) -> NestedTensors: return NestedTensors({ key: torch.zeros((batch_size, hidden_size), dtype=dtype, From 530c8a3aab900f452b5e3aaa913fbee8529aadcf Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 2 Sep 2024 21:31:19 +0530 Subject: [PATCH 13/32] Updating Branch --- .../lm-eval-harness/configs/models-small.txt | 1 - .buildkite/run-tpu-test.sh | 2 +- .buildkite/test-pipeline.yaml | 2 + CMakeLists.txt | 2 + Dockerfile | 25 +- Dockerfile.tpu | 2 +- benchmarks/benchmark_serving.py | 132 +++- csrc/mamba/causal_conv1d/causal_conv1d.cu | 700 ++++++++++++++++++ csrc/mamba/causal_conv1d/causal_conv1d.h | 144 ++++ csrc/mamba/causal_conv1d/static_switch.h | 28 + csrc/mamba/mamba_ssm/selective_scan.h | 276 +++++++ csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 593 +++++++++++++++ csrc/mamba/mamba_ssm/static_switch.h | 28 + csrc/ops.h | 22 + csrc/torch_bindings.cpp | 25 + docs/requirements-docs.txt | 3 +- .../getting_started/tpu-installation.rst | 7 +- docs/source/models/supported_models.rst | 8 + .../performance_benchmark/benchmarks.rst | 2 +- .../serving/openai_compatible_server.md | 26 + examples/offline_inference_neuron.py | 11 +- examples/openai_vision_api_client.py | 39 + requirements-common.txt | 3 +- requirements-mamba.txt | 3 - requirements-test.txt | 6 +- requirements-tpu.txt | 2 +- setup.py | 1 + tests/compile/test_wrapper.py | 59 ++ tests/conftest.py | 6 + tests/data/test_config.yaml | 2 + tests/entrypoints/openai/test_serving_chat.py | 2 + tests/entrypoints/openai/test_vision.py | 71 +- tests/entrypoints/test_chat_utils.py | 305 ++++++++ tests/kernels/test_awq_triton.py | 169 +++++ tests/kernels/test_causal_conv1d.py | 205 +++++ tests/kernels/test_flashinfer.py | 228 +++++- tests/kernels/test_mamba_ssm.py | 324 ++++++++ tests/models/test_fp8kv_flashinfer.py | 96 +++ tests/models/test_granite.py | 49 ++ tests/models/test_intern_vit.py | 3 +- tests/models/test_internvl.py | 63 +- tests/models/test_llava.py | 17 + tests/models/test_phimoe.py | 111 +++ tests/models/test_ultravox.py | 4 +- tests/models/utils.py | 43 +- .../multi_step/test_correctness_async_llm.py | 105 ++- tests/multi_step/test_correctness_llm.py | 95 ++- tests/multimodal/test_base.py | 12 + tests/quantization/test_bitsandbytes.py | 166 +++-- tests/samplers/test_rejection_sampler.py | 116 ++- .../test_typical_acceptance_sampler.py | 50 +- tests/spec_decode/test_multi_step_worker.py | 3 +- tests/spec_decode/test_spec_decode_worker.py | 8 +- tests/spec_decode/test_utils.py | 21 +- tests/spec_decode/utils.py | 4 +- tests/test_sequence.py | 5 +- tests/test_utils.py | 44 ++ tests/tpu/__init__.py | 0 tests/tpu/test_compilation.py | 35 +- tests/tpu/test_custom_dispatcher.py | 9 + tests/utils.py | 60 ++ vllm/_custom_ops.py | 39 + vllm/attention/backends/flashinfer.py | 40 +- vllm/attention/backends/pallas.py | 8 +- vllm/attention/selector.py | 4 + vllm/compilation/__init__.py | 0 vllm/compilation/wrapper.py | 81 ++ vllm/config.py | 52 +- vllm/core/scheduler.py | 5 +- .../device_communicators/tpu_communicator.py | 27 +- vllm/engine/arg_utils.py | 4 +- vllm/engine/async_llm_engine.py | 68 +- vllm/engine/llm_engine.py | 144 +++- vllm/engine/output_processor/multi_step.py | 23 +- vllm/engine/output_processor/single_step.py | 65 +- vllm/engine/output_processor/util.py | 3 +- vllm/engine/protocol.py | 2 +- vllm/entrypoints/chat_utils.py | 228 +++--- vllm/entrypoints/openai/rpc/client.py | 81 +- vllm/entrypoints/openai/rpc/server.py | 48 +- vllm/entrypoints/openai/serving_chat.py | 10 +- .../openai/serving_tokenization.py | 4 +- vllm/envs.py | 11 +- vllm/executor/cpu_executor.py | 3 +- vllm/executor/distributed_gpu_executor.py | 3 +- vllm/executor/executor_base.py | 3 +- vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 3 +- vllm/executor/neuron_executor.py | 3 +- vllm/executor/openvino_executor.py | 3 +- vllm/executor/ray_gpu_executor.py | 3 +- vllm/executor/ray_tpu_executor.py | 46 +- vllm/executor/ray_utils.py | 29 + vllm/executor/tpu_executor.py | 3 +- vllm/executor/xpu_executor.py | 3 +- vllm/lora/ops/bgmv_expand.py | 9 +- vllm/lora/ops/bgmv_expand_slice.py | 9 +- vllm/lora/ops/bgmv_shrink.py | 9 +- vllm/lora/ops/sgmv_expand.py | 9 +- vllm/lora/ops/sgmv_expand_slice.py | 9 +- vllm/lora/ops/sgmv_shrink.py | 9 +- vllm/lora/punica.py | 4 +- ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 ++++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 ++++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 ++++ .../layers/fused_moe/fused_moe.py | 19 +- vllm/model_executor/layers/fused_moe/layer.py | 90 ++- vllm/model_executor/layers/linear.py | 21 +- vllm/model_executor/layers/mamba/__init__.py | 0 .../layers/mamba/ops/__init__.py | 0 .../layers/mamba/ops/causal_conv1d.py | 86 +++ .../layers/mamba/ops/mamba_ssm.py | 346 +++++++++ .../layers/quantization/awq_triton.py | 304 ++++++++ .../layers/quantization/bitsandbytes.py | 231 +++++- .../compressed_tensors_moe.py | 24 +- .../layers/quantization/experts_int8.py | 26 +- .../model_executor/layers/quantization/fp8.py | 26 +- .../layers/quantization/tpu_int8.py | 21 +- .../layers/rejection_sampler.py | 184 +++-- .../model_executor/layers/rotary_embedding.py | 26 +- vllm/model_executor/layers/sampler.py | 290 ++++++-- .../layers/spec_decode_base_sampler.py | 43 +- .../layers/typical_acceptance_sampler.py | 7 +- .../layers/vocab_parallel_embedding.py | 5 +- vllm/model_executor/model_loader/loader.py | 205 +++-- vllm/model_executor/model_loader/neuron.py | 36 +- vllm/model_executor/model_loader/openvino.py | 3 +- vllm/model_executor/models/__init__.py | 3 + vllm/model_executor/models/arctic.py | 4 +- vllm/model_executor/models/baichuan.py | 4 +- vllm/model_executor/models/bart.py | 4 +- vllm/model_executor/models/blip.py | 79 +- vllm/model_executor/models/blip2.py | 7 +- vllm/model_executor/models/bloom.py | 4 +- vllm/model_executor/models/chameleon.py | 4 +- vllm/model_executor/models/chatglm.py | 4 +- vllm/model_executor/models/clip.py | 105 ++- vllm/model_executor/models/commandr.py | 4 +- vllm/model_executor/models/dbrx.py | 4 +- vllm/model_executor/models/deepseek.py | 4 +- vllm/model_executor/models/deepseek_v2.py | 4 +- vllm/model_executor/models/eagle.py | 3 +- vllm/model_executor/models/exaone.py | 617 +++++++++++++++ vllm/model_executor/models/falcon.py | 4 +- vllm/model_executor/models/fuyu.py | 3 +- vllm/model_executor/models/gemma.py | 4 +- vllm/model_executor/models/gemma2.py | 4 +- vllm/model_executor/models/gpt2.py | 4 +- vllm/model_executor/models/gpt_bigcode.py | 4 +- vllm/model_executor/models/gpt_j.py | 4 +- vllm/model_executor/models/gpt_neox.py | 4 +- vllm/model_executor/models/granite.py | 543 ++++++++++++++ vllm/model_executor/models/intern_vit.py | 64 +- vllm/model_executor/models/internlm2.py | 63 +- vllm/model_executor/models/internvl.py | 3 +- vllm/model_executor/models/jais.py | 4 +- vllm/model_executor/models/jamba.py | 13 +- vllm/model_executor/models/llama.py | 4 +- vllm/model_executor/models/llava.py | 3 +- vllm/model_executor/models/llava_next.py | 3 +- vllm/model_executor/models/medusa.py | 2 +- vllm/model_executor/models/minicpm.py | 4 +- vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/mixtral.py | 4 +- vllm/model_executor/models/mixtral_quant.py | 4 +- vllm/model_executor/models/mlp_speculator.py | 3 +- vllm/model_executor/models/mpt.py | 4 +- vllm/model_executor/models/nemotron.py | 4 +- vllm/model_executor/models/olmo.py | 4 +- vllm/model_executor/models/opt.py | 4 +- vllm/model_executor/models/orion.py | 4 +- vllm/model_executor/models/paligemma.py | 52 +- vllm/model_executor/models/persimmon.py | 4 +- vllm/model_executor/models/phi.py | 4 +- vllm/model_executor/models/phi3_small.py | 4 +- vllm/model_executor/models/phi3v.py | 57 +- vllm/model_executor/models/phimoe.py | 620 ++++++++++++++++ vllm/model_executor/models/qwen.py | 4 +- vllm/model_executor/models/qwen2.py | 4 +- vllm/model_executor/models/qwen2_moe.py | 4 +- vllm/model_executor/models/siglip.py | 211 +----- vllm/model_executor/models/stablelm.py | 4 +- vllm/model_executor/models/starcoder2.py | 4 +- vllm/model_executor/models/ultravox.py | 9 +- vllm/model_executor/models/utils.py | 16 +- vllm/model_executor/models/xverse.py | 4 +- vllm/multimodal/base.py | 4 +- vllm/multimodal/utils.py | 20 +- vllm/scripts.py | 9 + vllm/sequence.py | 74 +- vllm/spec_decode/batch_expansion.py | 3 +- vllm/spec_decode/draft_model_runner.py | 4 +- vllm/spec_decode/medusa_worker.py | 4 +- vllm/spec_decode/mlp_speculator_worker.py | 4 +- vllm/spec_decode/multi_step_worker.py | 5 +- vllm/spec_decode/ngram_worker.py | 3 +- vllm/spec_decode/proposer_worker_base.py | 3 +- .../spec_decode/smaller_tp_proposer_worker.py | 3 +- vllm/spec_decode/spec_decode_worker.py | 10 +- vllm/spec_decode/top1_proposer.py | 4 +- vllm/spec_decode/util.py | 8 +- vllm/transformers_utils/config.py | 19 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/exaone.py | 190 +++++ vllm/transformers_utils/configs/granite.py | 199 +++++ vllm/transformers_utils/tokenizer.py | 4 +- vllm/transformers_utils/utils.py | 16 + vllm/utils.py | 101 +++ vllm/worker/cpu_model_runner.py | 4 +- vllm/worker/enc_dec_model_runner.py | 3 +- vllm/worker/model_runner.py | 47 +- vllm/worker/model_runner_base.py | 4 +- vllm/worker/multi_step_model_runner.py | 244 +++++- vllm/worker/multi_step_worker.py | 11 +- vllm/worker/neuron_model_runner.py | 4 +- vllm/worker/openvino_model_runner.py | 3 +- vllm/worker/openvino_worker.py | 3 +- vllm/worker/tpu_model_runner.py | 57 +- vllm/worker/tpu_worker.py | 4 - vllm/worker/utils.py | 2 +- vllm/worker/worker.py | 4 +- vllm/worker/worker_base.py | 4 +- vllm/worker/xpu_model_runner.py | 4 +- 223 files changed, 10234 insertions(+), 1556 deletions(-) create mode 100644 csrc/mamba/causal_conv1d/causal_conv1d.cu create mode 100644 csrc/mamba/causal_conv1d/causal_conv1d.h create mode 100644 csrc/mamba/causal_conv1d/static_switch.h create mode 100644 csrc/mamba/mamba_ssm/selective_scan.h create mode 100644 csrc/mamba/mamba_ssm/selective_scan_fwd.cu create mode 100644 csrc/mamba/mamba_ssm/static_switch.h delete mode 100644 requirements-mamba.txt create mode 100644 tests/compile/test_wrapper.py create mode 100644 tests/data/test_config.yaml create mode 100644 tests/entrypoints/test_chat_utils.py create mode 100644 tests/kernels/test_awq_triton.py create mode 100644 tests/kernels/test_causal_conv1d.py create mode 100644 tests/kernels/test_mamba_ssm.py create mode 100644 tests/models/test_fp8kv_flashinfer.py create mode 100644 tests/models/test_granite.py create mode 100644 tests/models/test_phimoe.py create mode 100644 tests/tpu/__init__.py create mode 100644 tests/tpu/test_custom_dispatcher.py create mode 100644 vllm/compilation/__init__.py create mode 100644 vllm/compilation/wrapper.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/mamba/__init__.py create mode 100644 vllm/model_executor/layers/mamba/ops/__init__.py create mode 100644 vllm/model_executor/layers/mamba/ops/causal_conv1d.py create mode 100644 vllm/model_executor/layers/mamba/ops/mamba_ssm.py create mode 100644 vllm/model_executor/layers/quantization/awq_triton.py create mode 100644 vllm/model_executor/models/exaone.py create mode 100644 vllm/model_executor/models/granite.py create mode 100644 vllm/model_executor/models/phimoe.py create mode 100644 vllm/transformers_utils/configs/exaone.py create mode 100644 vllm/transformers_utils/configs/granite.py create mode 100644 vllm/transformers_utils/utils.py diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index bb9cd43e2df0..064883859218 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,5 +1,4 @@ Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 335ffd83fcd7..6989c94d46a8 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -12,4 +12,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9f449ff650b9..86eddb576c42 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -90,6 +90,7 @@ steps: - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/openai + - pytest -v -s entrypoints/test_chat_utils.py - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" @@ -173,6 +174,7 @@ steps: - vllm/ commands: - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py - label: Vision Language Models Test # 42min diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b0d0ba904c3..923ed084ffd9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" diff --git a/Dockerfile b/Dockerfile index 36fcc2f83e9f..ec6069f605eb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt -COPY requirements-mamba.txt requirements-mamba.txt -RUN python3 -m pip install packaging -RUN python3 -m pip install -r requirements-mamba.txt # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt #################### DEV IMAGE #################### -#################### MAMBA Build IMAGE #################### -FROM dev as mamba-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} - -WORKDIR /usr/src/mamba - -COPY requirements-mamba.txt requirements-mamba.txt - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel -r requirements-mamba.txt \ - --no-build-isolation --no-deps --no-cache-dir - -#################### MAMBA Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base @@ -179,13 +160,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose -RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir - RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl #################### vLLM installation IMAGE #################### diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 1cf43247e978..3a11c6721ead 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240808" +ARG NIGHTLY_DATE="20240828" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fe687da49290..e38ceaa22295 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -61,15 +61,22 @@ class BenchmarkMetrics: mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float - p99_ttft_ms: float + percentiles_ttft_ms: List[Tuple[float, float]] mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float - p99_tpot_ms: float + percentiles_tpot_ms: List[Tuple[float, float]] mean_itl_ms: float median_itl_ms: float std_itl_ms: float - p99_itl_ms: float + percentiles_itl_ms: List[Tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: List[Tuple[float, float]] def sample_sharegpt_requests( @@ -235,6 +242,8 @@ def calculate_metrics( outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, + selected_percentile_metrics: List[str], + selected_percentiles: List[float], ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 @@ -242,6 +251,7 @@ def calculate_metrics( itls: List[float] = [] tpots: List[float] = [] ttfts: List[float] = [] + e2els: List[float] = [] for i in range(len(outputs)): if outputs[i].success: # We use the tokenizer to count the number of output tokens for all @@ -258,6 +268,7 @@ def calculate_metrics( (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) completed += 1 else: actual_output_lens.append(0) @@ -276,17 +287,25 @@ def calculate_metrics( output_throughput=sum(actual_output_lens) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend - median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, - p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], mean_tpot_ms=np.mean(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, - p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], mean_itl_ms=np.mean(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, - p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.median(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.mean(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], ) return metrics, actual_output_lens @@ -304,6 +323,8 @@ async def benchmark( request_rate: float, disable_tqdm: bool, profile: bool, + selected_percentile_metrics: List[str], + selected_percentiles: List[str], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -392,6 +413,8 @@ async def benchmark( outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -407,23 +430,6 @@ async def benchmark( metrics.input_throughput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) - print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) - print("{:<40} {:<10.2f}".format("Median TTFT (ms):", - metrics.median_ttft_ms)) - print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) - print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', - n=50, - c='-')) - print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) - print("{:<40} {:<10.2f}".format("Median TPOT (ms):", - metrics.median_tpot_ms)) - print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) - print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) - print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) - print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) - print("=" * 50) result = { "duration": benchmark_duration, @@ -433,18 +439,6 @@ async def benchmark( "request_throughput": metrics.request_throughput, "input_throughput": metrics.input_throughput, "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], @@ -452,6 +446,47 @@ async def benchmark( "generated_texts": [output.generated_text for output in outputs], "errors": [output.error for output in outputs], } + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function print and add statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + return result @@ -550,6 +585,10 @@ def main(args: argparse.Namespace): request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], )) # Save config and results to json @@ -765,6 +804,23 @@ def main(args: argparse.Namespace): "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" " format.", ) + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " + "Default value is \"ttft,tpot,itl\".") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\". " + "Use \"--percentile-metrics\" to select metrics.", + ) args = parser.parse_args() main(args) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu new file mode 100644 index 000000000000..88a64a8ece58 --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -0,0 +1,700 @@ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu +// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu +#include +#include +#include + +#include "causal_conv1d.h" +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "static_switch.h" + + + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + void* bias_ptr, + bool silu_activation) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias_ptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.x_batch_stride = x.stride(0); + params.x_c_stride = x.stride(1); + params.x_l_stride = x.stride(-1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(0); + params.out_c_stride = out.stride(1); + params.out_l_stride = out.stride(-1); +} + + +at::Tensor +causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + const c10::optional &seq_idx_, + const c10::optional &initial_states_, + const c10::optional &final_states_out_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + + if (is_channel_last) { + TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); + TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); + } + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + if (seq_idx_.has_value()) { + TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); + auto seq_idx = seq_idx_.value(); + TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_idx.is_cuda()); + TORCH_CHECK(seq_idx.is_contiguous()); + CHECK_SHAPE(seq_idx, batch_size, seqlen); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + + if (seq_idx_.has_value()) { + params.seq_idx_ptr = seq_idx_.value().data_ptr(); + } else { + params.seq_idx_ptr = nullptr; + } + + if (initial_states_.has_value()) { + TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); + auto initial_states = initial_states_.value(); + TORCH_CHECK(initial_states.scalar_type() == input_type); + TORCH_CHECK(initial_states.is_cuda()); + CHECK_SHAPE(initial_states, batch_size, dim, width - 1); + TORCH_CHECK(initial_states.stride(1) == 1); + params.initial_states_ptr = initial_states.data_ptr(); + params.initial_states_batch_stride = initial_states.stride(0); + params.initial_states_c_stride = initial_states.stride(1); + params.initial_states_l_stride = initial_states.stride(2); + } else { + params.initial_states_ptr = nullptr; + } + + if (final_states_out_.has_value()) { + TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); + auto final_states = final_states_out_.value(); + TORCH_CHECK(final_states.scalar_type() == input_type); + TORCH_CHECK(final_states.is_cuda()); + CHECK_SHAPE(final_states, batch_size, dim, width - 1); + TORCH_CHECK(final_states.stride(1) == 1); + params.final_states_ptr = final_states.data_ptr(); + params.final_states_batch_stride = final_states.stride(0); + params.final_states_c_stride = final_states.stride(1); + params.final_states_l_stride = final_states.stride(2); + } else { + params.final_states_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + if (!is_channel_last) { + causal_conv1d_fwd_cuda(params, stream); + } else { + causal_conv1d_channellast_fwd_cuda(params, stream); + } + }); + return out; +} + + +at::Tensor +causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + params.conv_state_ptr = conv_state.data_ptr(); + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); + return out; +} + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + static constexpr int kSmemIOSize = kIsVecLoad + ? 0 + : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t zeros[kNElts] = {0}; + smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + if constexpr(kIsVecLoad) { + typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + __syncthreads(); + typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + } + x += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + } + out += kChunkSize; + } +} + + +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + + auto kernel = &causal_conv1d_fwd_kernel; + + if (kSmemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + #else + // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } + kernel<<>>(params); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template +struct Causal_conv1d_channellast_fwd_kernel_traits { + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. + // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 + // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; + // using BlockStoreT = cub::BlockStore; + // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; + + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t *weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) + + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; + input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr + : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + // The last L-chunk will also have enough info to write to final states, since it also contain a few x values + // from the previous L-chunk. + input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr + : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); + } + reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + // Load the elements from the previous chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } else if (initial_states != nullptr + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); + } + reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + + __syncthreads(); + + if (final_states != nullptr + && l_idx < kWidth - 1 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) + // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] + *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; + } + + constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; + } + } + float x_vals[kWidth - 1 + kLPerThread]; + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + int seq_idx_thread[kWidth - 1 + kLPerThread]; + if constexpr (kHasSeqIdx) { + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; + } + } + + float out_vals[kLPerThread]; + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; + const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + if constexpr (!kHasSeqIdx) { + out_vals[i] += weight_vals[w] * x_vals[i + w]; + } else { + out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; + } + } + if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } + __syncthreads(); + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t out_vals_store[kNElts]; + reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; + } + } + +} + +template +void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { + using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_fwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +/////// + + + + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + float weight_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + } + + float x_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } + x_vals[kWidth - 1] = float(x[0]); + #pragma unroll + for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } + } + + float out_val = bias_val; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + if (channel_id < params.dim) { out[0] = input_t(out_val); } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h new file mode 100644 index 000000000000..bb25314c8bbb --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h +#pragma once + +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + + void *__restrict__ seq_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; +}; + + +#ifndef USE_ROCM + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor_sync(uint32_t(-1), val, offset); + } + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor(val, offset); + } + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h new file mode 100644 index 000000000000..ef74bf447f84 --- /dev/null +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h new file mode 100644 index 000000000000..0070c92f6cd0 --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -0,0 +1,276 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h + +#pragma once + +#ifndef USE_ROCM + #include +#else + #include +#endif +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; + void *__restrict__ index_ptr; +}; + + + + +#ifndef USE_ROCM + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + + +#define MAX_DSTATE 256 + + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + #ifdef USE_ROCM + , Ktraits::kNThreads * Ktraits::kNLoads + #endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_index(int *u, + int (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_index_vec = reinterpret_cast(smem_load_index); + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu new file mode 100644 index 000000000000..df968dda92ad --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -0,0 +1,593 @@ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh +#include +#include +#include +#include "selective_scan.h" + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + namespace cub = hipcub; +#endif + +#include "selective_scan.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + int index_vals_load[kNRows][kNItems]; + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kUseIndex) { + load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); + } + } + if constexpr (kUseIndex) { + index += kChunkSize; + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; + } + } + + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size + constexpr bool kIsVariableB = true; + constexpr bool kIsVariableC = true; + constexpr bool kHasZ = true; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + + #ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #endif +} + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const torch::Tensor u, + const torch::Tensor delta, + const torch::Tensor A, + const torch::Tensor B, + const torch::Tensor C, + const torch::Tensor out, + const torch::Tensor z, + const torch::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus, + void* index_ptr) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.index_ptr = index_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +std::vector +selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + const c10::optional &index_, + const c10::optional &x) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen ); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + + TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + if (x.has_value()){ + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); + } + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.value().data_ptr(), + has_z, + delta_softplus, + index_.has_value() ? index_.value().data_ptr() : nullptr); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); + std::vector result = {out, x.value()}; + if (has_z) { result.push_back(out_z); } + return result; +} + diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h new file mode 100644 index 000000000000..840cb2374a2f --- /dev/null +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/ops.h b/csrc/ops.h index 6bf0cff23252..8d24545de898 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); +std::vector selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const c10::optional& D_, + const c10::optional& z_, + const c10::optional& delta_bias_, bool delta_softplus, + const c10::optional& index_, + const c10::optional& x); + +at::Tensor causal_conv1d_update(const at::Tensor& x, + const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation); + +at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& seq_idx_, + const c10::optional& initial_states_, + const c10::optional& final_states_out_, + bool silu_activation); + #ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6d1f53b75f4e..7783acd741f5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, &cutlass_scaled_mm_supports_fp8); + // Mamba selective scan kernel + ops.def( + "selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor! C," + "Tensor? D_, Tensor? z_, Tensor? delta_bias_," + "bool delta_softplus," + "Tensor? index_, Tensor? x) -> Tensor[]"); + ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + + ops.def( + "causal_conv1d_update(Tensor! x," + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + + ops.def( + "causal_conv1d_fwd(Tensor! x, Tensor! weight," + "Tensor? bias_," + "Tensor? seq_idx_," + "Tensor? initial_states_," + "Tensor? final_states_out_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif // Quantized GEMM for GPTQ. diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 95a9be780663..c358e23b6a37 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -11,5 +11,6 @@ pydantic >= 2.8 torch py-cpuinfo transformers -mistral_common >= 1.3.4 openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +mistral_common >= 1.3.4 +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 31ae30ad302b..d0c2498d8849 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,9 +56,10 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="+20240808" - $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl - $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl + $ export DATE="20240828" + $ export TORCH_VERSION="2.5.0" + $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl + $ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl $ # Install JAX and Pallas. $ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 223c68b40766..2c20b6e48407 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -51,6 +51,10 @@ Decoder-only Language Models - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - + * - :code:`ExaoneForCausalLM` + - EXAONE-3 + - :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. + - ✅︎ * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. @@ -143,6 +147,10 @@ Decoder-only Language Models - Phi-3-Small - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - + * - :code:`PhiMoEForCausalLM` + - Phi-3.5-MoE + - :code:`microsoft/Phi-3.5-MoE-instruct`, etc. + - * - :code:`PersimmonForCausalLM` - Persimmon - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. diff --git a/docs/source/performance_benchmark/benchmarks.rst b/docs/source/performance_benchmark/benchmarks.rst index 9a23aab10d03..e5c8d6a55de6 100644 --- a/docs/source/performance_benchmark/benchmarks.rst +++ b/docs/source/performance_benchmark/benchmarks.rst @@ -20,4 +20,4 @@ The performance benchmarks and nightly benchmarks can be triggered by submitting .. note:: - Please refer to `vLLM performance benchmark descriptions `_ and `vLLM nightly benchmark descriptions `_ for detailed descriptions on benchmark environment, workload and metrics. + Please refer to `vLLM performance benchmark descriptions `_ and `vLLM nightly benchmark descriptions `_ for detailed descriptions on benchmark environment, workload and metrics. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a06c30d9c48c..b2acde390083 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -111,6 +111,32 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) :prog: vllm serve ``` +### Config file + +The `serve` module can also accept arguments from a config file in +`yaml` format. The arguments in the yaml must be specified using the +long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server): + +For example: + +```yaml +# config.yaml + +host: "127.0.0.1" +port: 6379 +uvicorn-log-level: "info" +``` + +```bash +$ vllm serve SOME_MODEL --config config.yaml +``` +--- +**NOTE** +In case an argument is supplied using command line and the config file, the value from the commandline will take precedence. +The order of priorities is `command line > config file values > defaults`. + +--- + ## Tool calling in the chat completion API vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. diff --git a/examples/offline_inference_neuron.py b/examples/offline_inference_neuron.py index 5ecbbf020ab8..2856be7c864e 100644 --- a/examples/offline_inference_neuron.py +++ b/examples/offline_inference_neuron.py @@ -1,5 +1,12 @@ +import os + from vllm import LLM, SamplingParams +# creates XLA hlo graphs for all the context length buckets. +os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" +# creates XLA hlo graphs for all the token gen buckets. +os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + # Sample prompts. prompts = [ "Hello, my name is", @@ -19,8 +26,8 @@ # Currently, this is a known limitation in continuous batching support # in transformers-neuronx. # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=128, - block_size=128, + max_model_len=2048, + block_size=2048, # The device can be automatically detected when AWS Neuron SDK is installed. # The device argument can be either unspecified for automated detection, # or explicitly assigned. diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index be90394511f8..e1d4055763e5 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -1,7 +1,13 @@ """An example showing how to use vLLM to serve VLMs. Launch the vLLM server with the following command: + +(single image inference with Llava) vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja + +(multi-image inference with Phi-3.5-vision-instruct) +vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ + --trust-remote-code --limit-mm-per-prompt image=2 """ import base64 @@ -84,3 +90,36 @@ def encode_image_base64_from_url(image_url: str) -> str: result = chat_completion_from_base64.choices[0].message.content print(f"Chat completion output:{result}") + +# Multi-image input inference +image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" +image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" +chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What are the animals in these images?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url_duck + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url_lion + }, + }, + ], + }], + model=model, + max_tokens=64, +) + +result = chat_completion_from_url.choices[0].message.content +print(f"Chat completion output:{result}") diff --git a/requirements-common.txt b/requirements-common.txt index 61daf9981975..4c5b681a0d5a 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -22,8 +22,7 @@ typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq msgspec -librosa # Required for audio processing -soundfile # Required for audio processing gguf == 0.9.1 importlib_metadata mistral_common >= 1.3.4 +pyyaml diff --git a/requirements-mamba.txt b/requirements-mamba.txt deleted file mode 100644 index 1838e87d063d..000000000000 --- a/requirements-mamba.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Mamba dependencies -mamba-ssm>=1.2.2 -causal-conv1d>=1.2.0 diff --git a/requirements-test.txt b/requirements-test.txt index cdbc3e50cc9e..58cf1716b45c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -11,12 +11,14 @@ pytest-shard # testing utils awscli -einops # required for MPT and qwen-vl +einops # required for MPT, qwen-vl and Mamba httpx +librosa # required for audio test peft requests ray sentence-transformers # required for embedding +soundfile # required for audio test compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test transformers_stream_generator # required for qwen-vl test @@ -30,4 +32,4 @@ aiohttp # quantization bitsandbytes==0.42.0 -buildkite-test-collector==0.1.8 \ No newline at end of file +buildkite-test-collector==0.1.8 diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 5eb27b39eb62..4c606cf0a910 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,4 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. -ray +ray[default] diff --git a/setup.py b/setup.py index 21b0422c0f0b..38d3f41663f2 100644 --- a/setup.py +++ b/setup.py @@ -501,6 +501,7 @@ def _read_requirements(filename: str) -> List[str]: ext_modules=ext_modules, extras_require={ "tensorizer": ["tensorizer>=2.9.0"], + "audio": ["librosa", "soundfile"] # Required for audio processing }, cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, package_data=package_data, diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py new file mode 100644 index 000000000000..cef516ade27e --- /dev/null +++ b/tests/compile/test_wrapper.py @@ -0,0 +1,59 @@ +from typing import Optional + +import torch + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther + + +class MyMod(torch.nn.Module): + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + if cache is not None: + return x + cache + return x * 2 + + +class MyWrapper(TorchCompileWrapperWithCustomDispacther): + + def __init__(self, model): + self.model = model + compiled_callable = torch.compile(self.forward, backend="eager") + super().__init__(compiled_callable) + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # this is the function to be compiled + return self.model(x, cache) + + def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # let torch.compile compile twice + if len(self.compiled_codes) == 2: + dispatch_id = 0 if cache is None else 1 + with self.dispatch_to_code(dispatch_id): + return self.forward(x, cache) + else: + return self.compiled_callable(x, cache) + + +def test_torch_compile_wrapper(): + mod = MyMod() + wrappers = [] + for i in range(3): + torch._dynamo.reset() + wrapper = MyWrapper(mod) + wrappers.append(wrapper) + x = torch.tensor([1]) + wrapper(x, None) # profile run, compile + # create a cache tensor + cache = torch.tensor([2]) + wrapper(x, cache) # warm up with cache, recompile + + # for new input, dispatch to the compiled code directly + new_x = torch.tensor([3]) + assert wrapper(new_x, + None).item() == 6 # dispatch to the first compiled code + assert wrapper( + new_x, cache).item() == 5 # dispatch to the second compiled code + + for wrapper in wrappers: + # make sure they have independent compiled codes + assert len(wrapper.compiled_codes) == 2 diff --git a/tests/conftest.py b/tests/conftest.py index d8264f65b614..e66a14598c34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -209,8 +209,14 @@ class HfRunner: def wrap_device(self, input: _T) -> _T: if not is_cpu(): + # Check if the input is already on the GPU + if hasattr(input, 'device') and input.device.type == "cuda": + return input # Already on GPU, no need to move return input.to("cuda") else: + # Check if the input is already on the CPU + if hasattr(input, 'device') and input.device.type == "cpu": + return input # Already on CPU, no need to move return input.to("cpu") def __init__( diff --git a/tests/data/test_config.yaml b/tests/data/test_config.yaml new file mode 100644 index 000000000000..20d499624de2 --- /dev/null +++ b/tests/data/test_config.yaml @@ -0,0 +1,2 @@ +port: 12312 +tensor_parallel_size: 2 diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 3783b7cd66a6..c3a6c65be1d9 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from unittest.mock import MagicMock +from vllm.config import MultiModalConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -20,6 +21,7 @@ class MockModelConfig: max_model_len = 100 tokenizer_revision = None embedding_mode = False + multimodal_config = MultiModalConfig() @dataclass diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index d2ef3c2071ef..f61fa127b7d0 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -6,11 +6,10 @@ from vllm.multimodal.utils import encode_image_base64, fetch_image -from ...utils import VLLM_PATH, RemoteOpenAIServer +from ...utils import RemoteOpenAIServer -MODEL_NAME = "llava-hf/llava-1.5-7b-hf" -LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja" -assert LLAVA_CHAT_TEMPLATE.exists() +MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" +MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -24,13 +23,9 @@ @pytest.fixture(scope="module") def server(): args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "4096", - "--enforce-eager", - "--chat-template", - str(LLAVA_CHAT_TEMPLATE), + "--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs", + "5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -84,7 +79,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=596, total_tokens=606) + completion_tokens=10, prompt_tokens=772, total_tokens=782) message = choice.message message = chat_completion.choices[0].message @@ -139,7 +134,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=596, total_tokens=606) + completion_tokens=10, prompt_tokens=772, total_tokens=782) message = choice.message message = chat_completion.choices[0].message @@ -217,26 +212,22 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_url: str): + image_urls: List[str]): messages = [{ "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { + *({ "type": "image_url", "image_url": { "url": image_url } - }, + } for image_url in image_urls), { "type": "text", "text": "What's in this image?" @@ -244,20 +235,30 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ], }] - with pytest.raises(openai.BadRequestError): # test multi-image input - await client.chat.completions.create( + if len(image_urls) > MAXIMUM_IMAGES: + with pytest.raises(openai.BadRequestError): # test multi-image input + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 + else: + chat_completion = await client.chat.completions.create( model=model_name, messages=messages, max_tokens=10, temperature=0.0, ) - - # the server should still work afterwards - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - completion = completion.choices[0].text - assert completion is not None and len(completion) >= 0 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py new file mode 100644 index 000000000000..53f99189beb1 --- /dev/null +++ b/tests/entrypoints/test_chat_utils.py @@ -0,0 +1,305 @@ +import warnings + +import pytest +from PIL import Image + +from vllm.assets.image import ImageAsset +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.multimodal.utils import encode_image_base64 +from vllm.transformers_utils.tokenizer_group import TokenizerGroup + +PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" + + +@pytest.fixture(scope="module") +def phi3v_model_config(): + return ModelConfig(PHI3V_MODEL_ID, + PHI3V_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }) + + +@pytest.fixture(scope="module") +def phi3v_tokenizer(): + return TokenizerGroup( + tokenizer_id=PHI3V_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + + +@pytest.fixture(scope="module") +def image_url(): + image = ImageAsset('cherry_blossom') + base64 = encode_image_base64(image.pil_image) + return f"data:image/jpeg;base64,{base64}" + + +@pytest.mark.asyncio +async def test_parse_chat_messages_with_image_url(phi3v_model_config, + phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert isinstance(mm_data["image"], Image.Image) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images(phi3v_model_config, + phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_placeholder_already_in_prompt( + phi3v_model_config, phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_placeholder_one_already_in_prompt( + phi3v_model_config, phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " + "other one?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_across_messages( + phi3v_model_config, phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about this one?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?" + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "<|image_2|>\nWhat about this one?" + }, + ] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_rejects_too_many_images_in_one_message( + phi3v_model_config, phi3v_tokenizer, image_url): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="coroutine 'async_get_and_parse_image' was never awaited") + with pytest.raises( + ValueError, + match="At most 2 image\\(s\\) may be provided in one request\\." + ): + parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_rejects_too_many_images_across_messages( + phi3v_model_config, phi3v_tokenizer, image_url): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="coroutine 'async_get_and_parse_image' was never awaited") + with pytest.raises( + ValueError, + match="At most 2 image\\(s\\) may be provided in one request\\." + ): + parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about these two?" + }] + }], phi3v_model_config, phi3v_tokenizer) diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py new file mode 100644 index 000000000000..198d40a155cc --- /dev/null +++ b/tests/kernels/test_awq_triton.py @@ -0,0 +1,169 @@ +"""Tests for the AWQ Triton kernel. + +Run `pytest tests/kernels/test_awq_triton.py`. +""" +import pytest +import torch + +from vllm.model_executor.layers.quantization.awq_triton import ( + AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) + +device = "cuda" + + +def reverse_awq_order(t: torch.Tensor): + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + reverse_order_tensor = torch.arange( + t.shape[-1], + dtype=torch.int32, + device=t.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + t = t[:, reverse_order_tensor] & 0xF + return t + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int) -> torch.Tensor: + + if group_size == -1: + group_size = qweight.shape[0] + + bits = 4 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], + shifts[None, None, :]).to(torch.int8) + + iweights = iweights.view(iweights.shape[0], -1) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], + shifts[None, None, :]).to(torch.int8) + zeros = zeros.view(qzeros.shape[0], -1) + zeros = reverse_awq_order(zeros) + + iweights = reverse_awq_order(iweights) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024]) +@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) +def test_dequantize(qweight_rows, qweight_cols, group_size): + + if group_size == -1: + group_size = qweight_rows + + qweight_dtype = torch.int32 + scales_rows = qweight_rows // group_size + scales_cols = qweight_cols * 8 + scales_dtype = torch.float16 + zeros_rows = scales_rows + zeros_cols = qweight_cols + zeros_dtype = torch.int32 + + torch.manual_seed(0) + + qweight = torch.randint(0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=qweight_dtype, + device=device) + scales = torch.rand(scales_rows, + scales_cols, + dtype=scales_dtype, + device=device) + zeros = torch.randint(0, + torch.iinfo(torch.int32).max, + (zeros_rows, zeros_cols), + dtype=zeros_dtype, + device=device) + + iweights_triton = awq_dequantize_triton(qweight, scales, zeros) + + assert (not torch.any(torch.isinf(iweights_triton)) + and not torch.any(torch.isnan(iweights_triton))) + + iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) + + torch.testing.assert_close(iweights_triton, iweights_torch) + + +# input - [N, K] +# qweight - [K, M // 8] +# qzeros - [K // G, M // 8] +# scales - [K // G, M] +@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32]) +@pytest.mark.parametrize("K", [128]) +@pytest.mark.parametrize("M", [16, 24, 32]) +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("splitK", [1, 8]) +def test_gemm(N, K, M, splitK, group_size): + + if group_size == -1: + group_size = K + + split_k_iters = splitK + + input_rows = N + input_cols = K + input_dtype = torch.float32 + qweight_rows = input_cols + qweight_cols = M // 8 + scales_rows = qweight_rows // group_size + scales_cols = M + scales_dtype = torch.float32 + qzeros_rows = scales_rows + qzeros_cols = qweight_cols + + torch.manual_seed(0) + + input = torch.rand((input_rows, input_cols), + dtype=input_dtype, + device=device) + qweight = torch.randint(0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + device=device) + qzeros = torch.randint(0, + torch.iinfo(torch.int32).max, + (qzeros_rows, qzeros_cols), + device=device) + scales = torch.rand((scales_rows, scales_cols), + dtype=scales_dtype, + device=device) + + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, + split_k_iters) + + assert (not torch.any(torch.isinf(output_triton)) + and not torch.any(torch.isnan(output_triton))) + + dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros) + + output_torch = torch.matmul(input, dequantized_weights) + + assert (not torch.any(torch.isinf(output_torch)) + and not torch.any(torch.isnan(output_torch))) + + torch.testing.assert_close(output_triton.cpu(), + output_torch.cpu(), + atol=1e-1, + rtol=1e-1) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py new file mode 100644 index 000000000000..7bf338b36953 --- /dev/null +++ b/tests/kernels/test_causal_conv1d.py @@ -0,0 +1,205 @@ +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange + +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update_ref(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + batch, dim = x.shape + width = weight.shape[1] + assert conv_state.shape == (batch, dim, width) + assert weight.shape == (dim, width) + conv_state.copy_(torch.roll(conv_state, shifts=-1, + dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + out = torch.sum(conv_state * weight, dim=-1) # (B D) + if bias is not None: + out += bias + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +@pytest.mark.parametrize("return_final_states", [False, True]) +@pytest.mark.parametrize("has_initial_states", [False, True]) +@pytest.mark.parametrize("channel_last", [False, True]) +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize("seqlen", [128, 512, 4096]) +@pytest.mark.parametrize('dim', [64, 4096 + 32]) +@pytest.mark.parametrize('batch', [1, 2]) +def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, + itype, channel_last, has_initial_states, + return_final_states): + if not channel_last and (has_initial_states or return_final_states): + pytest.skip( + "Only channel_last support initial_states or return_final_states") + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.random.manual_seed(0) + if not channel_last: + x = torch.randn(batch, + 4096 + dim + 64, + seqlen, + device=device, + dtype=itype)[:, 4096:4096 + dim, :] + else: + x = rearrange( + torch.randn(batch, + seqlen, + 4096 + dim + 64, + device=device, + dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s") + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + if has_initial_states: + initial_states = torch.randn(batch, + width - 1, + dim, + device=device, + dtype=itype).transpose(1, 2) + else: + initial_states = None + x_ref = x.detach().clone() + weight_ref = weight.detach().clone() + bias_ref = bias.detach().clone() if bias is not None else None + initial_states_ref = initial_states.detach().clone( + ) if initial_states is not None else None + activation = None if not silu_activation else "silu" + out, final_states = causal_conv1d_fn( + x, + weight, + bias, + initial_states=initial_states, + return_final_states=return_final_states, + activation=activation) + out_ref, final_states_ref = causal_conv1d_ref( + x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) + if return_final_states: + assert final_states is not None and final_states_ref is not None + assert torch.allclose(final_states, + final_states_ref, + rtol=rtol, + atol=atol) + + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + if return_final_states: + out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) + out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +@pytest.mark.parametrize("batch", [1, 2]) +def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, + itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.random.manual_seed(0) + batch = 2 + x = torch.randn(batch, dim, device=device, dtype=itype) + conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state.detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) + + assert torch.equal(conv_state, conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index f109792ad251..67f12cf1ee08 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -73,11 +73,14 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], - num_heads: Tuple[int, - int], head_size: int, - dtype: torch.dtype, block_size: int, - soft_cap: Optional[float]) -> None: +def test_flashinfer_decode_with_paged_kv( + kv_lens: List[int], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) num_seqs = len(kv_lens) @@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], scale = head_size**-0.5 query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn(NUM_BLOCKS, 2, block_size, @@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", use_tensor_cores=( - (num_query_heads//num_kv_heads) not in (1, 2, 4, 8)) + (num_query_heads//num_kv_heads) > 4) ) wrapper.begin_forward(kv_indptr, kv_indices, @@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], soft_cap=soft_cap) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) +@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +def test_flashinfer_prefill_with_paged_fp8_kv( + seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], + head_size: int, dtype: torch.dtype, block_size: int, + soft_cap: Optional[float]) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + kv_cache_dtype = torch.float8_e4m3fn + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + NUM_BLOCKS_FP8 = 2048 + key_value_cache = torch.randn(NUM_BLOCKS_FP8, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + + k_scale = key_cache.amax().item() / 448.0 + v_scale = value_cache.amax().item() / 448.0 + + kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], + dim=1).to(kv_cache_dtype) + + assert (kv_cache_fp8.shape == key_value_cache.shape) + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS_FP8, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + qo_indptr = [0] + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + qo_indptr.append(qo_indptr[-1] + query_lens[i]) + + qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD") + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + ) + + output = wrapper.forward(query, + kv_cache_fp8, + logits_soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale) + + ref_output = ref_paged_attn(query=query, + key_cache=key_cache.squeeze(1), + value_cache=value_cache.squeeze(1), + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) + del query + del block_tables + # verify prefill fp8 + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@torch.inference_mode +def test_flashinfer_decode_with_paged_fp8_kv( + kv_lens: List[int], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: + # test doesn't work for num_heads = (16,16) + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + use_tensor_cores = (num_query_heads // num_kv_heads) > 4 + kv_cache_dtype = torch.float8_e4m3fn + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + NUM_BLOCKS_FP8 = 2048 + key_value_cache = torch.randn(NUM_BLOCKS_FP8, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + + k_scale = key_cache.amax().item() / 448.0 + v_scale = value_cache.amax().item() / 448.0 + + key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype) + value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype) + assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1) + kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS_FP8, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.\ + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", + use_tensor_cores=use_tensor_cores) + wrapper.begin_forward(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + data_type=dtype) + output = wrapper.forward(query, + kv_cache_fp8, + logits_soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale) + key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) + value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) + + ref_output = ref_paged_attn(query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) + # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py new file mode 100644 index 000000000000..d3cb0a8656a0 --- /dev/null +++ b/tests/kernels/test_mamba_ssm.py @@ -0,0 +1,324 @@ +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) + + +def selective_state_update_ref(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * + A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange( + B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_ref(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + prev_state: r(B D N), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + if position_indices is not None and position_indices[0, i] == 0: + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("scan_chunks", [1, 2, 3]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, + has_z, has_delta_bias, delta_softplus, + return_last_state, seqlen, itype, wtype, scan_chunks): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 4 + dstate = 8 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + if not is_variable_B: + B_shape = [dim, dstate] + elif varBC_groups == 1: + B_shape = [batch_size, dstate, seqlen] + else: + B_shape = [batch_size, varBC_groups, dstate, seqlen] + B = torch.randn(B_shape, + device=device, + dtype=wtype if not is_variable_B else itype) + if not is_variable_C: + C_shape = [dim, dstate] + elif varBC_groups == 1: + C_shape = [batch_size, dstate, seqlen] + else: + C_shape = [batch_size, varBC_groups, dstate, seqlen] + C = torch.randn(C_shape, + device=device, + dtype=wtype if not is_variable_C else itype) + D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + z = torch.randn(batch_size, dim, seqlen, device=device, + dtype=itype) if has_z else None + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) + ) if has_delta_bias else None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + delta = (0.5 * + torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + state = None + state_ref = None + out = None + out_ref = None + outs = [] + for c in range(scan_chunks): + chunked_prompt_len = seqlen // scan_chunks + chunk_start = chunked_prompt_len * c + chunk_end = chunked_prompt_len * (c + 1) + if c == scan_chunks - 1: + chunk_end = seqlen + _B = B + if is_variable_B: + _B = B[..., chunk_start:chunk_end] + _C = C + if is_variable_B: + _C = C[..., chunk_start:chunk_end] + _z = z + if has_z: + assert z is not None + _z = z[..., chunk_start:chunk_end] + out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end], + delta[..., chunk_start:chunk_end], + A, + _B, + _C, + D, + z=_z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state, + prev_state=state if c > 0 else None) + outs.append(out) + if return_last_state: + state = rest[0] + if len(outs) > 1: + out = torch.cat(outs, dim=-1) + out_ref, *rest = selective_scan_ref(u, + delta, + A, + B, + C, + D, + z=z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state) + if return_last_state: + state_ref = rest[0] + + assert out is not None and out_ref is not None + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + assert state is not None and state_ref is not None + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_selective_state_update(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 1 + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state.detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/models/test_fp8kv_flashinfer.py b/tests/models/test_fp8kv_flashinfer.py new file mode 100644 index 000000000000..ff2a44162b6c --- /dev/null +++ b/tests/models/test_fp8kv_flashinfer.py @@ -0,0 +1,96 @@ +# flake8: noqa +"""Tests fp8 models against ground truth generation +This verifies the flashinfer backend with fp8 +quantization and fp8 KV Cache without scaling +factors Note: these tests will only pass on H100 GPU. +""" +import os +from typing import List + +import pytest +from transformers import AutoTokenizer + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = [ + "nm-testing/Meta-Llama-3-8B-Instruct-FP8", +] + +EXPECTED_STRS_MAP = { + "nm-testing/Meta-Llama-3-8B-Instruct-FP8": { + "auto": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o', + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', + ] + } +} + + +# This test compares against golden strings for exact match since +# there is no baseline implementation to compare against +# and is unstable w.r.t specifics of the fp8 implementation or +# the hardware being run on. +# No assert to prevent it from breaking the build +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"]) +def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None: + # Note that the golden strings may not work for FLASHINFER Backend. + # The intention is to test the path + os.environ["VLLM_ATTENTION_BACKEND"] = backend + model = LLM(model=model_name, + max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, + quantization="fp8", + kv_cache_dtype=kv_cache_dtype) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + + params = SamplingParams(max_tokens=20, temperature=0) + generations: List[str] = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}") + expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + print(f"generated_str\n: {generated_str}") + print(f"expected_str\n: {expected_str}") diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py new file mode 100644 index 000000000000..2435b5dc3ff8 --- /dev/null +++ b/tests/models/test_granite.py @@ -0,0 +1,49 @@ +"""Compare the outputs of HF and vLLM for Granite models using greedy sampling. + +Run `pytest tests/models/test_granite.py`. +""" +import importlib.metadata + +import pytest + +from .utils import check_logprobs_close + +TRANSFORMERS_VERSION = tuple( + map(int, + importlib.metadata.version("transformers").split("."))) + +MODELS = [ + "ibm/PowerLM-3b", +] + + +# GraniteForCausalLM will be in transformers >= 4.45 +@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), + reason="granite model test requires transformers >= 4.45") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + # TODO(sang): Sliding window should be tested separately. + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_intern_vit.py b/tests/models/test_intern_vit.py index e980446ff357..816f846f69ba 100644 --- a/tests/models/test_intern_vit.py +++ b/tests/models/test_intern_vit.py @@ -6,8 +6,6 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from vllm.model_executor.models.intern_vit import InternVisionModel - from ..conftest import _ImageAssets, cleanup pytestmark = pytest.mark.vlm @@ -49,6 +47,7 @@ def run_intern_vit_test( for pixel_value in pixel_values ] + from vllm.model_executor.models.intern_vit import InternVisionModel vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 243bc857c88d..42732cebc656 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -6,9 +6,6 @@ from PIL.Image import Image from transformers import AutoConfig -from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END, - IMG_START, - image_to_pixel_values) from vllm.multimodal.utils import rescale_image_size from vllm.utils import is_cpu @@ -33,35 +30,6 @@ ] -class InternVLProcessor: - """A simple processor for InternVL2 HF model which misses a processor.""" - - def __init__(self, hf_runner: HfRunner): - self.num_image_token = hf_runner.model.num_image_token - self.tokenizer = hf_runner.tokenizer - self.dtype = hf_runner.model.dtype - - self.config = AutoConfig.from_pretrained(hf_runner.model_name) - self.vision_config = self.config.vision_config - self.use_thumbnail = self.config.use_thumbnail - self.min_num = self.config.min_dynamic_patch - self.max_num = self.config.max_dynamic_patch - self.image_size = self.vision_config.image_size - - def __call__(self, text: str, images: Image, **kwargs): - pixel_values = image_to_pixel_values(images, self.image_size, - self.min_num, self.max_num, - self.use_thumbnail).to(self.dtype) - num_patches_list = [pixel_values.shape[0]] - for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token * num_patches - image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('', image_tokens, 1) - prompt = self.tokenizer(text, return_tensors="pt") - prompt.update({"pixel_values": pixel_values}) - return prompt - - # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py def generate( self, @@ -127,6 +95,37 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). + class InternVLProcessor: + """A simple processor for InternVL2 which misses a processor.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + self.dtype = hf_runner.model.dtype + + self.config = AutoConfig.from_pretrained(hf_runner.model_name) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + + def __call__(self, text: str, images: Image, **kwargs): + from vllm.model_executor.models.internvl import ( + IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) + pixel_values = image_to_pixel_values( + images, self.image_size, self.min_num, self.max_num, + self.use_thumbnail).to(self.dtype) + num_patches_list = [pixel_values.shape[0]] + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + # max_model_len should be greater than image_feature_size with vllm_runner(model, max_model_len=4096, diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 93634f245cee..9d7da5f803ea 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -179,3 +179,20 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize("model", models) +def test_context_length_too_short(vllm_runner, image_assets, model): + images = [asset.pil_image for asset in image_assets] + + with pytest.raises(ValueError, match="too long to fit into the model"): + vllm_model = vllm_runner( + model, + max_model_len=128, # LLaVA has a feature size of 576 + enforce_eager=True, + ) + + with vllm_model: + vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], + max_tokens=1, + images=[images[0]]) diff --git a/tests/models/test_phimoe.py b/tests/models/test_phimoe.py new file mode 100644 index 000000000000..2fb2eecc9467 --- /dev/null +++ b/tests/models/test_phimoe.py @@ -0,0 +1,111 @@ +"""Compare the outputs of HF and vLLM for moe models using greedy sampling. + +Run `pytest tests/models/test_phimoe.py`. +""" +import pytest +import torch + +from vllm.utils import is_cpu + +from .utils import check_logprobs_close + +MODELS = [ + "microsoft/Phi-3.5-MoE-instruct", +] + + +def test_phimoe_routing_function(): + from vllm.model_executor.models.phimoe import phimoe_routing_function + test_case = { + 0: { + "hidden_states": + torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.float32, + requires_grad=False).view(4, 2), + "gating_output": + torch.tensor([0.1, 0.2, 0.3, 0.4], + dtype=torch.float32, + requires_grad=False), + "topk": + 2, + "renormalize": + False, + }, + 1: { + "hidden_states": + torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.float32, + requires_grad=False).view(4, 2), + "gating_output": + torch.tensor([0.4, 0.2, 0.3, 0.4], + dtype=torch.float32, + requires_grad=False), + "topk": + 2, + "renormalize": + False, + } + } + + ground_truth = { + 0: { + "topk_weights": + torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False), + "topk_ids": + torch.tensor([3, 2], dtype=torch.long, requires_grad=False), + }, + 1: { + "topk_weights": + torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False), + "topk_ids": + torch.tensor([0, 3], dtype=torch.long, requires_grad=False), + } + } + + for test_id in test_case: + topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id]) + assert torch.allclose(topk_weights, + ground_truth[test_id]["topk_weights"]) + assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) + + +def get_gpu_memory(): + try: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + gpu_memory = props.total_memory / (1024**3) + return gpu_memory + except Exception: + return 0 + + +@pytest.mark.skipif(condition=is_cpu(), + reason="This test takes a lot time to run on CPU, " + "and vllm CI's disk space is not enough for this model.") +@pytest.mark.skipif(condition=get_gpu_memory() < 100, + reason="Skip this test if GPU memory is insufficient.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_ultravox.py b/tests/models/test_ultravox.py index 98de10aa0840..23008f9b8b56 100644 --- a/tests/models/test_ultravox.py +++ b/tests/models/test_ultravox.py @@ -1,11 +1,9 @@ from typing import List, Optional, Tuple, Type -import librosa import numpy as np import pytest from transformers import AutoModel, AutoTokenizer, BatchEncoding -from vllm.assets.audio import AudioAsset from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -21,6 +19,7 @@ @pytest.fixture(scope="session") def audio_and_sample_rate(): + from vllm.assets.audio import AudioAsset return AudioAsset("mary_had_lamb").audio_and_sample_rate @@ -109,6 +108,7 @@ def process(hf_inputs: BatchEncoding): dtype=dtype, postprocess_inputs=process, auto_cls=AutoModel) as hf_model: + import librosa hf_outputs_per_audio = [ hf_model.generate_greedy_logprobs_limit( diff --git a/tests/models/utils.py b/tests/models/utils.py index ff29a0ae81d6..93ec03995094 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import SampleLogprobs +from vllm.sequence import Logprob, SampleLogprobs TokensText = Tuple[List[int], str] @@ -38,34 +38,39 @@ def check_outputs_equal( float]], SampleLogprobs]]] +# Allow for tokens to be represented as str's rather than IDs +TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], + List[Dict[str, + Logprob]]]]] + def check_logprobs_close( *, - outputs_0_lst: Sequence[TokensTextLogprobs], - outputs_1_lst: Sequence[TokensTextLogprobs], + outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], + outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, warn_on_mismatch: bool = True, -): - """ - Compare the logprobs of two sequences generated by different models, + always_check_logprobs: bool = False, +) -> None: + """Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. - Arguments: - - * outputs_0_lst: First sequence to compare - * outputs_0_lst: Second sequence to compare - * name_0: sequence #0 name - * name_1: sequence #1 name - * num_outputs_0_skip_tokens: If > 0, specifies the number of initial + Args: + outputs_0_lst: First sequence to compare + outputs_0_lst: Second sequence to compare + name_0: sequence #0 name + name_1: sequence #1 name + num_outputs_0_skip_tokens: If > 0, specifies the number of initial sequence #0 tokens & logprobs to discard before comparison, i.e. all of sequence #1 will be compared to sequence #0 beginning at index num_outputs_0_skip_tokens - * warn_on_mismatch: Issue a warning if there is token-wise or text-wise + warn_on_mismatch: Issue a warning if there is token-wise or text-wise mismatch between the two sequences + always_check_logprobs: If true, check logprobs even when tokens match """ assert len(outputs_0_lst) == len(outputs_1_lst) @@ -94,8 +99,12 @@ def check_logprobs_close( for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - # If generated tokens don't match, then - if output_id_0 != output_id_1: + is_tok_mismatch = output_id_0 != output_id_1 + + # If generated tokens don't match + # or it is desired to always check logprobs, + # then + if is_tok_mismatch or always_check_logprobs: logprobs_elem_0 = logprobs_0[idx] logprobs_elem_1 = logprobs_1[idx] @@ -111,7 +120,7 @@ def check_logprobs_close( assert output_id_0 in logprobs_elem_1, fail_msg assert output_id_1 in logprobs_elem_0, fail_msg - if warn_on_mismatch: + if warn_on_mismatch and is_tok_mismatch: with warnings.catch_warnings(): # This ensures that repeated warnings are shown # in the output, not just the first occurrence diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index ad99d70d7417..d054ca341694 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -1,10 +1,12 @@ # Test the AsyncLLMEngine with multi-step-decoding -from typing import List +from typing import List, Optional import pytest -from ..utils import RemoteOpenAIServer +from ..models.utils import check_logprobs_close +from ..utils import (completions_with_server_args, get_client_text_generations, + get_client_text_logprob_generations) MODELS = [ "JackFram/llama-160m", @@ -23,22 +25,6 @@ ] -async def completions_with_server_args(prompts: List[str], model_name: str, - server_cli_args: List[str]): - - outputs = None - with RemoteOpenAIServer(model_name, server_cli_args) as server: - async with server.get_async_client() as client: - outputs = await client.completions.create(model=model_name, - prompt=prompts, - temperature=0, - stream=False, - max_tokens=5) - assert outputs is not None - - return outputs - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize(("tp_size, pp_size"), [ (1, 1), @@ -47,10 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio -async def test_multi_step(example_prompts, model: str, tp_size: int, - pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int): +async def test_multi_step( + example_prompts, + model: str, + tp_size: int, + pp_size: int, + eager_mode: int, + num_scheduler_steps: int, + num_prompts: int, + is_async: bool, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step scheduling in an OpenAI-protocol + client/server environment. + + Set up an engine with single-step scheduling as a ground-truth reference. + + Send a completions API request to both engines with the same prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + tp_size: degree of tensor-parallelism + pp_size: degree of pipeline-parallelism + eager_mode + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + """ prompts = example_prompts if len(prompts) < num_prompts: @@ -62,9 +81,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] - # Disable output proc callback as its not supported - # with multi-step right now - ms_server_args += ["--disable-async-output-proc"] + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + if eager_mode: ms_server_args.append("--enforce-eager") @@ -75,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, str(pp_size), ] + # Spin up client/server & issue completion API requests. + # Default `max_wait_seconds` is 240 but was empirically + # was raised 3x to 720 *just for this test* due to + # observed timeouts in GHA CI ref_completions = await completions_with_server_args( - prompts, model, server_args + distributed_args) + prompts, + model, + server_args + distributed_args, + num_logprobs, + max_wait_seconds=3 * 240) test_completions = await completions_with_server_args( - prompts, model, ms_server_args + distributed_args) - - def get_text_generations(completions): - return [x.text for x in completions.choices] - - ref_generations = get_text_generations(ref_completions) - test_generations = get_text_generations(test_completions) + prompts, + model, + ms_server_args + distributed_args, + num_logprobs, + max_wait_seconds=3 * 240) + + # Assert multi-step scheduling produces identical tokens + # to single-step scheduling. + ref_generations = get_client_text_generations(ref_completions) + test_generations = get_client_text_generations(test_completions) assert ref_generations == test_generations + + # Assert multi-step scheduling produces nearly-identical logprobs + # to single-step scheduling. + ref_text_logprobs = get_client_text_logprob_generations(ref_completions) + test_text_logprobs = get_client_text_logprob_generations(test_completions) + check_logprobs_close( + outputs_0_lst=ref_text_logprobs, + outputs_1_lst=test_text_logprobs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 36f610ba74f0..50c85df932e2 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -1,8 +1,10 @@ # Test the LLMEngine with multi-step-decoding +from typing import Optional + import pytest -from ..models.utils import check_outputs_equal +from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ "JackFram/llama-160m", @@ -18,10 +20,45 @@ @pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, tp_size: int, max_tokens: int, - enforce_eager: int, num_scheduler_steps: int, - num_prompts: int) -> None: +@pytest.mark.parametrize("num_logprobs", [None, 5]) +def test_multi_step_llm( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step scheduling via sync LLM Engine. + + Set up a HuggingFace (HF) transformers model as a ground-truth reference. + + Prompt them with the same example prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + """ prompts = example_prompts if len(prompts) < num_prompts: @@ -29,21 +66,37 @@ def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, prompts = prompts[:num_prompts] assert len(prompts) == num_prompts - with vllm_runner(model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - use_v2_block_manager=True, - num_scheduler_steps=num_scheduler_steps) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs)) with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + hf_outputs = (hf_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + hf_model.generate_greedy_logprobs_limit( + prompts, max_tokens, num_logprobs)) + + if num_logprobs is None: + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + else: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index f19a0f33fe06..e9562d2048f0 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists(): result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])}) + + +def test_multimodal_input_batch_mixed_stacking_depths(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 3, 3]) + c = torch.rand([1, 4, 3]) + + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]}) + + result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}]) + assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]}) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index b760e9ccb6b7..3f0c6cbc051a 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -2,85 +2,115 @@ Run `pytest tests/quantization/test_bitsandbytes.py`. ''' + +import gc + import pytest import torch from tests.quantization.utils import is_quant_method_supported -from vllm import SamplingParams -models_to_test = [ +models_4bit_to_test = [ ('huggyllama/llama-7b', 'quantize model inflight'), - ('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'), ] +models_pre_qaunt_4bit_to_test = [ + ('lllyasviel/omost-llama-3-8b-4bits', + 'read pre-quantized 4-bit NF4 model'), + ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', + 'read pre-quantized 4-bit FP4 model'), +] + +models_pre_quant_8bit_to_test = [ + ('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'), +] + + +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = {"load_in_4bit": True} + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name, hf_model_kwargs) + + +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", + models_pre_qaunt_4bit_to_test) +def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name) + @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", models_to_test) -def test_load_bnb_model(vllm_runner, model_name, description) -> None: +@pytest.mark.parametrize("model_name, description", + models_pre_quant_8bit_to_test) +def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name) + + +def log_generated_texts(prompts, outputs, runner_name): + logged_texts = [] + for i, (_, generated_text) in enumerate(outputs): + log_entry = { + "prompt": prompts[i], + "runner_name": runner_name, + "generated_text": generated_text, + } + logged_texts.append(log_entry) + return logged_texts + + +def validate_generated_texts(hf_runner, + vllm_runner, + prompts, + model_name, + hf_model_kwargs=None): + + if hf_model_kwargs is None: + hf_model_kwargs = {} + + # Run with HF runner + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + hf_outputs = llm.generate_greedy(prompts, 8) + hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") + + # Clean up the GPU memory for the next test + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + #Run with vLLM runner with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', - enforce_eager=True) as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 - - # check the weights in MLP & SelfAttention are quantized to torch.uint8 - qweight = model.model.layers[0].mlp.gate_up_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].mlp.down_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.o_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.qkv_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') - - # some weights should not be quantized - weight = model.lm_head.weight - assert weight.dtype != torch.uint8, ( - 'lm_head weight dtype should not be torch.uint8') - - weight = model.model.embed_tokens.weight - assert weight.dtype != torch.uint8, ( - 'embed_tokens weight dtype should not be torch.uint8') - - weight = model.model.layers[0].input_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - weight = model.model.layers[0].post_attention_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - # check the output of the model is expected - sampling_params = SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=8) - - prompts = ['That which does not kill us', 'To be or not to be,'] - expected_outputs = [ - 'That which does not kill us makes us stronger.', - 'To be or not to be, that is the question.' - ] - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(outputs) == len(prompts) - - for index in range(len(outputs)): - # compare the first line of the output - actual_output = outputs[index][1][0].split('\n', 1)[0] - expected_output = expected_outputs[index].split('\n', 1)[0] - - assert len(actual_output) >= len(expected_output), ( - f'Actual {actual_output} should be larger than or equal to ' - f'expected {expected_output}') - actual_output = actual_output[:len(expected_output)] - - assert actual_output == expected_output, ( - f'Expected: {expected_output}, but got: {actual_output}') + enforce_eager=True, + gpu_memory_utilization=0.8) as llm: + vllm_outputs = llm.generate_greedy(prompts, 8) + vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") + + # Clean up the GPU memory for the next test + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + # Compare the generated strings + for hf_log, vllm_log in zip(hf_logs, vllm_logs): + hf_str = hf_log["generated_text"] + vllm_str = vllm_log["generated_text"] + prompt = hf_log["prompt"] + assert hf_str == vllm_str, (f"Model: {model_name}" + f"Mismatch between HF and vLLM outputs:\n" + f"Prompt: {prompt}\n" + f"HF Output: '{hf_str}'\n" + f"vLLM Output: '{vllm_str}'") diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 3ce4a5f65819..91a9d879eb4a 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -44,12 +44,16 @@ def mock_causal_accepted_tensor( ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, - disable_bonus_tokens: bool, seed: int, - device: str): +def test_correct_output_format(which_tokens_accepted: str, seed: int, + disable_bonus_tokens: bool, device: str, + use_flashinfer: bool): """Verify the output has correct format given predetermined accepted matrix. """ + if use_flashinfer and disable_bonus_tokens: + pytest.skip("Flashinfer rejection sampler must enable bonus token.") + set_random_seed(seed) torch.set_default_device(device) @@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str, dtype=torch.int64) rejection_sampler = RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens) + disable_bonus_tokens=disable_bonus_tokens, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str, @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", list(range(1, 32))) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, - device: str): + device: str, use_flashinfer: bool): torch.set_default_device(device) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, @pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("n_rep", [100]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, - frac_seeded: float, n_rep: int, - device: str): + frac_seeded: float, n_rep: int, device: str, + use_flashinfer: bool): torch.set_default_device(device) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, assert torch.equal(results[j][i], results[0][i]) +@pytest.mark.parametrize("k", [1, 3, 6]) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_compare_nonflashinfer_backend(k: int, vocab_size: int, + batch_size: int, device: str): + """ + Test the flashinfer and nonflashinfer backend generate + the same output metrics. + """ + torch.set_default_device(device) + torch.manual_seed(0) + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + + num_accepted_tokens = [] + num_emitted_tokens = [] + num_draft_tokens = [] + + def get_seeded_seqs(): + return { + i: torch.Generator(device=device).manual_seed(i) + for i in range(batch_size) + } + + for use_flashinfer in [True, False]: + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) + rejection_sampler.init_gpu_tensors(device=device) + # We use seeded sequences to ensure the same tokens are accepted + # for both flashinfer and nonflashinfer backends. + seeded_seqs = get_seeded_seqs() + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids, seeded_seqs) + num_accepted_tokens.append(rejection_sampler.num_accepted_tokens) + num_emitted_tokens.append(rejection_sampler.num_emitted_tokens) + num_draft_tokens.append(rejection_sampler.num_draft_tokens) + + assert num_accepted_tokens[0] == num_accepted_tokens[1] + assert num_emitted_tokens[0] == num_emitted_tokens[1] + assert num_draft_tokens[0] == num_draft_tokens[1] + + @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @pytest.mark.parametrize("which_token_ids", ["bonus_token_ids", "draft_token_ids"]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_raises_when_vocab_oob(above_or_below_vocab_range: str, - which_token_ids: str, device: str): + which_token_ids: str, device: str, + use_flashinfer: bool): k = 3 batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - rejection_sampler = RejectionSampler(strict_mode=True) + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer, + strict_mode=True) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) @pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_rejection_sampling_approximates_target_distribution( - seed: int, draft_and_target_probs_equal: bool): + seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool): """Verify rejection sampling approximates target distribution, despite sampling from a potentially distinct draft distribution. @@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution( """ torch.set_default_device("cpu") set_random_seed(seed) - helper = _CorrectnessTestHelper( vocab_size=10, - rejection_sampler=RejectionSampler(), + rejection_sampler=RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer), ) draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( @@ -398,10 +476,10 @@ def _estimate_rejection_sampling_pdf( draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( num_samples, 1, 1) - # Repeat target probs num_samples * k times. + # Repeat target probs num_samples * (k + 1) times. # Rejection sampler requires bonus token probs, but they aren't used. target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( - num_samples, self.k, 1) + num_samples, self.k + 1, 1) # Randomly sample draft token ids from draft probs. draft_token_ids = torch.multinomial(draft_probs[:, 0, :], diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index aa3c1d29bdb3..e81ec4a0fdf1 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler() typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, size=(batch_size, k), dtype=torch.int64) # Verify that sampling succeeds for all cases. - typical_acceptance_sampler(target_probs, + typical_acceptance_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, oob_token_ids[0][0] = rogue_token_id with pytest.raises(AssertionError): - typical_acceptance_sampler(target_probs, + typical_acceptance_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens( typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), @@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens( size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int, # Simulate temperature 0 probability distribution for target probabilities # and create target probabilities such that only 1 token id has # probability 1.0 - target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( - batch_size, k, vocab_size) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] # Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0 draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, @@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int, # fallback to the greedy sampling for selecting 1 token for each sequence. # Verify the same. output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, # For sequences 0 and 2 set the distribution to a temperature # zero distribution. For sequences 1 and 3 set it to a uniform # distribution. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] + target_probs = target_with_bonus_probs[:, :-1] draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) @@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, # Create a temperature zero target probability distribution and ensure # all draft token ids correspond to the tokens with 1.0 probability. # Verify that all of them are accepted. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] draft_token_ids = zero_temperature_token_ids bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, draft_token_ids = torch.cat( (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, # 0.00001. Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0. Without any changes to the posterior thresholds # none of the draft tokens are accepted. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( + batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] target_probs[target_probs == 0] = 0.00001 draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index ada6c37d9af8..e7a0af437763 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,9 +5,10 @@ import pytest import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, - SamplerOutput, get_all_seq_ids) + get_all_seq_ids) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 9ae1b4bc40f0..501d05756e01 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -7,8 +7,9 @@ import pytest import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput +from vllm.sequence import ExecuteModelRequest, SequenceOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -229,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) - assert torch.equal( - actual.target_probs, - target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) + assert torch.equal(actual.target_with_bonus_probs, + target_token_probs.reshape(batch_size, k + 1, -1)) assert torch.equal(actual.draft_token_ids, proposal_token_ids) assert torch.equal(actual.draft_probs, proposal_probs) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 06780d4b8cd0..195fce64822b 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -4,10 +4,12 @@ import torch from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import _get_ranks from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids -from vllm.spec_decode.util import split_batch_by_proposal_len +from vllm.spec_decode.util import (get_sampled_token_logprobs, + split_batch_by_proposal_len) def test_get_all_seq_ids(): @@ -126,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method): return sampler else: raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") + + +def test_get_sampled_token_logprobs(): + """Verify get_sampled_token_logprobs returns consistent rankings + with regular get_ranks when probabilities match exactly. + """ + logprob_tensor = torch.tensor( + [[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size) + sampled_token_tensor = torch.tensor([[1, + 0]]) # shape (num_steps, batch_size) + ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor, + sampled_token_tensor) + + ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)), + sampled_token_tensor.reshape(-1)) + + assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 60b36a33d907..9075a433eb66 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -8,12 +8,12 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceData, SequenceGroupMetadata, - SequenceOutput) + SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 1ae349e808e0..348ba7dd41d9 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,9 +2,10 @@ import pytest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, SamplerOutput, - SequenceData, SequenceOutput) + CompletionSequenceGroupOutput, SequenceData, + SequenceOutput) from .core.utils import create_dummy_prompt diff --git a/tests/test_utils.py b/tests/test_utils.py index c157be1c08f8..c7cb663068c0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -132,6 +132,16 @@ def parser(): return parser +@pytest.fixture +def parser_with_config(): + parser = FlexibleArgumentParser() + parser.add_argument('serve') + parser.add_argument('--config', type=str) + parser.add_argument('--port', type=int) + parser.add_argument('--tensor-parallel-size', type=int) + return parser + + def test_underscore_to_dash(parser): args = parser.parse_args(['--image_input_type', 'pixel_values']) assert args.image_input_type == 'pixel_values' @@ -176,3 +186,37 @@ def test_missing_required_argument(parser): parser.add_argument('--required-arg', required=True) with pytest.raises(SystemExit): parser.parse_args([]) + + +def test_cli_override_to_config(parser_with_config): + args = parser_with_config.parse_args([ + 'serve', '--config', './data/test_config.yaml', + '--tensor-parallel-size', '3' + ]) + assert args.tensor_parallel_size == 3 + args = parser_with_config.parse_args([ + 'serve', '--tensor-parallel-size', '3', '--config', + './data/test_config.yaml' + ]) + assert args.tensor_parallel_size == 3 + + +def test_config_args(parser_with_config): + args = parser_with_config.parse_args( + ['serve', '--config', './data/test_config.yaml']) + assert args.tensor_parallel_size == 2 + + +def test_config_file(parser_with_config): + with pytest.raises(FileNotFoundError): + parser_with_config.parse_args(['serve', '--config', 'test_config.yml']) + + with pytest.raises(ValueError): + parser_with_config.parse_args( + ['serve', '--config', './data/test_config.json']) + + with pytest.raises(ValueError): + parser_with_config.parse_args([ + 'serve', '--tensor-parallel-size', '3', '--config', '--batch-size', + '32' + ]) diff --git a/tests/tpu/__init__.py b/tests/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 5a432fb78b3d..d8df86b2aaa1 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,6 +5,10 @@ import depyf +# disable custom dispatcher, let Dynamo takes over +# all the control +os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0" + temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): cur_dir = os.path.dirname(__file__) @@ -16,19 +20,36 @@ compiled_code = sorted( glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) -full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0] + # we should only trigger Dynamo compilation three times: -# one for the profiling phase (and the compiled artifact will be discarded) +# one for the profiling phase without kv cache # one for the prefill phase with symbolic shapes # one for the decode phase with symbolic shapes # and later calls should not trigger Dynamo compilation again. # NOTE: it might still trigger XLA compilation. # check we have three compiled code +# this is the assumption when we use the custom dispatcher assert len(compiled_code) == 3 -# check the first compilation is discarded -with open(full_code) as f: - full_code_content = f.read() - profile_function = compiled_code[0].split(".")[0] - assert profile_function not in full_code_content +# check all the compilations are as expected +compiled_fn = sorted( + glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) + +# the first compilation is the profiling phase, +# it should not have any kv cache +with open(compiled_fn[0]) as f: + content = f.read() + assert "kv_caches" not in content + +# the second compilation is the prefill phase, +# it should have kv cache and the flash_attention op +with open(compiled_fn[1]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content + +# the third compilation is the decode phase, +# it should have kv cache and the paged_attention op +with open(compiled_fn[2]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py new file mode 100644 index 000000000000..7f3fb595321a --- /dev/null +++ b/tests/tpu/test_custom_dispatcher.py @@ -0,0 +1,9 @@ +from ..utils import compare_two_settings + + +def test_custom_dispatcher(): + compare_two_settings("google/gemma-2b", + arg1=["--enforce-eager"], + arg2=["--enforce-eager"], + env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, + env2={}) diff --git a/tests/utils.py b/tests/utils.py index 7f181a21b4de..f33340f0c755 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,9 +11,11 @@ import openai import requests +from openai.types.completion import Completion from transformers import AutoTokenizer from typing_extensions import ParamSpec +from tests.models.utils import TextTextLogprobs from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs @@ -432,3 +434,61 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: f" args {args} and kwargs {kwargs}") return wrapper + + +async def completions_with_server_args( + prompts: List[str], + model_name: str, + server_cli_args: List[str], + num_logprobs: Optional[int], + max_wait_seconds: int = 240, +) -> Completion: + '''Construct a remote OpenAI server, obtain an async client to the + server & invoke the completions API to obtain completions. + + Args: + prompts: test prompts + model_name: model to spin up on the vLLM server + server_cli_args: CLI args for starting the server + num_logprobs: Number of logprobs to report (or `None`) + max_wait_seconds: timeout interval for bringing up server. + Default: 240sec + + Returns: + OpenAI Completion instance + ''' + + outputs = None + with RemoteOpenAIServer(model_name, + server_cli_args, + max_wait_seconds=max_wait_seconds) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5, + logprobs=num_logprobs) + assert outputs is not None + + return outputs + + +def get_client_text_generations(completions: Completion) -> List[str]: + '''Extract generated tokens from the output of a + request made to an Open-AI-protocol completions endpoint. + ''' + return [x.text for x in completions.choices] + + +def get_client_text_logprob_generations( + completions: Completion) -> List[TextTextLogprobs]: + '''Operates on the output of a request made to an Open-AI-protocol + completions endpoint; obtains top-rank logprobs for each token in + each :class:`SequenceGroup` + ''' + text_generations = get_client_text_generations(completions) + text = ''.join(text_generations) + return [(text_generations, text, + (None if x.logprobs is None else x.logprobs.top_logprobs)) + for x in completions.choices] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ae90af563c0c..fe254732e730 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm._core_ext import ScalarType from vllm.logger import init_logger from vllm.platforms import current_platform @@ -177,12 +178,20 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton) + return awq_dequantize_triton(qweight, scales, zeros) return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_gemm_triton) + return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) @@ -491,6 +500,36 @@ def ggml_mul_mat_a8( return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + seq_idx_: Optional[torch.Tensor], + initial_states_: Optional[torch.Tensor], + final_states_out_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, + initial_states_, final_states_out_, + silu_activation) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, index_: Optional[torch.Tensor], + x: Optional[torch.Tensor]) -> List[torch.Tensor]: + return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, + delta_bias_, delta_softplus, index_, + x) + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a8d76b79ff20..aa9d4a71dbf8 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -83,6 +83,15 @@ def copy_blocks( def get_supported_head_sizes() -> List[int]: return [64, 128, 256] + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + class FlashInferState(AttentionState): @@ -177,8 +186,12 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_indices_buffer, _last_page_len_buffer, "NHD", use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) paged_kv_indptr_tensor_host = torch.arange(0, batch_size + 1, @@ -366,7 +379,8 @@ def prefill_metadata(self) -> Optional["FlashInferMetadata"]: def decode_metadata(self) -> Optional["FlashInferMetadata"]: # Currently chunked prefill is not supported if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + assert self.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") return None return self @@ -576,8 +590,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + return FlashInferMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -661,7 +680,6 @@ def forward( if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( "Chunked prefill is not supported with flashinfer yet.") - if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( @@ -674,6 +692,12 @@ def forward( k_scale, v_scale, ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) query = query.contiguous( ) # Flashinfer requires query to be contiguous @@ -711,5 +735,7 @@ def forward( query, kv_cache, sm_scale=self.scale, - logits_soft_cap=self.logits_soft_cap) + logits_soft_cap=self.logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index ac03b6d8b1ea..83fdef16ef5c 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -123,7 +123,13 @@ def __init__( raise NotImplementedError("TPU version must be 4 or higher.") self.megacore_mode = None - tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + if "lite" not in tpu_type: if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 54558fc2d7e5..855586d4e596 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -226,6 +226,10 @@ def which_attn_to_use( elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): logger.info( "Cannot use FlashAttention-2 backend for FP8 KV cache.") + logger.warning( + "Please use FlashInfer backend with FP8 KV Cache for " + "better performance by setting environment variable " + "VLLM_ATTENTION_BACKEND=FLASHINFER") selected_backend = _Backend.XFORMERS elif block_size % 16 != 0: logger.info( diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py new file mode 100644 index 000000000000..c3d863299dd0 --- /dev/null +++ b/vllm/compilation/wrapper.py @@ -0,0 +1,81 @@ +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, List + +import torch + +import vllm.envs as envs + + +class TorchCompileWrapperWithCustomDispacther: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, compiled_callable: Callable): + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/config.py b/vllm/config.py index 0339024296d7..66889bf0afb9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -32,6 +32,7 @@ logger = init_logger(__name__) _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _PP_SUPPORTED_MODELS = [ "AquilaModel", "AquilaForCausalLM", "DeepseekV2ForCausalLM", @@ -258,7 +259,7 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] + rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"] optimized_quantization_methods = [ "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", @@ -313,6 +314,12 @@ def _verify_quantization(self) -> None: "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) + if (self.quantization == "awq" and is_hip() + and not envs.VLLM_USE_TRITON_AWQ): + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: @@ -332,10 +339,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if device_config.device_type != "cuda": + if device_config.device_type not in ("cuda", "tpu"): logger.warning( - "Async output processing is only supported for CUDA." - " Disabling it for other platforms.") + "Async output processing is only supported for CUDA or TPU. " + "Disabling it for other platforms.") self.use_async_output_proc = False return @@ -390,6 +397,8 @@ def verify_with_parallel_config( raise ValueError( "BitAndBytes quantization with TP or PP is not supported yet.") + # Remove the constraint after the bitsandbytes issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 if self.quantization == "bitsandbytes" and self.enforce_eager is False: logger.warning("CUDA graph is not supported on BitAndBytes yet, " "fallback to the eager mode.") @@ -554,6 +563,10 @@ def is_embedding_model(self) -> bool: """Extract the embedding model flag.""" return self.embedding_mode + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + class CacheConfig: """Configuration for the KV cache. @@ -930,25 +943,36 @@ def __init__(self, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, - embedding_mode: Optional[bool] = False, + embedding_mode: bool = False, + is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, send_delta_data: bool = False) -> None: - if max_num_batched_tokens is not None: - self.max_num_batched_tokens = max_num_batched_tokens - else: + if max_num_batched_tokens is None: if enable_chunked_prefill: # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. - self.max_num_batched_tokens = 512 - elif embedding_mode: - # For embedding, choose specific value for higher throughput - self.max_num_batched_tokens = max( - max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) + max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + max_num_batched_tokens = max(max_model_len, 2048) + + if embedding_mode: + # For embedding, choose specific value for higher throughput + max_num_batched_tokens = max( + max_num_batched_tokens, + _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + max_num_batched_tokens = max( + max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + self.max_num_batched_tokens = max_num_batched_tokens + if enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 51fde6e4eb7a..4c2f71582031 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1107,10 +1107,7 @@ def schedule( if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] - # TODO: Combine multi-step and async postprocessor - allow_async_output_proc: bool = ( - self.use_async_output_proc - and not self.scheduler_config.is_multi_step) + allow_async_output_proc: bool = self.use_async_output_proc # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 81a141e86206..765a0f9cb1c8 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,3 +1,5 @@ +import os + import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -5,11 +7,12 @@ from vllm.platforms import current_platform if current_platform.is_tpu(): - import ray import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt + from vllm.executor import ray_utils + class TpuCommunicator: @@ -24,9 +27,29 @@ def __init__(self, group: ProcessGroup): # be simply calculated as follows. global_rank = dist.get_rank(group) global_world_size = dist.get_world_size(group) - num_nodes = len(ray.nodes()) + + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6e66198e203f..8dbe6504d21b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,7 @@ from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import FlexibleArgumentParser if TYPE_CHECKING: @@ -753,7 +754,7 @@ def from_cli_args(cls, args: argparse.Namespace): def create_engine_config(self) -> EngineConfig: # gguf file needs a specific model loader and doesn't use hf_repo - if self.model.endswith(".gguf"): + if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" # bitsandbytes quantization needs a specific model loader @@ -921,6 +922,7 @@ def create_engine_config(self) -> EngineConfig: delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, + is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37696bf1d9dc..159281dabde4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -22,11 +22,12 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -279,6 +280,10 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # skip the scheduler if there are any remaining steps in the seq groups. @@ -289,17 +294,27 @@ async def step_async( # Clear outputs on scheduler iteration start ctx.request_outputs.clear() + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # If current scheduler iteration has no async postprocessor, - # then we need first to drain the pending async postprocessor - # before moving forward + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -311,9 +326,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -339,8 +351,13 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step # Execute the model. output = await self.model_executor.execute_model_async( @@ -350,7 +367,7 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -362,22 +379,25 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(virtual_engine=virtual_engine, @@ -390,7 +410,11 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - ctx.request_outputs = [] + # Multi-step case + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a6de8817946c..1eab83f3b988 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -33,6 +33,7 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) @@ -40,8 +41,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + Sequence, SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -91,7 +92,8 @@ class SchedulerOutputState: @dataclass class SchedulerContext: - output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata], + output_queue: Deque[Tuple[Optional[List[SamplerOutput]], + List[SequenceGroupMetadata], SchedulerOutputs]] = field( default_factory=lambda: deque()) @@ -432,6 +434,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for v_id in range(self.parallel_config.pipeline_parallel_size) ] + self.async_callback_multi_step = [ + functools.partial(self._process_model_outputs, + virtual_engine=v_id, + is_async=False) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1240,28 +1249,49 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, virtual_engine: int, - is_async: bool) -> None: + def _process_model_outputs(self, + virtual_engine: int, + is_async: bool, + sampler_output: Optional[SamplerOutput] = None, + is_last_output: bool = False) -> None: """Apply the model output to the sequences in the scheduled seq groups. virtual_engine: The engine id to operate on + is_async: Indicates whether this postprocessor runs in parallel with the GPU forward pass and is processing tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) + sampler_output: Used with multi-step execution to provide + sampler_output of each step + is_last_output: Used with multi-step execution to indicate + the last step (of each multi-step group) + Returns RequestOutputs that can be returned to the client. """ now = time.time() + is_multi_step = sampler_output is not None + ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] if len(ctx.output_queue) == 0: return None - (outputs, seq_group_metadata_list, - scheduler_outputs) = ctx.output_queue.popleft() + if is_multi_step: + # Async + multi-step case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue[0] + assert outputs is None + outputs = [sampler_output] + else: + # Async standard case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue.popleft() + + assert outputs is not None # Sanity check assert len(seq_group_metadata_list) == len( @@ -1320,7 +1350,11 @@ def _process_model_outputs(self, virtual_engine: int, self.output_processor.process_outputs(seq_group, output, is_async) - # Free the finished sequence groups. + # For async + multi-step, free finished seqs and create outputs + # only on the final step. + if is_multi_step and not is_last_output: + return + for scheduler in self.scheduler: scheduler.free_finished_seq_groups() @@ -1328,7 +1362,7 @@ def _process_model_outputs(self, virtual_engine: int, for i, _ in enumerate(seq_group_metadata_list): scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - if i in finished_before: + if not is_multi_step and i in finished_before: continue # Avoids double processing seq_group = scheduled_seq_group.seq_group @@ -1342,7 +1376,11 @@ def _process_model_outputs(self, virtual_engine: int, request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) - if is_async: + # For async + multi-step, do stats only on the last output. + # Otherwise, do stats if the execution is async + do_stats = is_multi_step or is_async + + if do_stats: # Log stats. self.do_log_stats(scheduler_outputs, outputs, finished_before) @@ -1437,7 +1475,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "as performance will be severely degraded otherwise.") # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0 + # used is always 0. virtual_engine = 0 # These are cached outputs from previous iterations. None if on first @@ -1447,6 +1485,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # Skip the scheduler if there are any remaining steps in the seq groups. @@ -1462,11 +1504,22 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -1478,9 +1531,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -1505,8 +1555,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step output = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -1518,7 +1573,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -1535,18 +1590,23 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() - # Add results to the output_queue - # (for async or non-async postprocessing) - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + # Add results to the output_queue + # (for async or non-async postprocessing) + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len(output) == 1, ("Multi step decoding does not work " - "with async output processing.") + if output and allow_async_output_proc: + assert len(output) == 1, ( + "Multi step decoding does not work " + "with async output processing.") - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path if not allow_async_output_proc: @@ -1560,7 +1620,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.do_tracing(scheduler_outputs) else: # Multi-step case - ctx.request_outputs = [] + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) @@ -1948,7 +2011,26 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - prompt_key = "encoder_prompt_token_ids" \ - if self.is_encoder_decoder_model() else "prompt_token_ids" - if not inputs.get(prompt_key): + if self.is_encoder_decoder_model(): + prompt_ids = inputs.get("encoder_prompt_token_ids") + else: + prompt_ids = inputs.get("prompt_token_ids") + + if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 49a33ded5fca..e182cee8ba18 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -4,6 +4,8 @@ from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) +from vllm.engine.output_processor.single_step import ( + single_step_process_prompt_logprob) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams @@ -46,9 +48,16 @@ def __init__( def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: - # TODO(sang): Prompt logprob currently not implemented in multi step - # workers. - self._log_prompt_logprob_unsupported_warning_once() + """Process prompt logprobs associated with each step of a multi-step- + scheduled computation. + + Args: + seq_group: the outputs are associated with this :class:`SequenceGroup` + outputs: the :class:`SequenceGroupOutput`s for all scheduler steps + """ + for output in outputs: + # Concatenate single-step prompt logprob processing results. + single_step_process_prompt_logprob(self, seq_group, output) @staticmethod @functools.lru_cache() @@ -79,9 +88,15 @@ def process_outputs(self, # TODO: Add support for async if necessary assert not is_async + # Sequences can be in RUNNING or FINISHED_ABORTED state + # once scheduled, as a sequence is moved to FINSIHED_ABORTED + # if a client disconnects from the api server. seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + if seqs is None: + seqs = sequence_group.get_seqs( + status=SequenceStatus.FINISHED_ABORTED) - assert seqs, "expected running sequences" + assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" assert len(seqs) == 1, ( "Beam search not supported in multi-step decoding.") seq = seqs[0] diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4b0c3f37a5e2..422e6d30522f 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -15,6 +15,44 @@ logger = init_logger(__name__) +def single_step_process_prompt_logprob( + sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, + output: SequenceGroupOutput) -> None: + """Process prompt logprobs associated with the :class:`SequenceGroupOutput` + for a given step. + + Do nothing if the output has no prompt logprobs. + + Account for the fact that transformers do not compute first-token logprobs. + + Args: + sg_output_proc: :class:`SequenceGroupOutputProcessor` instance + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ + prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. + if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + + assert hasattr(sg_output_proc, 'detokenizer') + if (seq_group.sampling_params.detokenize + and sg_output_proc.detokenizer): + sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + + seq_group.prompt_logprobs.extend(prompt_logprobs) + + class SingleStepOutputProcessor(SequenceGroupOutputProcessor): """SequenceGroupOutputProcessor which handles "output processing" logic, which happens after the model returns generated token ids and before @@ -60,27 +98,16 @@ def process_outputs(self, sequence_group: SequenceGroup, def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with one step of a single-step- + scheduled computation. + + Args: + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] - prompt_logprobs = output.prompt_logprobs - - # If this is the first (or only) "chunk" of the prefill, we need - # to prepend None to the list of prompt logprobs. The reason for this - # is that for N prompt tokens, the Sampler will generate N-1 total - # prompt logprobs during prefill since the token at idx 0 will not - # have a logprob associated with it. - if prompt_logprobs is not None: - if not seq_group.prompt_logprobs: - prompt_logprobs = [None] + prompt_logprobs - seq_group.prompt_logprobs = [] - - if seq_group.sampling_params.detokenize and self.detokenizer: - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, - prompt_logprobs, - position_offset=len(seq_group.prompt_logprobs)) - - seq_group.prompt_logprobs.extend(prompt_logprobs) + single_step_process_prompt_logprob(self, seq_group, output) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput, diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 57cc33d91118..76782888031e 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -2,7 +2,8 @@ from typing import Sequence as GenericSequence from typing import Union -from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import PoolerOutput, SequenceGroupOutput def create_output_by_sequence_group( diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 1deb75167bc7..34ae79f5fa8d 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -5,11 +5,11 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput from vllm.transformers_utils.tokenizer import AnyTokenizer diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c5368ac3bf02..c70c6d9330b1 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,9 +1,10 @@ +import asyncio import codecs -from dataclasses import dataclass +from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, - Union) +from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping, + Optional, Tuple, Union) # yapf conflicts with isort for this block # yapf: disable @@ -80,10 +81,90 @@ class ConversationMessage(TypedDict): content: str -@dataclass(frozen=True) -class ChatMessageParseResult: - messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] +class MultiModalItemTracker: + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + self._model_config = model_config + self._tokenizer = tokenizer + self._allowed_items = (model_config.multimodal_config.limit_per_prompt + if model_config.multimodal_config else {}) + self._consumed_items = {k: 0 for k in self._allowed_items} + self._futures: List[Awaitable[MultiModalDataDict]] = [] + + @staticmethod + @lru_cache(maxsize=None) + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int): + return tokenizer.decode(token_index) + + def add(self, modality: Literal["image", "audio"], + mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]: + """ + Adds the multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = self._consumed_items.get(modality, 0) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._consumed_items[modality] = current_count + self._futures.append(mm_future) + + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + model_type = self._model_config.hf_config.model_type + if modality == "image": + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return f"<|image_{current_count}|>" + if model_type == "minicpmv": + return "(./)" + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): + # These models do not use image tokens in the prompt + return None + if model_type.startswith("llava"): + return MultiModalItemTracker._cached_token_str( + self._tokenizer, + self._model_config.hf_config.image_token_index) + if model_type in ("chameleon", "internvl_chat"): + return "" + + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "audio": + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") + else: + raise TypeError(f"Unknown modality: {modality}") + + @staticmethod + async def _combine(futures: List[Awaitable[MultiModalDataDict]]): + mm_lists: Mapping[str, List[object]] = defaultdict(list) + + # Merge all the multi-modal items + for single_mm_data in (await asyncio.gather(*futures)): + for mm_key, mm_item in single_mm_data.items(): + if isinstance(mm_item, list): + mm_lists[mm_key].extend(mm_item) + else: + mm_lists[mm_key].append(mm_item) + + # Unpack any single item lists for models that don't expect multiple. + return { + mm_key: mm_list[0] if len(mm_list) == 1 else mm_list + for mm_key, mm_list in mm_lists.items() + } + + def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]: + return MultiModalItemTracker._combine( + self._futures) if self._futures else None def load_chat_template( @@ -112,44 +193,30 @@ def load_chat_template( return resolved_chat_template -@lru_cache(maxsize=None) -def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, - modality: Literal["image", "audio"]) -> Optional[str]: - # TODO: Let user specify how to insert image tokens into prompt - # (similar to chat template) - model_type = model_config.hf_config.model_type - if modality == "image": - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer - return "<|image_1|>" - if model_type == "minicpmv": - return "(./)" - if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): - # These models do not use image tokens in the prompt - return None - if model_type.startswith("llava"): - return tokenizer.decode(model_config.hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat"): - return "" - - raise TypeError(f"Unknown model type: {model_type}") - elif modality == "audio": - if model_type == "ultravox": - return "<|reserved_special_token_0|>" - raise TypeError(f"Unknown model type: {model_type}") - else: - raise TypeError(f"Unknown modality: {modality}") - - # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_token_str: str, +def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], text_prompt: str) -> str: """Combine multimodal prompts for a multimodal language model""" - # NOTE: For now we assume all model architectures use the same - # placeholder + text prompt format. This may change in the future. - return f"{placeholder_token_str}\n{text_prompt}" + # Look through the text prompt to check for missing placeholders + missing_placeholders = [] + for placeholder in placeholder_counts: + + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") + + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) _TextParser = TypeAdapter(ChatCompletionContentPartTextParam) @@ -160,12 +227,12 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str, def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], - model_config: ModelConfig, - tokenizer: AnyTokenizer, -) -> ChatMessageParseResult: + mm_tracker: MultiModalItemTracker, +) -> List[ConversationMessage]: texts: List[str] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - modality: Literal["image", "audio"] = "image" + + # multimodal placeholder_string : count + mm_placeholder_counts: Dict[str, int] = {} for part in parts: part_type = part["type"] @@ -173,11 +240,6 @@ def _parse_chat_message_content_parts( text = _TextParser.validate_python(part)["text"] texts.append(text) elif part_type == "image_url": - modality = "image" - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple multimodal inputs is currently not supported.") - image_url = _ImageParser.validate_python(part)["image_url"] if image_url.get("detail", "auto") != "auto": @@ -185,60 +247,44 @@ def _parse_chat_message_content_parts( "'image_url.detail' is currently not supported and " "will be ignored.") - image_future = async_get_and_parse_image(image_url["url"]) - mm_futures.append(image_future) + image_coro = async_get_and_parse_image(image_url["url"]) + placeholder = mm_tracker.add("image", image_coro) + if placeholder: + mm_placeholder_counts[placeholder] = mm_placeholder_counts.get( + placeholder, 0) + 1 elif part_type == "audio_url": - modality = "audio" - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple multimodal inputs is currently not supported.") - audio_url = _AudioParser.validate_python(part)["audio_url"] - audio_future = async_get_and_parse_audio(audio_url["url"]) - mm_futures.append(audio_future) + audio_coro = async_get_and_parse_audio(audio_url["url"]) + placeholder = mm_tracker.add("audio", audio_coro) + if placeholder: + mm_placeholder_counts[placeholder] = mm_placeholder_counts.get( + placeholder, 0) + 1 else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) - if mm_futures: - placeholder_token_str = _mm_token_str(model_config, tokenizer, - modality) - if placeholder_token_str is not None: - if placeholder_token_str in text_prompt: - logger.warning( - "Detected multi-modal token string in the text prompt. " - "Skipping prompt formatting.") - else: - text_prompt = _get_full_multimodal_text_prompt( - placeholder_token_str=placeholder_token_str, - text_prompt=text_prompt, - ) - - messages = [ConversationMessage(role=role, content=text_prompt)] - - return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) + return [ConversationMessage(role=role, content=text_prompt)] def _parse_chat_message_content( - message: ChatCompletionMessageParam, - model_config: ModelConfig, - tokenizer: AnyTokenizer, -) -> ChatMessageParseResult: + message: ChatCompletionMessageParam, + mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]: role = message["role"] content = message.get("content") if content is None: - return ChatMessageParseResult(messages=[], mm_futures=[]) + return [] if isinstance(content, str): - messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages, mm_futures=[]) + return [ConversationMessage(role=role, content=content)] return _parse_chat_message_content_parts( role, content, # type: ignore - model_config, - tokenizer, + mm_tracker, ) @@ -246,18 +292,16 @@ def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, -) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: +) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) for msg in messages: - parse_result = _parse_chat_message_content(msg, model_config, - tokenizer) + sub_messages = _parse_chat_message_content(msg, mm_tracker) - conversation.extend(parse_result.messages) - mm_futures.extend(parse_result.mm_futures) + conversation.extend(sub_messages) - return conversation, mm_futures + return conversation, mm_tracker.all_mm_data() def apply_chat_template( diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index a472e12e8ca4..c457555c54b9 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,11 +1,13 @@ import asyncio +import pickle from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Mapping, Optional +from typing import Any, AsyncGenerator, Iterator, Mapping, Optional from uuid import uuid4 import cloudpickle import zmq import zmq.asyncio +from zmq.asyncio import Socket from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -115,18 +117,21 @@ def __init__(self, rpc_path: str): self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server = self.context.socket(zmq.constants.DEALER) + self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.to_rpc_server.bind(rpc_path) # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server = self.context.socket(zmq.constants.ROUTER) + self.from_api_server: Socket = self.context.socket( + zmq.constants.ROUTER) self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.from_api_server.bind(INPROC_PROXY_PATH) # Asyncio background task for the proxy. - self.proxy_task = asyncio.create_task( + self.proxy_in_task = asyncio.create_task( self.run_proxy(self.from_api_server, self.to_rpc_server)) + self.proxy_out_task = asyncio.create_task( + self.run_proxy(self.to_rpc_server, self.from_api_server)) # Since we open 1 inproc socket per request, we have a hard cap on # the number of requests that can run in vLLM w. frontend @@ -136,20 +141,11 @@ def __init__(self, rpc_path: str): # 1 for generate(), 1 for abort(), do_log_stats(), check_health() self.limit_concurrency = socket_limit // 2 - 2 - async def run_proxy(self, socket_from, socket_to): + async def run_proxy(self, socket_from: Socket, socket_to: Socket): """Background task that runs a proxy""" - poller = zmq.asyncio.Poller() - poller.register(socket_from, zmq.constants.POLLIN) - poller.register(socket_to, zmq.constants.POLLIN) while True: - events_lst = await poller.poll() - events = dict(events_lst) - if socket_from in events: - identity, msg = await socket_from.recv_multipart() - await socket_to.send_multipart([identity, msg]) - if socket_to in events: - identity, msg = await socket_to.recv_multipart() - await socket_from.send_multipart([identity, msg]) + frames = await socket_from.recv_multipart(copy=False) + await socket_to.send_multipart(frames, copy=False) async def setup(self): """Setup the client before it starts sending server requests.""" @@ -180,7 +176,7 @@ def close(self): self.context.destroy() @contextmanager - def to_proxy_socket(self): + def to_proxy_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. @@ -208,7 +204,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, with self.to_proxy_socket() as socket: # Ping RPCServer with a request. - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), ), + copy=False) # Make sure the server responds if await socket.poll(timeout=self._data_timeout) == 0: @@ -216,7 +213,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, f"{self._data_timeout} ms") # Await the data from the Server. - data = cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) if isinstance(data, Exception): # Re-raise exceptions returned by the server @@ -234,23 +232,22 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, return data - async def _send_one_way_rpc_request( - self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[zmq.asyncio.Socket] = None): + async def _send_one_way_rpc_request(self, + request: RPC_REQUEST_TYPE, + error_message: str, + socket: Optional[Socket] = None): """Send one-way RPC request to trigger an action.""" - async def do_rpc_call(socket: zmq.asyncio.Socket, - request: RPC_REQUEST_TYPE): + async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), )) if await socket.poll(timeout=self._data_timeout) == 0: raise TimeoutError("Server didn't reply within " f"{self._data_timeout} ms") - return cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + return pickle.loads(frame.buffer) # Make a new socket connection. if socket is None: @@ -386,21 +383,19 @@ async def generate( try: with self.to_proxy_socket() as socket: # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart([ - cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)) - ]) + await socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) # Stream back the results from the RPC Server. while not finished: - message = await socket.recv() - request_output = cloudpickle.loads(message) + message = await socket.recv(copy=False) + request_output = pickle.loads(message.buffer) if isinstance(request_output, Exception): # On exception, check if the server is still healthy @@ -424,9 +419,7 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) - async def check_health(self, - socket: Optional[zmq.asyncio.Socket] = None - ) -> None: + async def check_health(self, socket: Optional[Socket] = None) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( @@ -451,4 +444,4 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") \ No newline at end of file + error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 738d12bbef05..bebc2faedb68 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,4 +1,5 @@ import asyncio +import pickle import signal from typing import Any, Coroutine, Union @@ -7,6 +8,8 @@ import zmq import zmq.asyncio from typing_extensions import Never +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket from vllm import AsyncEngineArgs, AsyncLLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -35,7 +38,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs, self.context = zmq.asyncio.Context() # Init socket. - self.socket = self.context.socket(zmq.constants.DEALER) + self.socket: Socket = self.context.socket(zmq.constants.DEALER) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) self.socket.connect(rpc_path) @@ -63,30 +66,31 @@ async def get_config(self, identity, request): else: raise ValueError("Unknown Config Request: %s", request) - await self.socket.send_multipart( - [identity, cloudpickle.dumps(config)]) + await self.socket.send_multipart((identity, pickle.dumps(config)), + copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def is_tracing_enabled(self, identity): """Send the is_tracing_enabled flag""" tracing_flag = await self.engine.is_tracing_enabled() await self.socket.send_multipart( - [identity, cloudpickle.dumps(tracing_flag)]) + (identity, pickle.dumps(tracing_flag))) async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def is_server_ready(self, identity): """Notify the client that we are ready.""" await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" @@ -96,7 +100,7 @@ async def abort(self, identity, request: RPCAbortRequest): result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR except Exception as e: result = e - await self.socket.send_multipart([identity, cloudpickle.dumps(result)]) + await self.socket.send_multipart((identity, pickle.dumps(result))) async def generate(self, identity, generate_request: RPCGenerateRequest): try: @@ -110,45 +114,47 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): async for request_output in results_generator: await self.socket.send_multipart( - [identity, cloudpickle.dumps(request_output)]) + (identity, pickle.dumps(request_output)), copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def start_profile(self, identity): logger.info("Starting profiler...") await self.engine.start_profile() logger.info("Profiler started.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) async def stop_profile(self, identity): logger.info("Stopping profiler...") await self.engine.stop_profile() logger.info("Profiler stopped.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) def _make_handler_coro(self, identity, - message) -> Coroutine[Any, Any, Never]: + message: Frame) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" - request = cloudpickle.loads(message) + request = cloudpickle.loads(message.buffer) if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) @@ -189,7 +195,7 @@ async def run_server_loop(self): running_tasks = set() while True: # Wait for a request. - identity, message = await self.socket.recv_multipart() + identity, message = await self.socket.recv_multipart(copy=False) # Process the request async. task = asyncio.create_task( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d31ac4995fe2..f7576509d06c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -94,7 +94,7 @@ async def create_chat_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - conversation, mm_futures = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages( request.messages, model_config, tokenizer) tool_dicts = None if request.tools is None else [ @@ -116,12 +116,8 @@ async def create_chat_completion( mm_data: Optional[MultiModalDataDict] = None try: - if len(mm_futures): - # since we support only single mm data currently - assert len( - mm_futures - ) == 1, "Multiple 'image_url' input is currently not supported." - mm_data = await mm_futures[0] + if mm_data_future: + mm_data = await mm_data_future except Exception as e: logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 1aeabb7a7d72..fc9ca29e9cf8 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -65,10 +65,10 @@ async def create_tokenize( if isinstance(request, TokenizeChatRequest): model_config = self.model_config - conversation, mm_futures = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages( request.messages, model_config, tokenizer) - if mm_futures: + if mm_data_future: logger.warning( "Multi-modal inputs are ignored during tokenization") diff --git a/vllm/envs.py b/vllm/envs.py index 24e09ee0e055..3c6b6adff82f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -31,6 +31,7 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: bool = False + VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -196,6 +197,10 @@ def get_default_config_root(): # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": + lambda: + (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in + ("true", "1")), # local rank of the process in the distributed setting, used to determine # the GPU device id @@ -348,7 +353,7 @@ def get_default_config_root(): os.path.join(get_default_cache_root(), "vllm", "xla_cache"), )), "VLLM_FUSED_MOE_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": @@ -400,6 +405,10 @@ def get_default_config_root(): "VLLM_TORCH_PROFILER_DIR": lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + + # If set, vLLM will use Triton implementations of AWQ. + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), } # end-env-vars-definition diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 37d12725bd1e..21ad43f64168 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -11,8 +11,9 @@ ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async) from vllm.worker.worker_base import WorkerWrapperBase diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 1a35a7c3b8f7..ad84422ee212 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -6,7 +6,8 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest logger = init_logger(__name__) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 422bef107f35..c96cb0f2c298 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -6,8 +6,9 @@ PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest class ExecutorBase(ABC): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 795692195f84..947776e5d6ef 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,8 +3,9 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 02b2499be465..9c6d4051eb3f 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -14,7 +14,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, get_distributed_init_method, get_open_port, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 02627de3e0be..f2fcfa58b26e 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -3,7 +3,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 867859d8d3d7..78606e223aa7 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -9,7 +9,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 760c06cb6c06..ab8844bcdafe 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -12,7 +12,8 @@ from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 7048d4798072..8c8b5f741488 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -10,7 +10,8 @@ from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.tpu_executor import TPUExecutor from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -70,6 +71,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_module_name = "vllm.worker.tpu_worker" worker_class_name = "TPUWorker" + # GKE does not fetch environment information from metadata server + # and instead sets these from within the Ray process. Therefore we + # need to override the Ray environment variables manually. + override_env = {} + if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: + override_env.update({ + "TPU_CHIPS_PER_HOST_BOUNDS": + os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] + }) + if "TPU_HOST_BOUNDS" in os.environ: + override_env.update( + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + worker = ray.remote( num_cpus=0, resources={"TPU": 1}, @@ -80,6 +94,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) + if override_env: + worker.override_env_vars.remote(override_env) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: @@ -95,12 +111,40 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Else, added to the list of workers. self.workers.append(worker) + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) if self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any TPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "TPU node.") + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + # Get the set of TPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index bfdd0f5cf97b..59e9854393b6 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,3 +1,4 @@ +import os import time from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -84,6 +85,9 @@ def execute_model_spmd( return output + def override_env_vars(self, vars: Dict[str, str]): + os.environ.update(vars) + ray_import_err = None except ImportError as e: @@ -291,3 +295,28 @@ def initialize_ray_cluster( _verify_bundles(current_placement_group, parallel_config, device_str) # Set the placement group in the parallel config parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 253c8abdc1ad..0af8ba41e24d 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -5,7 +5,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 774204dd4612..bada56068507 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -9,7 +9,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async from vllm.worker.worker_base import WorkerBase diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 0bbc1844ef45..619408b9315c 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -160,6 +160,9 @@ def _bgmv_expand( return -bgmv_expand = torch.library.custom_op("lora::bgmv_expand", - _bgmv_expand, - mutates_args=["output_tensor"]) +try: + bgmv_expand = torch.library.custom_op("lora::bgmv_expand", + _bgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 87d7d9902a4c..c16db233891a 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -173,6 +173,9 @@ def _bgmv_expand_slice( return -bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", - _bgmv_expand_slice, - mutates_args=["output_tensor"]) +try: + bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", + _bgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index c979d758492d..0846ff36b169 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -142,6 +142,9 @@ def _bgmv_shrink( return -bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", - _bgmv_shrink, - mutates_args=["output_tensor"]) +try: + bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", + _bgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 80a0b605b0fe..c71332d8bdfb 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -192,6 +192,9 @@ def _sgmv_expand( return -sgmv_expand = torch.library.custom_op("lora::sgmv_expand", - _sgmv_expand, - mutates_args=["output_tensor"]) +try: + sgmv_expand = torch.library.custom_op("lora::sgmv_expand", + _sgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 53237166a1c6..b4ae9a2acbb5 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -205,6 +205,9 @@ def _sgmv_expand_slice( return -sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", - _sgmv_expand_slice, - mutates_args=["output_tensor"]) +try: + sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", + _sgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 51d2a09eee94..c0791c260e91 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -189,6 +189,9 @@ def _sgmv_shrink( return -sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", - _sgmv_shrink, - mutates_args=["output_tensor"]) +try: + sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", + _sgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_shrink = _sgmv_shrink diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index d666fc293757..6d5c83429996 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -10,10 +10,8 @@ import torch from vllm.triton_utils import HAS_TRITON -from vllm.utils import is_xpu -# FIXME: xpu path doesn't support torch.library.custom_op -if HAS_TRITON and not is_xpu(): +if HAS_TRITON: from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..cd0cdbea0c33 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..ba9041d00850 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..57055453aa24 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3584": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d2b152320e11..05169eaddb25 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2,7 +2,7 @@ import functools import json import os -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch import triton @@ -446,7 +446,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, rand_perm1: torch.Tensor, rand_perm2: torch.Tensor, topk: int, - renormalize: bool, + custom_routing_function: Optional[Callable] = None, + renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -497,8 +498,12 @@ def fused_marlin_moe(hidden_states: torch.Tensor, E = w1.shape[0] N = w2.shape[1] * 16 - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + if custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) get_config_func = functools.partial(try_get_optimal_moe_config, w1.shape, @@ -695,6 +700,7 @@ def fused_moe( use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -742,9 +748,12 @@ def fused_moe( topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk, renormalize, num_expert_group, topk_group) - else: + elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) return fused_experts(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 61ebef5e11f4..3df0b61a9ebe 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,6 +1,6 @@ from abc import abstractmethod from enum import Enum -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -62,15 +62,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -79,17 +82,21 @@ def apply(self, renormalize=renormalize, use_grouped_topk=use_grouped_topk, topk_group=topk_group, - num_expert_group=num_expert_group) - - def forward_cuda(self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) @@ -101,7 +108,8 @@ def forward_cuda(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -114,20 +122,24 @@ def forward_cpu(self, *args, **kwargs): raise NotImplementedError( "The CPU backend currently does not support MoE.") - def forward_tpu(self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None assert topk_group is None + assert custom_routing_function is None return fused_moe(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -172,6 +184,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", + custom_routing_function: Optional[Callable] = None, ): super().__init__() @@ -190,6 +203,7 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + self.custom_routing_function = custom_routing_function if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -390,7 +404,8 @@ def select_experts(hidden_states: torch.Tensor, use_grouped_topk: bool, renormalize: bool, topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None): + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None): from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, grouped_topk) @@ -405,11 +420,17 @@ def select_experts(hidden_states: torch.Tensor, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group) - else: + elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) return topk_weights, topk_ids @@ -426,7 +447,8 @@ def forward(self, hidden_states: torch.Tensor, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, topk_group=self.topk_group, - num_expert_group=self.num_expert_group) + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1cad4e55f51e..1163cc727762 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -23,7 +23,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", - "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod" + "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod" ] @@ -35,9 +36,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def adjust_bitsandbytes_shard(param: Parameter, - qkv_offsets: Dict[str, Tuple[int, int]], - loaded_shard_id: str) -> Tuple[int, int]: +def adjust_bitsandbytes_4bit_shard(param: Parameter, + qkv_offsets: Dict[str, Tuple[int, int]], + loaded_shard_id: str) -> Tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" total, _ = qkv_offsets["total"] @@ -504,8 +505,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: shard_size = loaded_weight.shape[output_dim] shard_offset = loaded_weight.shape[output_dim] * \ loaded_shard_id @@ -857,8 +859,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), "k": (self.num_heads * self.head_size, @@ -870,7 +873,7 @@ def weight_loader(self, ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0) } - shard_size, shard_offset = adjust_bitsandbytes_shard( + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: diff --git a/vllm/model_executor/layers/mamba/__init__.py b/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/mamba/ops/__init__.py b/vllm/model_executor/layers/mamba/ops/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py new file mode 100644 index 000000000000..413c8bc227ae --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, Tri Dao. + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: str = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert (initial_states is + None), "initial_states must be None if seq_idx is not None" + assert (not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and (initial_states.stride(2) != 1 + and initial_states.stride(1) != 1): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert (final_states_out.stride(2) == 1 + or final_states_out.stride(1) == 1) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty(batch, + width - 1, + dim, + device=x.device, + dtype=x.dtype).transpose(1, 2) + else: + final_states_out = None + + out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states, + final_states_out, activation + in ["silu", "swish"]) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation_bool = activation in ["silu", "swish"] + return ops.causal_conv1d_update(x, conv_state, weight, bias, + activation_bool) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py new file mode 100644 index 000000000000..869c69214caf --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -0,0 +1,346 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import torch +import triton +import triton.language as tl +from packaging import version + +from vllm import _custom_ops as ops + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics( + {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + # Matrix dimensions + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // + nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // + nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, + state, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else + ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( + -1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), + dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_fn(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros(( + u.shape[0], + u.shape[1], + n_chunks, + int(A.shape[1] * 2), + ), + device=u.device, + dtype=torch.float32, + requires_grad=False) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, position_indices, x) + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if z is None: + return out if not return_last_state else (out, last_state) + else: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py new file mode 100644 index 000000000000..ad706f28a742 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -0,0 +1,304 @@ +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( + 0, BLOCK_SIZE_X * 8) + result_offsets = (8 * num_cols * result_offsets_y[:, None] + + result_offsets_x[None, :]) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + + tl.arange(0, BLOCK_SIZE_Y) // group_size) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + + tl.arange(0, BLOCK_SIZE_Y) // group_size) + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + + tl.arange(0, BLOCK_SIZE_X * 8)) + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + + scale_offsets_x[None, :]) + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, + group_size, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], + (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) + + tl.arange(0, BLOCK_SIZE_N) // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) + + tl.arange(0, BLOCK_SIZE_N) // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + + # Dequantize b. + offsets_szk = ( + (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + + tl.arange(0, BLOCK_SIZE_K) // group_size) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton(qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty(qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META['BLOCK_SIZE_X']), + triton.cdiv(Y, META['BLOCK_SIZE_Y']), + ) + awq_dequantize_kernel[grid](qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton(input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), + split_k_iters, + ) + + result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid](input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters) + + return result diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index c143d1a8f2bc..66bc5395dbd7 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) @@ -15,8 +14,28 @@ class BitsAndBytesConfig(QuantizationConfig): Reference: https://arxiv.org/abs/2305.14314 """ - def __init__(self, ) -> None: - pass + def __init__( + self, + load_in_8bit: bool = False, + load_in_4bit: bool = True, + bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_type: str = "fp4", + bnb_4bit_use_double_quant: bool = False, + llm_int8_enable_fp32_cpu_offload: bool = False, + llm_int8_has_fp16_weight: bool = False, + llm_int8_skip_modules: Optional[Any] = None, + llm_int8_threshold: float = 0.0, + ) -> None: + + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_threshold = llm_int8_threshold def __repr__(self) -> str: return "BitsAndBytesConfig" @@ -41,7 +60,46 @@ def get_config_filenames() -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": - return cls() + + def get_safe_value(config, keys, default_value=None): + try: + value = cls.get_from_keys(config, keys) + return value if value is not None else default_value + except ValueError: + return default_value + + load_in_8bit = get_safe_value(config, ["load_in_8bit"], + default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], + default_value=True) + bnb_4bit_compute_dtype = get_safe_value(config, + ["bnb_4bit_compute_dtype"], + default_value="float32") + bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], + default_value="fp4") + bnb_4bit_use_double_quant = get_safe_value( + config, ["bnb_4bit_use_double_quant"], default_value=False) + llm_int8_enable_fp32_cpu_offload = get_safe_value( + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) + llm_int8_has_fp16_weight = get_safe_value(config, + ["llm_int8_has_fp16_weight"], + default_value=False) + llm_int8_skip_modules = get_safe_value(config, + ["llm_int8_skip_modules"], + default_value=[]) + llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], + default_value=0.0) + + return cls( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + llm_int8_skip_modules=llm_int8_skip_modules, + llm_int8_threshold=llm_int8_threshold) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["BitsAndBytesLinearMethod"]: @@ -78,39 +136,58 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - quant_ratio = 0 - if params_dtype.is_floating_point: - quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo( - torch.uint8).bits + from bitsandbytes.nn import Int8Params + + def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + def create_qweight_for_8bit(): + qweight = Int8Params( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": 1, + "use_bitsandbytes_8bit": True, + "generation": 0 + }) + return qweight + + def create_qweight_for_4bit(): + quant_ratio = calculate_quant_ratio(params_dtype) + + total_size = input_size_per_partition * sum(output_partition_sizes) + if total_size % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape.") + + qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, + 1, + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True + }) + return qweight + + if self.quant_config.load_in_8bit: + qweight = create_qweight_for_8bit() else: - quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo( - torch.uint8).bits - - if input_size_per_partition * sum( - output_partition_sizes) % quant_ratio != 0: - raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. ") - qweight = Parameter( - torch.empty( - input_size_per_partition * sum(output_partition_sizes) // - quant_ratio, - 1, - dtype=torch.uint8, - ), - requires_grad=False, - ) - - set_weight_attrs( - qweight, - { - "input_dim": 0, - # In bitsandbytes, a tensor of shape [n,m] is quantized to - #[n*m/pack_ratio, 1],so the output_dim is 0 - "output_dim": 0, - "pack_factor": quant_ratio, - "use_bitsandbytes": True, - }) + qweight = create_qweight_for_4bit() + layer.register_parameter("qweight", qweight) set_weight_attrs(qweight, extra_weight_attrs) @@ -119,6 +196,88 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.quant_config.load_in_8bit: + return self._apply_8bit_weight(layer, x, bias) + else: + return self._apply_4bit_weight(layer, x, bias) + + def _apply_8bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import MatmulLtState, matmul + + original_type = x.dtype + bf_x = x.to(torch.bfloat16) + + qweight = layer.qweight + offsets = qweight.bnb_shard_offsets + quant_states = qweight.bnb_quant_state + matmul_states = qweight.matmul_state + generation = qweight.generation + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.float16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + + # in profile_run or the first generation of inference, + # create new matmul_states + if generation == 0 or generation == 1: + matmul_states[i] = MatmulLtState() + matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].SCB = quant_states[i] + matmul_states[i].threshold = ( + self.quant_config.llm_int8_threshold) + matmul_states[i].has_fp16_weights = ( + self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].is_training = False + if matmul_states[i].threshold > 0.0 and not matmul_states[ + i].has_fp16_weights: + matmul_states[i].use_pool = True + + new_x = bf_x.unsqueeze(0) + + out[:, current_index:current_index + output_size] = matmul( + new_x, + qweight[offsets[i]:offsets[i + 1]], + state=matmul_states[i]) + + current_index += output_size + + # only update the matmul_states if it is not profile_run + if (generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None): + del matmul_states[i].CB + qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + + out = out.to(original_type) + + if bias is not None: + out += bias + + qweight.generation += 1 + + return out + + def _apply_4bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # only load the bitsandbytes module when needed from bitsandbytes import matmul_4bit diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0e0ab9ce9169..36323493d601 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,6 +1,6 @@ import enum from enum import Enum -from typing import List, Optional +from typing import Callable, List, Optional import torch @@ -256,15 +256,18 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, ) replace_tensor("w2_weight_scale", marlin_w2_scales) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_marlin_moe) @@ -278,6 +281,7 @@ def apply(self, layer.w13_g_idx_sort_indices, layer.w2_g_idx_sort_indices, top_k, + custom_routing_function, renormalize=renormalize, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index dabf17df78fe..116a4ea0aed8 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch @@ -96,15 +96,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_scale", w2_scale) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -114,7 +117,8 @@ def apply(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1817dbcb023a..32affe06b89b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch.nn import Module @@ -468,15 +468,18 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) return - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -487,7 +490,8 @@ def apply(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index ae34e01497db..be8235b468f6 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import ModelWeightParameter ACTIVATION_SCHEMES = ["none"] @@ -64,16 +64,16 @@ def create_weights(self, layer: Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) def _quantize_weight( self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -92,6 +92,7 @@ def _quantize_weight( return qweight, qscale def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.data, requires_grad=False) device = layer.weight.device qweight, qscale = self._quantize_weight(layer.weight) qweight = qweight.to(device) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 2124196d06f9..b2f333a5bcc8 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,12 +1,28 @@ from functools import cached_property +from importlib.util import find_spec from typing import Dict, List, Optional, Tuple import torch import torch.jit +import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeStochasticBaseSampler) +logger = init_logger(__name__) + +if find_spec("flashinfer"): + """ + Consider utilizing the FlashInfer rejection sampling kernel initially, + as it employs a dedicated kernel rather than relying on + Torch tensor operations. This design choice helps to fuse operations, + reduce memory I/O, and consequently enhances performance. + """ + from flashinfer.sampling import chain_speculative_sampling +else: + chain_speculative_sampling = None + class RejectionSampler(SpecDecodeStochasticBaseSampler): """Apply modified rejection sampling as described in "Accelerating Large @@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): def __init__(self, disable_bonus_tokens: bool = True, - strict_mode: bool = False): + strict_mode: bool = False, + use_flashinfer: Optional[bool] = None): """Create a rejection sampler. Args: @@ -26,13 +43,29 @@ def __init__(self, strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. + use_falshinfer: We will use this parameter to determine whether + to use the FlashInfer rejection sampling kernel or not. If it's + None, we will use the default value from the environment variable. + This parameter is only used for testing purposes. """ super().__init__(disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) + if use_flashinfer is None: + self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( + chain_speculative_sampling is not None) + else: + self.use_flashinfer = use_flashinfer + + if self.use_flashinfer: + assert not disable_bonus_tokens, \ + "flashinfer will enable bonus token by default" + logger.info("Use flashinfer for rejection sampling.") + else: + logger.info("Use pytorch for rejection sampling.") def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -50,9 +83,9 @@ def forward( sequence. Args: - target_probs: The probability distribution over token ids given - context according to the target model. - shape = [batch_size, num_speculative_tokens, vocab_size] + target_with_bonus_probs: The probability distribution + over token ids given context according to the target model. + shape = [batch_size, num_speculative_tokens + 1, vocab_size] bonus_token_ids: The "bonus" token ids that are accepted iff all speculative tokens in a sequence are accepted. @@ -78,23 +111,52 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) - accepted, recovered_token_ids = ( - self._batch_modified_rejection_sampling( - target_probs, - draft_probs, - draft_token_ids, - seeded_seqs, - )) + batch_size, k, _ = draft_probs.shape - output_token_ids = self._create_output( - accepted, - recovered_token_ids, - draft_token_ids, - bonus_token_ids, - ) + # batch_size = 0 when all requests in the batch are + # non_spec requests. In this case, output_token_ids is + # just an empty tensor. + if batch_size == 0: + return torch.empty(0, k + 1, device=draft_probs.device, dtype=int) + + # If use Flashinfer chain_speculative_sampling kernel + # for rejection sampling + if self.use_flashinfer: + batch_size, k, _ = draft_probs.shape + uniform_samples = self._create_uniform_samples( + seeded_seqs, batch_size, k, draft_probs.device) + output_token_ids, accepted_token_num, emitted_token_num \ + = chain_speculative_sampling( + draft_probs, draft_token_ids, uniform_samples, + target_with_bonus_probs) + + # num_emitted_tokens returned by flashinfer + # does not include the bonus token + # Flashinfer stops at the first token that violates + # the condition p >= q and does not include recovery/bonus token. + # Therefore, we need to add batch_size here. + self.num_accepted_tokens += accepted_token_num.sum() + self.num_emitted_tokens += emitted_token_num.sum() + batch_size + self.num_draft_tokens += batch_size * k + else: + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_with_bonus_probs[:, :-1], + draft_probs, + draft_token_ids, + seeded_seqs, + )) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) return output_token_ids @@ -135,6 +197,63 @@ def _batch_modified_rejection_sampling( return accepted, recovered_token_ids + def _create_uniform_samples(self, + seeded_seqs: Optional[Dict[int, + torch.Generator]], + batch_size: int, k: int, + device: torch.device) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + for specific sequences. + + This method creates a tensor of shape `(batch_size, k + 1)` filled + with uniform random values in the range [0, 1). If `seeded_seqs` + is provided, the sequences corresponding to specific indices + will be generated using the provided `torch.Generator` for + reproducibility. The other sequences will be generated without + a seed. + + Args: + seeded_seqs : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. If `None`, all samples are + generated without a seed. + batch_size : int + The number of sequences to generate. + k : int + The number of random samples per sequence. + device : torch.device + The device on which to allocate the tensor. + + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(batch_size, k + 1)` containing uniform + random values in the range [0, 1). + """ + if not seeded_seqs: + return torch.rand(batch_size, k + 1, device=device) + + uniform_rand = torch.empty(batch_size, k + 1, device=device) + + non_seeded_indices = [] + for idx in range(batch_size): + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.append(idx) + else: + uniform_rand[idx, :] = torch.rand(1, + k + 1, + dtype=self.probs_dtype, + device=device, + generator=generator) + if non_seeded_indices: + uniform_rand[non_seeded_indices, :] = torch.rand( + len(non_seeded_indices), + k + 1, + dtype=self.probs_dtype, + device=device) + return uniform_rand + def _get_accepted( self, target_probs: torch.Tensor, # [batch_size, k, vocab_size] @@ -175,29 +294,8 @@ def _get_accepted( selected_target_probs = target_probs[batch_indices, probs_indicies, draft_token_ids] - if not seeded_seqs: - uniform_rand = torch.rand_like(selected_target_probs) - else: - uniform_rand = torch.empty_like(selected_target_probs) - - non_seeded_indices = [] - for idx in range(batch_size): - generator = seeded_seqs.get(idx) - if generator is None: - non_seeded_indices.append(idx) - else: - uniform_rand[idx, :] = torch.rand( - 1, - k, - dtype=self.probs_dtype, - device=target_probs.device, - generator=generator) - if non_seeded_indices: - uniform_rand[non_seeded_indices, :] = torch.rand( - len(non_seeded_indices), - k, - dtype=self.probs_dtype, - device=target_probs.device) + uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, + k - 1, target_probs.device) capped_ratio = torch.minimum( selected_target_probs / selected_draft_probs, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 0562b71aa749..c5a0278e485d 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -503,8 +503,8 @@ def __init__( dtype: torch.dtype, short_factor: List[float], long_factor: List[float], - short_mscale: float = 1.0, - long_mscale: float = 1.0, + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, ): super().__init__() @@ -523,18 +523,22 @@ def __init__( self.base = base self.short_factor = short_factor self.long_factor = long_factor - self.short_mscale = short_mscale - self.long_mscale = long_mscale - - scale = (self.max_position_embeddings / - self.original_max_position_embeddings) + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings if scale <= 1.0: - self.scaling_factor = 1.0 + scaling_factor = 1.0 else: - self.scaling_factor = math.sqrt( + scaling_factor = math.sqrt( 1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) @@ -571,8 +575,8 @@ def _compute_cos_sin_cache( inv_freq = self._compute_inv_freq(rescale_factors) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * mscale * self.scaling_factor - sin = freqs.sin() * mscale * self.scaling_factor + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale cache = torch.cat((cos, sin), dim=-1) return cache diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 7344d59e988f..c00da106734a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,13 +1,16 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools import warnings +from dataclasses import dataclass from importlib.util import find_spec from math import inf -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union +import msgspec import torch import torch.nn as nn +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.triton_utils import HAS_TRITON if HAS_TRITON: @@ -19,8 +22,7 @@ SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceOutput) + PromptLogprobs, SampleLogprobs, SequenceOutput) if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -35,6 +37,116 @@ # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] +# Types of temporary data structures used for +# computing sample_result +SampleMetadataType = Dict[SamplingType, Tuple[List[int], + List[SequenceGroupToSample]]] +MultinomialSamplesType = Dict[SamplingType, torch.Tensor] +SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] + + +# Encapsulates temporary data structures for computing +# sample_result. +# +# * For multi-step scheduling: must be returned +# by `Sampler.forward()` and used later to compute the pythonized +# sample_result +# +# * For single-step scheduling: consumed immediately +# inside `Sampler.forward()` to compute pythonized sample_result. +@dataclass +class SampleResultArgsType: + sample_metadata: SampleMetadataType + multinomial_samples: MultinomialSamplesType + sample_results_dict: SampleResultsDictType + sampling_metadata: SamplingMetadata + greedy_samples: Optional[torch.Tensor] + beam_search_logprobs: Optional[torch.Tensor] + + +# Union of non-deferred (single-step scheduling) +# vs deferred (multi-step scheduling) +# sample result types +MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] + +# Abbreviation of the _sample() return type +SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] + + +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """For each sequence group, we generate a list of SequenceOutput object, + each of which contains one possible candidate for the next token. + + This data structure implements methods, so it can be used like a list, but + also has optional fields for device tensors. + """ + + outputs: List[CompletionSequenceGroupOutput] + + # On-device tensor containing probabilities of each token. + sampled_token_probs: Optional[torch.Tensor] = None + + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + + # Holds either (1) the pythonized sampler result (single-step scheduling) + # or (2) what will be arguments for later deferred pythonization of the + # sampler result (muliti-step scheduling) + deferred_sample_results_args: Optional[SampleResultArgsType] = None + + # On-device tensor containing the sampled token ids. + sampled_token_ids: Optional[torch.Tensor] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None + + # Spec decode metrics populated by workers. + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None + + # Optional last hidden states from the model. + hidden_states: Optional[torch.Tensor] = None + + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + + # Time taken in the model execute function. This will include model forward, + # block/sync across workers, cpu-gpu sync time and sampling time. + model_execute_time: Optional[float] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -98,6 +210,19 @@ def forward( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ + Single-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Pythonize sampling result & logprobs tensor + + Multi-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Defer Pythonization of sampling result & logprobs + tensor + * Encapsulate arguments required for deferred Pythonization + in the :class:`SamplerOutput` structure + Args: logits: (num_tokens, vocab_size). sampling_metadata: Metadata for sampling. @@ -150,7 +275,7 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results, maybe_sampled_tokens_tensor = _sample( + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( probs, logprobs, sampling_metadata, @@ -160,20 +285,28 @@ def forward( ) if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs assert maybe_sampled_tokens_tensor is not None on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. on_device_tensors = None # Get the logprobs query results. prompt_logprobs = None sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) return _build_sampler_output( - sample_results, + maybe_deferred_sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, @@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer( return batch_next_token_ids.view(-1, num_samples) +def get_pythonized_sample_results( + sample_result_args: SampleResultArgsType) -> SampleResultType: + '''This function consumes GPU-side sampler results and computes + Pythonized CPU-side sampler results (GPU -> CPU sync.) + + Single-step scheduling: this function is invoked at sampling-time + for immediate Pythonization. + + Multi-step scheduling: Pythonization is deferred until after multiple + GPU-side steps have been completed. + + Args: + sample_result_args: GPU-side inputs to the Pythonization process + + Returns: + Pythonized sampler results + ''' + + ( + sample_metadata, + sampling_metadata, + greedy_samples, + multinomial_samples, + beam_search_logprobs, + sample_results_dict, + ) = ( + sample_result_args.sample_metadata, + sample_result_args.sampling_metadata, + sample_result_args.greedy_samples, + sample_result_args.multinomial_samples, + sample_result_args.beam_search_logprobs, + sample_result_args.sample_results_dict, + ) + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, + multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + return [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + + def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, @@ -550,7 +737,19 @@ def _sample_with_torch( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[SampleResultType, Optional[torch.Tensor]]: +) -> SampleReturnType: + '''Torch-oriented _sample() implementation. + + Single-step scheduling: + * Perform GPU-side sampling computation + * Immediately Pythonize sampling result + + Multi-step scheduling: + * Perform GPU-side sampling computation + * Defer Pythonization & preserve GPU-side + tensors required for Pythonization + ''' + categorized_seq_group_ids: Dict[SamplingType, List[int]] = {t: [] for t in SamplingType} @@ -560,10 +759,11 @@ def _sample_with_torch( sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata: Dict[SamplingType, - Tuple[List[int], List[SequenceGroupToSample]]] = {} - multinomial_samples: Dict[SamplingType, torch.Tensor] = {} + sample_results_dict: SampleResultsDictType = {} + sample_metadata: SampleMetadataType = {} + multinomial_samples: MultinomialSamplesType = {} + greedy_samples: Optional[torch.Tensor] = None + beam_search_logprobs: Optional[torch.Tensor] = None # Create output tensor for sampled token ids. if include_gpu_probs_tensor: @@ -638,32 +838,29 @@ def _sample_with_torch( else: raise ValueError(f"Unsupported sampling type: {sampling_type}") - # GPU<->CPU sync happens in the loop below. - # This also converts the sample output to Python objects. + # Encapsulate arguments for computing Pythonized sampler + # results, whether deferred or otherwise. + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=sample_metadata, + multinomial_samples=multinomial_samples, + greedy_samples=greedy_samples, + beam_search_logprobs=beam_search_logprobs, + sample_results_dict=sample_results_dict) + if not sampling_metadata.skip_sampler_cpu_output: - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, - SamplingType.RANDOM_SEED): - sample_results = _random_sample( - seq_groups, multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] + # GPU<->CPU sync happens here. + # This also converts the sampler output to a Python object. + # Return Pythonized sampler result & sampled token ids + return get_pythonized_sample_results( + maybe_deferred_args), sampled_token_ids_tensor else: - sample_results = [] - - return sample_results, sampled_token_ids_tensor + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) def _sample_with_triton_kernel( @@ -755,7 +952,7 @@ def _sample( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[SampleResultType, Optional[torch.Tensor]]: +) -> SampleReturnType: """ Args: probs: (num_query_tokens_in_batch, num_vocab) @@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: return result.sum(1).add_(1) -def _get_logprobs( +def get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: SampleResultType, @@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - sample_results: SampleResultType, + maybe_deferred_sample_results: MaybeDeferredSampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], sample_logprobs: Optional[List[SampleLogprobs]], @@ -1143,14 +1340,21 @@ def _build_sampler_output( speculative decoding rejection sampling. """ sampler_output: List[CompletionSequenceGroupOutput] = [] - if not skip_sampler_cpu_output: + + if skip_sampler_cpu_output: + assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) + deferred_sample_results_args = maybe_deferred_sample_results + else: assert prompt_logprobs is not None assert sample_logprobs is not None + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + deferred_sample_results_args = None for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): + maybe_deferred_sample_results, + prompt_logprobs, sample_logprobs): seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs: List[SequenceOutput] = [] @@ -1176,7 +1380,7 @@ def _build_sampler_output( sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, logprobs=logprobs_tensor, - ) + deferred_sample_results_args=deferred_sample_results_args) def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 467c43c41550..f9532dffa92c 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -130,29 +130,35 @@ def _create_output( def _raise_if_incorrect_input( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - self._raise_if_incorrect_shape(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_incorrect_dtype(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_inconsistent_device(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + self._raise_if_incorrect_shape(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_incorrect_dtype(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_inconsistent_device(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1], draft_token_ids, bonus_token_ids) def _raise_if_incorrect_shape( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: (target_batch_size, num_target_probs, - target_vocab_size) = target_probs.shape + target_vocab_size) = target_with_bonus_probs.shape + + # Does not count the extra token + num_target_probs -= 1 # validate the shape of draft token ids. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape @@ -175,12 +181,12 @@ def _raise_if_incorrect_shape( def _raise_if_incorrect_dtype( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - assert target_probs.dtype == self.probs_dtype + assert target_with_bonus_probs.dtype == self.probs_dtype assert draft_token_ids.dtype == self.token_id_dtype assert bonus_token_ids.dtype == self.token_id_dtype if draft_probs is not None: @@ -188,15 +194,16 @@ def _raise_if_incorrect_dtype( def _raise_if_inconsistent_device( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] - if t is not None + t.device for t in [ + target_with_bonus_probs, bonus_token_ids, draft_probs, + draft_token_ids + ] if t is not None ] assert all([devices[0] == device for device in devices]) @@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): @abstractmethod def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): @abstractmethod def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index a87ea0eee57d..7428d33ea720 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -41,7 +41,7 @@ def __init__( def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -80,8 +80,9 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input(target_probs, draft_token_ids, - bonus_token_ids) + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids) + target_probs = target_with_bonus_probs[:, :-1] accepted = self._evaluate_accepted_tokens(target_probs, draft_token_ids) recovered_token_ids = self._replacement_token_ids(target_probs) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 3ba15573c217..b26a3227e693 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -351,7 +351,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.weight_type = loaded_weight.item() return elif isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + shape = list(loaded_weight.shape) + if output_dim is not None: + shape[output_dim] = shape[output_dim] // self.tp_size + param.materialize(tuple(shape), dtype=loaded_weight.dtype) # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2f6cdbc6ce3e..553fa848489b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -771,7 +771,11 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): return pt_weights_iterator(hf_weights_files) def _get_quantized_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], pre_quant: bool + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, @@ -780,11 +784,9 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes - from bitsandbytes.functional import QuantState if bitsandbytes.__version__ < "0.42.0": raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.42.0.") - from bitsandbytes.functional import quantize_4bit except ImportError as err: raise ImportError("Please install bitsandbytes>=0.42.0 via " "`pip install bitsandbytes>=0.42.0` to use " @@ -793,80 +795,111 @@ def _get_quantized_weights_iterator( hf_weights_files, use_safetensors = self._prepare_weights( model_name_or_path, revision) - quant_state_dict = {} - - def quantized_checkpoint() -> Generator: - # First iterate over all quant state weights - weight_iterator = self._hf_weight_iter(hf_weights_files, - use_safetensors) - temp_state_dict = {} - for weight_name, weight_tensor in weight_iterator: - if weight_name.endswith(".weight"): - continue - # TODO: only nf4 quantization is supported for now - if weight_name.endswith(".quant_state.bitsandbytes__fp4"): - raise NotImplementedError( - "Only bitsandbytes_nf4 quantization" - f"is supported for now. {weight_name} is fp4 quantized" - ) - temp_state_dict[weight_name] = weight_tensor + quant_state_dict: Dict[str, Any] = {} - # Closure to parse quant_state for each prequant weight - def _parse_quant_state(param_name: str, - temp_state_dict: Dict) -> QuantState: - quant_state = {} - for k in temp_state_dict: - if param_name + "." in k: - quant_state[k] = temp_state_dict[k] - # bitsandbytes library requires - # weight.quant_state.bitsandbytes__nf4 in CPU - quant_state[param_name + - ".quant_state.bitsandbytes__nf4"] = quant_state[ - param_name + - ".quant_state.bitsandbytes__nf4"].cpu().data - return QuantState.from_dict(quant_state, device="cuda") - - # Second iterate over all prequant and normal weights - # pre quantized weights would have a quant_state - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - # Filter out all weights whose suffix is not ".weight" - if not weight_name.endswith(".weight"): - continue - if weight_name + ".quant_state.bitsandbytes__nf4" \ - in temp_state_dict: - quant_state = _parse_quant_state(weight_name, - temp_state_dict) - weight_name = weight_name.replace(".weight", ".qweight") - quant_state_dict[weight_name] = quant_state - yield weight_name.replace(".weight", - ".qweight"), weight_tensor - else: - yield weight_name, weight_tensor - - def generator() -> Generator: - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - if any(target_module in weight_name - for target_module in self.target_modules): - weight_name = weight_name.replace(".weight", ".qweight") - # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data - with set_default_torch_dtype(torch.float32): - processed_weight, quant_state = quantize_4bit( - loaded_weight, - compress_statistics=True, - quant_type="nf4") - - quant_state_dict[weight_name] = quant_state - else: - processed_weight = weight_tensor + if pre_quant: + if load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict - yield weight_name, processed_weight + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict - if pre_quant: - return quantized_checkpoint(), quant_state_dict - return generator(), quant_state_dict + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if not weight_name.lower().endswith(".scb"): + continue + + weight_key = weight_name.lower().replace(".scb", ".qweight") + quant_state_dict[weight_key] = weight_tensor + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + + if not weight_name.endswith(".weight"): + continue + + qweight_name = weight_name.replace(".weight", ".qweight") + if qweight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield qweight_name, weight_tensor + else: + yield weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith(".weight"): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in weight_name: + temp_state_dict[weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + # Filter out all weights whose suffix is not ".weight" + if not weight_name.endswith(".weight"): + continue + if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ + in temp_state_dict) or \ + (f"{weight_name}.quant_state.bitsandbytes__fp4" \ + in temp_state_dict): + quant_state = _parse_quant_state(weight_name, temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None: @@ -883,16 +916,26 @@ def _load_weights(self, model_config: ModelConfig, logger.info("Loading weights with BitsAndBytes quantization. " " May take a while ...") - is_quantized_checkpoint = False quant_config = getattr(model_config.hf_config, "quantization_config", None) - if quant_config is not None and quant_config.get( - 'quant_method') == "bitsandbytes": - is_quantized_checkpoint = True + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get('quant_method') + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get('load_in_8bit', False) qweight_iterator, quant_state_dict = \ self._get_quantized_weights_iterator( - model_config.model, model_config.revision, is_quantized_checkpoint) + model_config.model, model_config.revision, pre_quant, load_8bit) model.load_weights(qweight_iterator) @@ -942,6 +985,10 @@ def _load_weights(self, model_config: ModelConfig, offsets = np.concatenate(([0], np.cumsum(num_elements))) set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 07e23aca6cc5..7396ac833e78 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,7 +1,7 @@ """Utilities for selecting and loading neuron models.""" import importlib import os -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -10,9 +10,8 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput TORCH_DTYPE_TO_NEURON_AMP = { "auto": "f32", @@ -109,6 +108,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str: f"{list(_NEURON_SUPPORTED_MODELS.keys())}") +def _get_buckets(env: str, default_value: List[int]) -> List[int]: + env_value = os.getenv(env) + if env_value is None: + return default_value + buckets_remove_empty = filter( + lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) + buckets_int = map(int, buckets_remove_empty) + buckets_list = list(buckets_int) + return buckets_list + + def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: @@ -123,14 +133,18 @@ def get_neuron_model(model_config: ModelConfig, neuron_config = NeuronConfig( continuous_batching=continuous_batching_config) + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + # Load the weights from the cached or downloaded files. - model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=[scheduler_config.max_model_len], - n_positions=[scheduler_config.max_model_len], - batch_size=scheduler_config.max_num_seqs) + model.load_weights(model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) return model.eval() diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index 5c522a61732a..3c1f6fa76989 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -15,9 +15,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import (LogitsProcessor, _prune_hidden_states) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput logger = init_logger(__name__) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 8591c276b001..e30370596496 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -22,6 +22,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), @@ -49,6 +50,7 @@ "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), @@ -63,6 +65,7 @@ "EAGLEModel": ("eagle", "EAGLE"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM") } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 28f69cfbc46b..efa044d0b5e9 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig logger = init_logger(__name__) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 73711d8eb518..bdd76b11384c 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index f78400b0df7b..9b4c4be7fcb0 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors logger = logging.get_logger(__name__) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 830680fd990b..e6acf8cd5d5b 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -7,12 +7,14 @@ import torch.nn as nn from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig -from transformers.models.blip.modeling_blip import BlipAttention +from xformers import ops as xops from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.utils import (cached_get_tokenizer, @@ -154,6 +156,77 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class BlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + self.projection = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.projection(out) + + return attn_output + + class BlipMLP(nn.Module): def __init__(self, @@ -188,7 +261,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = BlipAttention(config) + self.self_attn = BlipAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = BlipMLP(config, quant_config=quant_config) @@ -199,7 +272,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8be786fd3f6f..39f2b2d853a6 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -13,13 +13,13 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) @@ -714,8 +714,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): use_default_weight_loading = False if "vision" in name: if self.vision_model is not None: - # We only do sharding for language model and - # not vision model for now. + # BlipVisionModel does not need sharding use_default_weight_loading = True else: for (param_name, weight_name, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 07ee0e3c531d..831b3f20457a 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index b25f5d521a9b..47e020e8ecb7 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -33,7 +33,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 4949d0232fab..35f1ed5ef5d3 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -20,12 +20,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 69bb9f6f3afe..ddfec91d6cab 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -7,12 +7,14 @@ import torch.nn as nn from PIL import Image from transformers import CLIPVisionConfig -from transformers.models.clip.modeling_clip import CLIPAttention +from xformers import ops as xops from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -160,6 +162,78 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.out_proj(out) + + return attn_output + + class CLIPMLP(nn.Module): def __init__(self, @@ -192,7 +266,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = CLIPAttention(config) + self.self_attn = CLIPAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config, quant_config=quant_config) @@ -204,7 +278,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states @@ -304,7 +378,15 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None): def device(self): return next(self.parameters()).device + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) @@ -318,7 +400,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if layer_idx >= layer_count: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index f63cf246e510..be7f19d15b62 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -38,14 +38,14 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors @torch.compile diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index dca959798e8b..6160197dc19d 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -17,13 +17,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 7a27e1388e98..61cc917ab620 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -43,12 +43,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class DeepseekMLP(nn.Module): diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c7f3af0ccb26..8cbd9435ec7c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -43,12 +43,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 99c825ff6357..ad1ab0231d86 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -5,12 +5,13 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.eagle import EAGLEConfig diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py new file mode 100644 index 000000000000..4a1c367de3f6 --- /dev/null +++ b/vllm/model_executor/models/exaone.py @@ -0,0 +1,617 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py +# Copyright 2024 The LG U+ CTO AI Tech Lab. +# Copyright 2021 The LG AI Research EXAONE Lab +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.exaone import ExaoneConfig +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +class ExaoneGatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneBlockAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.attention = ExaoneAttention( + config=config, + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=bias, + cache_config=cache_config, + prefix=prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + return self.attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + +class ExaoneDecoderLayer(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.attn = ExaoneBlockAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class ExaoneModel(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.wte = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.wte = PPMissingLayer() + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: ExaoneDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.h", + ) + if get_pp_group().is_last_rank: + self.ln_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + else: + self.ln_f = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class ExaoneForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "c_fc_0", + "c_fc_1", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "out_proj", + "gate_up_proj", + "c_proj", + "wte", + "lm_head", + ] + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "c_fc_0": ("gate_up_proj", 0), + "c_fc_1": ("gate_up_proj", 1), + } + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.transformer = ExaoneModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".c_fc_0", 0), + (".gate_up_proj", ".c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.transformer.h[layer_idx], nn.Identity): + layer_self_attn = self.transformer.h[layer_idx].attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 7b97b3d255df..b474d35baf89 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -39,12 +39,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig FalconConfig = Union[HF_FalconConfig, RWConfig] diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 6cdf331fed8b..beeae1422957 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -31,6 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -39,7 +40,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index e1041edf81b0..36fd38983128 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 5e0f8b70d4b8..90449ec51ef0 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index bfc231282952..fb5a297661dd 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b93fb8d69b2d..fe5ec1082760 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 4d52b448049b..664d775c8ba4 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class GPTJAttention(nn.Module): diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 2adecf7fa9ef..5f6f1e388054 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class GPTNeoXAttention(nn.Module): diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py new file mode 100644 index 000000000000..b0325e8b616c --- /dev/null +++ b/vllm/model_executor/models/granite.py @@ -0,0 +1,543 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.granite import GraniteConfig +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +class GraniteMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states + + +class GraniteModel(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + hidden_states *= self.config.embedding_multiplier + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = GraniteModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + logits /= self.config.logits_scaling + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 54c933e3e495..ad5919150cad 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -10,10 +10,13 @@ import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig +from xformers import ops as xops +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -81,7 +84,11 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: class InternAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PretrainedConfig): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -94,9 +101,13 @@ def __init__(self, config: PretrainedConfig): f' {self.num_heads}).') self.scale = self.head_dim**-0.5 - self.qkv = nn.Linear(self.embed_dim, - 3 * self.embed_dim, - bias=config.qkv_bias) + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) self.qk_normalization = config.qk_normalization @@ -104,25 +115,40 @@ def __init__(self, config: PretrainedConfig): self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) - self.proj = nn.Linear(self.embed_dim, self.embed_dim) + self.proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - - if self.qk_normalization: - B_, H_, N_, D_ = q.shape - q = self.q_norm.forward_native(q.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - k = self.k_norm.forward_native(k.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, C) + q = q.view(B, N, self.num_heads_per_partition, self.head_dim) + k = k.view(B, N, self.num_heads_per_partition, self.head_dim) + v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - x = self.proj(x) + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + + x = xops.memory_efficient_attention_forward( + q, + k, + v, + scale=self.scale, + ) + x = x.view(B, N, -1) + + x, _ = self.proj(x) return x @@ -161,7 +187,7 @@ def __init__(self, self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = InternAttention(config) + self.attn = InternAttention(config, quant_config=quant_config) self.mlp = InternMLP(config, quant_config=quant_config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 3482d941cb89..14c91bae5cce 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple import torch @@ -7,7 +8,10 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_rank, get_pp_group, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -17,12 +21,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -73,20 +77,21 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= self.tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % self.tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -125,11 +130,27 @@ def __init__( quant_config=quant_config) def split_qkv(self, qkv: torch.Tensor): - qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128) - q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2) - q = q.reshape(-1, self.q_size) - k = k.reshape(-1, self.kv_size) - v = v.reshape(-1, self.kv_size) + seq_len = qkv.shape[0] + if self.tp_size > 1: + qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size + qkv = tensor_model_parallel_all_gather(qkv) + qkv = torch.split(qkv, qkv_map, dim=-1) + qkv = qkv[::3] + qkv[1::3] + qkv[2::3] + qkv = torch.cat(qkv, dim=-1) + + qkv = qkv.view(seq_len, self.total_num_kv_heads, + self.key_value_groups + 2, self.head_dim) + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) + q = q.reshape(seq_len, self.q_size * self.tp_size) + k = k.reshape(seq_len, self.kv_size * self.tp_size) + v = v.reshape(seq_len, self.kv_size * self.tp_size) + + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] return q, k, v def forward( @@ -216,7 +237,6 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", ) -> None: super().__init__() self.config = config @@ -226,11 +246,10 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: InternLMDecoderLayer(config, cache_config, - quant_config), - prefix=f"{prefix}.layers") + self.layers = nn.ModuleList([ + InternLMDecoderLayer(config, cache_config, quant_config) + for _ in range(config.num_hidden_layers) + ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( @@ -258,7 +277,7 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): + for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( positions, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 2fc239f0698a..0e4569428aa7 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -18,13 +18,14 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index a550f7e6c97a..b0fbb7e9829e 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index caeda4e42d8a..73be7ffed0f8 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -4,9 +4,6 @@ from typing import Dict, Iterable, List, Optional, Tuple import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn from torch.nn.parameter import Parameter from transformers import JambaConfig @@ -24,16 +21,20 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) @@ -161,7 +162,7 @@ def mamba_forward(self, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) - hidden_states = causal_conv1d_fn( + hidden_states, _ = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0c67a9b8e198..e55c01316087 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -42,13 +42,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_hip from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 490c93294d50..43c485bdf366 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -11,10 +11,11 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 048ca16974e3..5a179e960371 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -15,10 +15,11 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 55d42952cd0c..619a5cd00d6b 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -4,11 +4,11 @@ import torch.nn as nn from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.medusa import MedusaConfig diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index ff42bdefe026..a135118bc748 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -44,13 +44,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 6a3d5422e0ce..dd10729b9ffb 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -57,7 +57,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .idefics2_vision_model import Idefics2VisionTransformer diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 413783ba4b25..e744e36ac08b 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,13 +39,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 8bdd52b34317..68471f6ac77d 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -42,12 +42,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class MixtralMLP(nn.Module): diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 9b96ecb78a3c..42ccd0129816 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -6,11 +6,10 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import MLPSpeculatorConfig SQRT2 = 2**0.5 diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 1a8e514a7ae8..0fcbf06e1a06 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -17,12 +17,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.mpt import MPTConfig diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 7d92a1ffe55d..e9ff12de2094 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -37,13 +37,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 8de124cd034d..97749725dd13 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OlmoAttention(nn.Module): diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c0d2d537e731..88d2bcb9f0c9 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OPTLearnedPositionalEmbedding(nn.Embedding): diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index fab35f0b882a..b01ce87adfa4 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -21,12 +21,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OrionMLP(nn.Module): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 46ee4c3208b7..9b29ff69808a 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -11,13 +11,13 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, @@ -145,7 +145,6 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # TODO(ywang96): Port over SiglipVisionModel & TP self.vision_tower = SiglipVisionModel(config.vision_config) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, @@ -308,34 +307,27 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if key_to_modify in name: name = name.replace(key_to_modify, new_key) use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break else: - for (param_name, shard_name, - shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # lm_head is not used in vllm as it is tied with - # embed_token. To prevent errors, skip loading - # lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - use_default_weight_loading = True + # lm_head is not used in vllm as it is tied with + # embed_token. To prevent errors, skip loading + # lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + use_default_weight_loading = True if use_default_weight_loading: param = params_dict[name] diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3300939c7b10..f8fc1cd8ef1f 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -37,12 +37,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class PersimmonMLP(nn.Module): diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index f31b5162aac9..15c21cfa2d8a 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -52,12 +52,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index df01bfa3d8e6..afc6fe9844ad 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -16,12 +16,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors def load_column_parallel_weight(param: torch.nn.Parameter, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bec1d3538850..c449e0fc759a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -31,7 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel @@ -39,7 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip @@ -71,6 +71,23 @@ projection_dim=768) +def _init_img_processor(hf_config: PretrainedConfig): + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + layer_idx = hf_config.img_processor.get('layer_idx', -2) + + # Initialize the CLIP only up to the required feature layer + if layer_idx < 0: + num_hidden_layers = clip_config.num_hidden_layers + \ + layer_idx + 1 + else: + num_hidden_layers = layer_idx + 1 + + img_processor = CLIPVisionModel( + clip_config, num_hidden_layers_override=num_hidden_layers) + + return img_processor + + class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] @@ -139,18 +156,8 @@ def __init__(self, config: PretrainedConfig) -> None: hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size - clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - self.layer_idx = config.img_processor.get('layer_idx', -2) - - # Initialize the CLIP only up to the required feature layer - if self.layer_idx < 0: - num_hidden_layers = clip_config.num_hidden_layers + \ - self.layer_idx + 1 - else: - num_hidden_layers = self.layer_idx + 1 + self.img_processor = _init_img_processor(config) - self.img_processor = CLIPVisionModel( - clip_config, num_hidden_layers_override=num_hidden_layers) image_dim_out = config.img_processor['image_dim_out'] self.num_img_tokens = config.img_processor['num_img_tokens'] @@ -656,23 +663,27 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + + # TODO(ChristopherCho): This is a temporary fix to load + # the vision weights with CLIPVisionModel.load_weights() + vision_weights = [] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: + # Skip loading the img_processor weights since they are + # loaded separately. + if "vision_embed_tokens.img_processor" in name: + vision_weights.append((name, loaded_weight)) continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue if weight_name not in name: continue + param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -686,3 +697,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + # We use regex to extract the sub-module name + # from "model.vision_embed_tokens.img_processor.*" + vision_weights = [ + (re.search(r"vision_embed_tokens\.img_processor\.(.*)", + n).group(1), w) for n, w in vision_weights + ] + self.vision_embed_tokens.img_processor.load_weights(vision_weights) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py new file mode 100644 index 000000000000..25bc0590c745 --- /dev/null +++ b/vllm/model_executor/models/phimoe.py @@ -0,0 +1,620 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only PhiMoE model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA + + +class PhiMoEConfig(PretrainedConfig): + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class mp(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + scores: torch.Tensor, + multiplier: torch.Tensor, + selected_experts: torch.Tensor, + masked_gates: torch.Tensor, + mask_for_one: torch.Tensor, + ): + ctx.save_for_backward(multiplier, selected_experts, masked_gates) + return multiplier * mask_for_one + + @staticmethod + def backward( + ctx, + grad_at_output: torch.Tensor, + ): + multiplier, selected_experts, masked_gates = ctx.saved_tensors + + grad_at_output = grad_at_output * multiplier + + grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) + grad_at_scores_expaned.scatter_add_( + dim=-1, + index=selected_experts, + src=grad_at_output, + ) + + return ( + grad_at_scores_expaned, + None, + None, + None, + None, + ) + + +def sparsemixer(scores, jitter_eps=0.01): + ################ first expert ################ + + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts = max_ind + + # compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + multiplier = multiplier_o + + # masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, + keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, + float("-inf")) + selected_experts_top2 = max_ind + # compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2 = masked_gates_top2.gather(dim=-1, + index=selected_experts_top2) + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), + dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +def phimoe_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert topk == 2, "Only top-2 routing is supported" + assert renormalize is False, "Renormalization is not supported" + + topk_weights, topk_ids = sparsemixer(gating_output) + return topk_weights, topk_ids + + +class PhiMoE(nn.Module): + """A tensor-parallel MoE implementation for PhiMoE that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + custom_routing_function=phimoe_routing_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class PhiMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[dict] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=None, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=None, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class PhiMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = PhiMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=config.rope_scaling, + ) + self.block_sparse_moe = PhiMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + + # Self Attention + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states, residual + + +class PhiMoEModel(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], attn_metadata, + residual) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class PhiMoEForCausalLM(nn.Module, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = PhiMoEModel(config, + cache_config, + quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=None, + bias=True, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b7d017d5f3ea..8298e3bac446 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -22,12 +22,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b95987c16ebc..a64e08c422bc 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -40,13 +40,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6f838947fbf2..56129515ca8d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -45,12 +45,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 073f60bb3a05..e6f95af0ff49 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -9,12 +9,10 @@ from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from transformers.models.siglip.modeling_siglip import SiglipAttention -from vllm_flash_attn import flash_attn_func -from xformers.ops import memory_efficient_attention +from xformers import ops as xops from vllm.config import ModelConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -221,9 +219,7 @@ def forward(self, return embeddings -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): Implement TP version of Attention -class SiglipTPAttention(nn.Module): +class SiglipAttention(nn.Module): def __init__( self, @@ -233,38 +229,30 @@ def __init__( super().__init__() self.config = config self.embed_dim = config.hidden_size - - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - if self.total_num_heads % tp_size != 0: - raise ValueError( - f"Number of attention heads ({self.total_num_heads}) " - "must be divisible by the tensor model parallel size" - f" ({tp_size}).") - - self.num_heads = self.total_num_heads // tp_size - self.head_dim = self.embed_dim // self.total_num_heads - if self.head_dim * self.total_num_heads != self.embed_dim: + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: raise ValueError(f"embed_dim must be divisible by num_heads (got " "`embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") - self.qkv_size = self.num_heads * self.head_dim + self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, - total_num_heads=self.total_num_heads, + total_num_heads=self.num_heads, quant_config=quant_config, ) + self.out_proj = RowParallelLinear( input_size=self.embed_dim, output_size=self.embed_dim, quant_config=quant_config, ) - self.attn_fn = self._basic_attention_forward + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) def forward( self, @@ -274,163 +262,29 @@ def forward( batch_size, q_len, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv_states.split( - [self.qkv_size] * 3, dim=-1) - - attn_output = self.attn_fn( - q=query_states, - k=key_states, - v=value_states, - batch_size=batch_size, - q_len=q_len, - ) - - attn_output, _ = self.out_proj(attn_output) - return attn_output - - def _basic_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - k_v_seq_len = k.shape[-2] - attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale - - if attn_weights.size() != ( - batch_size, - self.num_heads, - q_len, - k_v_seq_len, - ): - raise ValueError( - "Attention weights should be of size " - f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}") - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to(q.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, v) - - if attn_output.size() != ( - batch_size, - self.num_heads, - q_len, - self.head_dim, - ): - raise ValueError( - "`attn_output` should be of size " - f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): flash_attn_func is not working properly. -# It constantly throws a CUDA error. -class SiglipFlashAttention2(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._flash_attention_forward - - # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449 - # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133 - def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args, - **kwargs): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the - query, key, and value. (B, S, H, D) - """ - - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = flash_attn_func( - q, - k, - v, - dropout_p=self.dropout, - causal=False, - ) - - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) return attn_output -# NOTE: Not used - kept for later when we TP the ViT -class SiglipSdpaAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False - self.attn_fn = self._sdpa_attention_forward - - def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, q_len, self.embed_dim) - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -class SiglipxFormersAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._xformers_attention_forward - - def _xformers_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = memory_efficient_attention(q, - k, - v, - p=0.0, - scale=self.scale) - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -SIGLIP_ATTENTION_CLASSES = { - "eager": SiglipTPAttention, - "flash_attention_2": SiglipFlashAttention2, - "sdpa": SiglipSdpaAttention, - "xformers": SiglipxFormersAttention, -} - - class SiglipMLP(nn.Module): def __init__( @@ -473,8 +327,7 @@ def __init__( super().__init__() self.embed_dim = config.hidden_size - # TODO(ChristopherCho): use TP'ed Attention block - self.self_attn = SiglipAttention(config) + self.self_attn = SiglipAttention(config, quant_config=quant_config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -491,7 +344,7 @@ def forward( residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index decbf89d27c7..6236426dcd4e 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -36,12 +36,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class StablelmMLP(nn.Module): diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index d1b1d210b727..d3a3a83c8437 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -35,12 +35,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class Starcoder2Attention(nn.Module): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 03d622322551..7994945c5ac3 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -8,7 +8,6 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union, cast) -import librosa import numpy as np import torch import torch.utils.checkpoint @@ -27,6 +26,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (filter_weights, @@ -37,7 +37,7 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig _AUDIO_PLACEHOLDER_TOKEN = 128002 @@ -106,6 +106,11 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): feature_extractor = whisper_feature_extractor(ctx) if sr != feature_extractor.sampling_rate: + try: + import librosa + except ImportError: + raise ImportError( + "Please install vllm[audio] for audio support.") from None audio = librosa.resample(audio, orig_sr=sr, target_sr=feature_extractor.sampling_rate) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index aa28600875ea..af782a5b3d88 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,6 @@ from typing import (Callable, Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) -import numpy as np import torch import torch.nn as nn from torch.func import functional_call @@ -96,12 +95,13 @@ def flatten_bn( def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: """ - Recursively concatenates NestedTensors along any heterogeneously sized - dimensions. + Recursively flattens and concatenates NestedTensors on all but the last + dimension. """ if isinstance(embeddings, torch.Tensor): - return embeddings + # Flatten all but the last dimension. + return embeddings.flatten(0, -2) return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) @@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, assert isinstance(num_expected_tokens, int) flattened = _flatten_embeddings(multimodal_embeddings) - *dims, embed_dim = flattened.shape - num_multimodal_embeddings = np.prod(dims) - if num_multimodal_embeddings != num_expected_tokens: + if flattened.shape[0] != num_expected_tokens: expr = _embedding_count_expression(multimodal_embeddings) raise ValueError( - f"Attempted to assign {expr} = {num_multimodal_embeddings} " + f"Attempted to assign {expr} = {flattened.shape[0]} " f"multimodal tokens to {num_expected_tokens} placeholders") - inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) + inputs_embeds[mask] = flattened return inputs_embeds diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index c0bafa9367e4..24cc3728f85e 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index c02e61596927..17ef9938d057 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -54,8 +54,8 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return nested_tensors stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] - if is_list_of(stacked, list): - # Do not stack nested lists + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. return stacked tensors_ = cast(List[torch.Tensor], stacked) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 989b2e1a814c..4bed267e9963 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,11 +1,9 @@ import base64 from functools import lru_cache from io import BytesIO -from typing import List, Optional, Tuple, TypeVar, Union +from typing import Any, List, Optional, Tuple, TypeVar, Union -import librosa import numpy as np -import soundfile from PIL import Image from vllm.connections import global_http_connection @@ -73,10 +71,22 @@ async def async_fetch_image(image_url: str, return image.convert(image_mode) +def try_import_audio_packages() -> Tuple[Any, Any]: + try: + import librosa + import soundfile + except ImportError: + raise ImportError( + "Please install vllm[audio] for audio support.") from None + return librosa, soundfile + + def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: """ Load audio from a URL. """ + librosa, _ = try_import_audio_packages() + if audio_url.startswith("http"): audio_bytes = global_http_connection.get_bytes( audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) @@ -95,6 +105,8 @@ async def async_fetch_audio( """ Asynchronously fetch audio from a URL. """ + librosa, _ = try_import_audio_packages() + if audio_url.startswith("http"): audio_bytes = await global_http_connection.async_get_bytes( audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) @@ -123,6 +135,8 @@ def encode_audio_base64( sampling_rate: int, ) -> str: """Encode audio as base64.""" + _, soundfile = try_import_audio_packages() + buffered = BytesIO() soundfile.write(buffered, audio, sampling_rate, format="WAV") diff --git a/vllm/scripts.py b/vllm/scripts.py index a9ddfcf86413..e557961a335b 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -125,6 +125,15 @@ def main(): serve_parser.add_argument("model_tag", type=str, help="The model tag to serve") + serve_parser.add_argument( + "--config", + type=str, + default='', + required=False, + help="Read CLI options from a config file." + "Must be a YAML with the following options:" + "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server" + ) serve_parser = make_arg_parser(serve_parser) serve_parser.set_defaults(dispatch_function=serve) diff --git a/vllm/sequence.py b/vllm/sequence.py index 3125acc6fd53..87b3d21fa7ae 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1060,76 +1060,6 @@ def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: List[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # Spec decode metrics populated by workers. - spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int): - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return ( - f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr}, " - f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") - - class PoolerOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -1295,6 +1225,7 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + use_async_and_multi_step: bool = False @property def is_first_multi_step(self) -> bool: @@ -1341,4 +1272,5 @@ def clone( finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) + async_callback=self.async_callback, + use_async_and_multi_step=self.use_async_and_multi_step) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 8a691d65aaa0..b2204e8b27af 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -5,8 +5,9 @@ import torch from vllm import SamplingParams +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SamplerOutput, SequenceData, SequenceGroupMetadata, + SequenceData, SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index aedf0a83da07..6e35e4029438 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -3,6 +3,7 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.sampler import SamplerOutput try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -16,8 +17,7 @@ PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalInputs -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput) +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index d1809e49c2a8..0d233f393cb8 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -4,8 +4,8 @@ import torch from vllm.model_executor import SamplingMetadata -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 76e444387816..fc41bb82ea34 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -3,8 +3,8 @@ import torch from vllm.model_executor import SamplingMetadata -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 2dfbacfb7b75..4b53fbe056c4 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -4,8 +4,9 @@ import torch -from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, - SequenceData, SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 806480b5c892..36e5e1774aa0 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -3,7 +3,8 @@ import torch -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index efb8ee25ba2f..28a537593f26 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposer from vllm.worker.worker_base import LoraNotSupportedWorkerBase diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index 215ede52fb81..8896b7dbc6b8 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -6,7 +6,8 @@ init_model_parallel_group, patch_tensor_parallel_group) from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9b1f21fcb492..91f0a98c7bc3 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -8,12 +8,13 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, - HiddenStates, SamplerOutput, SequenceGroupMetadata, + HiddenStates, SequenceGroupMetadata, get_all_seq_ids, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner @@ -624,8 +625,8 @@ def _verify_tokens( seq_group_metadata_list, proposal_lens_list) original_indices = spec_indices + non_spec_indices - # Get probabilities of target model, excluding bonus token. - proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1] + # Get probabilities of target model, including bonus tokens. + proposal_verifier_probs = proposal_scores.probs[spec_indices] # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] @@ -650,13 +651,12 @@ def _verify_tokens( } accepted_token_ids = self.spec_decode_sampler( - target_probs=proposal_verifier_probs, + target_with_bonus_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, **sampler_extra_kwargs, ) - # Append output tokens from non-speculative sequences to # the accepted token ids tensor. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index aa993e539b6d..f6a52a516075 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -2,8 +2,8 @@ import torch -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index d18ee47e23a5..54e718bc4901 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -4,9 +4,9 @@ import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + SequenceGroupMetadata, SequenceOutput) SeqId = int @@ -43,8 +43,8 @@ def get_sampled_token_logprobs( sampled_token_ids, ] expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( -1, -1, vocab_size) - sampled_token_ids_ranks = (logprob_tensor >= - expanded_selected_logprobs).sum(-1) + sampled_token_ids_ranks = (logprob_tensor > + expanded_selected_logprobs).sum(-1).add_(1) return sampled_token_ids_ranks, selected_logprobs diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c2276b075c1d..dfe83ddb731d 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -11,11 +11,12 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - EAGLEConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - UltravoxConfig) + EAGLEConfig, ExaoneConfig, + InternVLChatConfig, JAISConfig, + MedusaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, UltravoxConfig) +from vllm.transformers_utils.utils import check_gguf_file if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -34,6 +35,7 @@ "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, "eagle": EAGLEConfig, + "exaone": ExaoneConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "ultravox": UltravoxConfig, @@ -55,7 +57,7 @@ def get_config( ) -> PretrainedConfig: # Separate model folder from file path for GGUF models - is_gguf = Path(model).is_file() and Path(model).suffix == ".gguf" + is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = Path(model).name model = Path(model).parent @@ -107,8 +109,11 @@ def get_hf_image_processor_config( revision: Optional[str] = None, **kwargs, ) -> Dict[str, Any]: + # ModelScope does not provide an interface for image_processor + if VLLM_USE_MODELSCOPE: + return dict() # Separate model folder from file path for GGUF models - if Path(model).is_file() and Path(model).suffix == ".gguf": + if check_gguf_file(model): model = Path(model).parent return get_image_processor_config(model, revision=revision, **kwargs) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index dc2fd6a859e3..736878b35ad4 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,6 +1,7 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig +from vllm.transformers_utils.configs.exaone import ExaoneConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -22,6 +23,7 @@ "JAISConfig", "MedusaConfig", "EAGLEConfig", + "ExaoneConfig", "MLPSpeculatorConfig", "NemotronConfig", "UltravoxConfig", diff --git a/vllm/transformers_utils/configs/exaone.py b/vllm/transformers_utils/configs/exaone.py new file mode 100644 index 000000000000..805b8ad93003 --- /dev/null +++ b/vllm/transformers_utils/configs/exaone.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copied from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/configuration_exaone.py +# Copyright 2021 The LG AI Research EXAONE Lab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Exaone model configuration""" + +from typing import Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {} + + +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class: + `~transformers.ExaoneModel`. It is used to instantiate a GPT Lingvo model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` + and can be used to control the model outputs. Read the documentation from : + class:`~transformers.PretrainedConfig` for more information. + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50257): + Vocabulary size of the GPT Lingvo model. Defines the number of + different tokens that can be represented by the :obj:`inputs_ids` + passed when calling :class:`~transformers.ExaoneModel`. Vocabulary + size of the model. + Defines the different tokens that can be represented by the + `inputs_ids` passed to the forward method of :class: + `~transformers.EXAONEModel`. + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1 the model will use + Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed by meanpooling + all the original heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + rotary_pct (`float`, *optional*, defaults to 0.25): + percentage of hidden dimensions to allocate to rotary embeddings + intermediate_size (:obj:`int`, `optional`, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in + the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, + defaults to :obj:`"gelu_new"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, :obj:`"gelu"`, :obj:`"relu"`, + :obj:`"selu"` and :obj:`"gelu_new"` are supported. + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the + embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling + :class:`~transformers.EXAONEModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + Only relevant if ``config.is_decoder=True``. + gradient_checkpointing (:obj:`bool`, `optional`, + defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense + of slower backward pass. + Example:: + + >>> from transformers import ExoneModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = ExoneModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rotary_pct=0.25, + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.resid_dropout = resid_dropout + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rotary_pct = rotary_pct + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.use_logit_cap = kwargs.pop("use_logit_cap", False) + self.ln_no_scale = kwargs.pop("ln_no_scale", False) + self.use_gated = kwargs.pop("use_gated", False) + self.use_emb_norm = kwargs.pop("use_emb_norm", False) + self.use_rotary_pos = kwargs.pop("use_rotary_pos", False) + self.rotary_type = kwargs.pop("rotary_type", None) + self.scaling_factor = kwargs.pop("scaling_factor", 1) + self.use_absolute_pos = kwargs.pop("use_absolute_pos", True) + self.use_extra_logit = kwargs.pop("use_extra_logit", True) + self.rotary_expand_length = kwargs.pop("rotary_expand_length", None) + self.rotary_base = kwargs.pop("rotary_base", 10000.0) + self.use_qkv_fuse = kwargs.pop("use_qkv_fuse", False) + self.rescale_before_lm_head = kwargs.pop("rescale_before_lm_head", + (rotary_pct == 0.25)) + if self.use_rotary_pos: + self.use_absolute_pos = False diff --git a/vllm/transformers_utils/configs/granite.py b/vllm/transformers_utils/configs/granite.py new file mode 100644 index 000000000000..c12838be5d38 --- /dev/null +++ b/vllm/transformers_utils/configs/granite.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Granite model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class GraniteConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of + a [`GraniteModel`]. It is used to instantiate an Granite + model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Granite-3B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Granite model. Defines the number of + different tokens that can be represented by the `inputs_ids` + passed when calling [`GraniteModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1` the model will use + Multi Query Attention (MQA) otherwise GQA is used. When converting + a multi-head checkpoint to a GQA checkpoint, each group key and + value head should be constructed by meanpooling all the original + heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the + decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE + embeddings. Currently supports two scaling strategies: linear and + dynamic. Their scaling factor must be a float greater than 1. The + expected format is + `{"type": strategy name, "factor": scaling factor}`. + When using this flag, don't update `max_position_embeddings` to + the expected new maximum. See the following thread for more + information on how these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. + This is an experimental feature, subject to breaking API changes + in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output + projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers + in the MLP layers. + embedding_multiplier (`float`, *optional*, defaults to 1.0): + embedding multiplier + logits_scaling (`float`, *optional*, defaults to 1.0): + divisor for output logits + residual_multiplier (`float`, *optional*, defaults to 1.0): + residual multiplier + attention_multiplier (`float`, *optional*, defaults to 1.0): + attention multiplier + + ```python + >>> from transformers import GraniteModel, GraniteConfig + + >>> # Initializing a Granite granite-3b style configuration + >>> configuration = GraniteConfig() + + >>> # Initializing a model from the granite-7b style configuration + >>> model = GraniteModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "granite" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + embedding_multiplier=1.0, + logits_scaling=1.0, + residual_multiplier=1.0, + attention_multiplier=1.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + self.embedding_multiplier = embedding_multiplier + self.logits_scaling = logits_scaling + self.residual_multiplier = residual_multiplier + self.attention_multiplier = attention_multiplier + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + rope_config_validation(self) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2866975850db..f9fb8d1e103b 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -12,6 +12,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import (BaichuanTokenizer, MistralTokenizer) +from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async logger = init_logger(__name__) @@ -96,8 +97,7 @@ def get_tokenizer( kwargs["truncation_side"] = "left" # Separate model folder from file path for GGUF models - is_gguf = Path(tokenizer_name).is_file() and Path( - tokenizer_name).suffix == ".gguf" + is_gguf = check_gguf_file(tokenizer_name) if is_gguf: kwargs["gguf_file"] = Path(tokenizer_name).name tokenizer_name = Path(tokenizer_name).parent diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py new file mode 100644 index 000000000000..7a9041b04fbb --- /dev/null +++ b/vllm/transformers_utils/utils.py @@ -0,0 +1,16 @@ +from os import PathLike +from pathlib import Path +from typing import Union + + +def check_gguf_file(model: Union[str, PathLike]) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + with open(model, "rb") as f: + header = f.read(4) + return header == b"GGUF" diff --git a/vllm/utils.py b/vllm/utils.py index dab8e5fe0435..657a3ecef696 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,6 +25,7 @@ import psutil import torch import torch.types +import yaml from packaging.version import Version from typing_extensions import ParamSpec, TypeIs, assert_never @@ -1093,6 +1094,9 @@ def parse_args(self, args=None, namespace=None): if args is None: args = sys.argv[1:] + if '--config' in args: + args = FlexibleArgumentParser._pull_args_from_config(args) + # Convert underscores to dashes and vice versa in argument names processed_args = [] for arg in args: @@ -1109,6 +1113,103 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) + @staticmethod + def _pull_args_from_config(args: List[str]) -> List[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + assert args.count( + '--config') <= 1, "More than one config file specified!" + + index = args.index('--config') + if index == len(args) - 1: + raise ValueError("No config file specified! \ + Please check your command-line arguments.") + + file_path = args[index + 1] + + config_args = FlexibleArgumentParser._load_config_file(file_path) + + # 0th index is for {serve,chat,complete} + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + args = [args[0]] + config_args + args[1:index] + args[index + 2:] + + return args + + @staticmethod + def _load_config_file(file_path: str) -> List[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tensor-parallel-size': '4' + ] + + """ + + extension: str = file_path.split('.')[-1] + if extension not in ('yaml', 'yml'): + raise ValueError( + "Config file must be of a yaml/yml type.\ + %s supplied", extension) + + # only expecting a flat dictionary of atomic types + processed_args: List[str] = [] + + config: Dict[str, Union[int, str]] = {} + try: + with open(file_path, 'r') as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. \ + Make sure path is correct", file_path) + raise ex + + for key, value in config.items(): + processed_args.append('--' + key) + processed_args.append(str(value)) + + return processed_args + async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f69afa4c4314..7205b1a7beb8 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -10,11 +10,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5c700229660c..d6189d82d51d 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,9 +16,10 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b287a5d2715..8a3c99a45b14 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -29,6 +29,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import (supports_lora, @@ -41,8 +42,7 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, supports_dynamo) @@ -60,10 +60,14 @@ LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 -# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# all the token sizes that **can** be captured by cudagraph. +# they can be arbitrarily large. +# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. +# the actual sizes to capture will be determined by the model, +# depending on the model's max_num_seqs. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) ] _NUM_WARMUP_ITERS = 2 @@ -92,6 +96,7 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None + use_async_and_multi_step: bool = False def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -659,7 +664,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def _use_captured_graph(self, batch_size: int, max_decode_seq_len: int) -> bool: return (self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and batch_size <= self.runner.max_batchsize_to_capture and max_decode_seq_len <= self.runner.max_seq_len_to_capture) def build(self) -> ModelInputForGPU: @@ -845,6 +850,8 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture + self.max_batchsize_to_capture = _get_max_graph_batch_size( + self.scheduler_config.max_num_seqs) self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) @@ -862,7 +869,7 @@ def __init__( # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) num_attn_heads = self.model_config.get_num_attention_heads( self.parallel_config) @@ -1123,10 +1130,6 @@ def profile_run(self) -> None: device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return def remove_all_loras(self): @@ -1221,7 +1224,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_batch_size = self.max_batchsize_to_capture input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() @@ -1249,8 +1252,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: None ] * self.parallel_config.pipeline_parallel_size - graph_batch_size = _get_graph_batch_size( - self.scheduler_config.max_num_seqs) + graph_batch_size = self.max_batchsize_to_capture batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] @@ -1676,3 +1678,22 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _get_max_graph_batch_size(max_num_seqs: int) -> int: + """ + max_num_seqs: Maximum number of sequences in a batch. + _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. + + pad the max_num_seqs if necessary by calling _get_graph_batch_size, + which will deal with some edge cases like 1, 2, 4. + + if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. + if not, it means the padded size is larger than the largest size in + _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. + """ + padded_size = _get_graph_batch_size(max_num_seqs) + if padded_size in _BATCH_SIZES_TO_CAPTURE: + return padded_size + assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 90c39407d726..f8fd9d801d28 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -5,9 +5,9 @@ import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata if TYPE_CHECKING: from vllm.attention import AttentionMetadata diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 521205eca05a..be0c75bc00db 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,5 +1,8 @@ +import dataclasses +import functools from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -13,9 +16,12 @@ from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, + SamplerOutput, + SamplingMetadata, get_logprobs, + get_pythonized_sample_results) from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + Logprob, SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -51,6 +57,8 @@ class ModelOutput: sampler_output_ready_event: torch.cuda.Event sampled_token_ids: Optional[torch.Tensor] = None pythonized: bool = False + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None def pythonize(self, input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, @@ -76,7 +84,9 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", blocking: bool) -> bool: """ If blocking is set, will block until the forward pass for the output is - ready and pythonize the output. + ready and pythonize the output. Upon completing Pythonization, erases + self.logprobs (note that a non-blocking call that is performed when + the sampler output is not yet ready, will not erase self.logprobs.) """ assert self.sampled_token_ids is not None if not blocking and not self.sampler_output_ready_event.query(): @@ -87,7 +97,15 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", with torch.cuda.stream(copy_stream): _pythonize_sampler_output(input_metadata, self.sampler_output, pinned_sampled_token_buffer, - self.sampled_token_ids) + self.sampled_token_ids, self.logprobs) + + # Erase the logprobs GPU-side tensor. + # Note that although _pythonize_sampler_output() runs in its + # own CUDA stream, nonetheless _pythonize_sampler_output() + # cannot return until Pythonization is complete; therefore + # we know that by the time the CPU reaches this point, + # `self.logprobs` is no longer needed. + self.logprobs = None return True @@ -215,6 +233,46 @@ def prepare_model_input( ) return model_input + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Callable): + # Proceed with pythonization and output_proc in order. + # Stop on the first one that fails to pythonize + cont = True + for model_output in model_input.cached_outputs: + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + output_proc_callback( + sampler_output=model_output.sampler_output) + else: + cont = False + + if not cont: + break + + def _final_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Optional[Callable]): + assert model_input.frozen_model_input is not None + + outputs = [] + for output_id in range(len(model_input.cached_outputs)): + is_last_output = output_id == len(model_input.cached_outputs) - 1 + + output = model_input.cached_outputs[output_id] + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + if model_input.frozen_model_input.use_async_and_multi_step: + assert output_proc_callback is not None + output_proc_callback(sampler_output=output.sampler_output, + is_last_output=is_last_output) + + outputs.append(output.sampler_output) + + return outputs + @torch.inference_mode() def execute_model( self, @@ -271,6 +329,20 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) + output_proc_callback = None + if frozen_model_input.use_async_and_multi_step: + output_proc_callback = frozen_model_input.async_callback + assert output_proc_callback is not None + async_callback = functools.partial( + self._async_process_outputs, + model_input=model_input, + output_proc_callback=output_proc_callback) + + frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + assert frozen_model_input is not None + # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, kv_caches, @@ -294,16 +366,23 @@ def execute_model( 0].sampled_token_ids.cpu() model_input.cached_outputs.append( ModelOutput(output[0], output_ready_event, - output[0].sampled_token_ids, False)) - # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids, False, + output[0].logprobs)) + + # These GPU tensors are not required by multi-step; + # erase them to ensure they are not pythonized or + # transferred to CPU output[0].sampled_token_ids = None output[0].sampled_token_probs = None output[0].logprobs = None + # Pythonize the output if CPU is ahead and the previous step is # ready. - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + if not frozen_model_input.use_async_and_multi_step: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) model_input.current_step += 1 @@ -316,11 +395,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = [] - for output in model_input.cached_outputs: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - outputs.append(output.sampler_output) + outputs = self._final_process_outputs(model_input, + output_proc_callback) return outputs # should be [SamplerOutput] @@ -409,12 +485,75 @@ def vocab_size(self) -> int: return self._base_model_runner.vocab_size -def _pythonize_sampler_output(model_input: StatefulModelInput, - output: SamplerOutput, - pinned_sampled_token_buffer: torch.Tensor, - sampled_token_ids: torch.Tensor) -> None: +DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], + Optional[List[SampleLogprobs]]] + + +def deferred_pythonize_logprobs( + output: SamplerOutput, + sampling_metadata: SamplingMetadata, + logprobs_tensor: Optional[torch.Tensor], +) -> DeferredLogprobsReturnType: + """Perform deferred logprob Pythonization. + + 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result. + 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists, + utilizing the Pythonized sampler result computed in step 1. + + These deferred computations are not required for single-step scheduling + or the `profile_run()` phase of multi-step scheduling. + + Args: + output: sampler output (under deferred Pythonization) + sampling_metadata + + Returns: + prompt_logprobs (CPU), sample_logprobs (CPU) + """ + + # - Deferred pythonization of sample result + sampler_result = get_pythonized_sample_results( + output.deferred_sample_results_args) + + # - Erase the GPU-side deferred sample_result + # computation args to ensure it is never + # pythonized or transferred to CPU + output.deferred_sample_results_args = None + + # - Deferred pythonization of logprobs + ( + prompt_logprobs, + sample_logprobs, + ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) + assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) + assert len(sample_logprobs) == len(sampling_metadata.seq_groups) + + return prompt_logprobs, sample_logprobs + + +def _pythonize_sampler_output( + model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor, + logprobs_tensor: Optional[torch.Tensor], +) -> None: """ This function is only called when the output tensors are ready. - See ModelOutput + See :class:`ModelOutput`. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + adding a Pythonized output data structure + (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. + + Args: + model_input + output: sampler output + pinned_sampled_token_token_buffer: CPU-side pinned memory + (receives copy of + GPU-side token buffer.) + sampled_token_ids: GPU-side token buffer + logprobs_tensor: GPU-side tensor containing + logprobs computed during sampling """ assert model_input.frozen_model_input is not None @@ -434,8 +573,51 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, sampling_metadata = frozen_model_input.sampling_metadata - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - samples_list): + skip_sampler_cpu_output = ( + frozen_model_input.sampling_metadata.skip_sampler_cpu_output) + + # We are guaranteed output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However this computation may be skipped entirely + # if no pythonization was deferred. + seq_groups = sampling_metadata.seq_groups + logprobs_are_requested = any([ + sg.sampling_params.logprobs is not None + or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups + ]) + do_pythonize_logprobs = (skip_sampler_cpu_output + and logprobs_are_requested) + ( + prompt_logprobs, + sample_logprobs, + ) = (deferred_pythonize_logprobs(output, sampling_metadata, + logprobs_tensor) + if do_pythonize_logprobs else (None, None)) + + for sgdx, (seq_group, + sample_result) in enumerate(zip(seq_groups, samples_list)): + + if do_pythonize_logprobs: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( # Utilize deferred pythonization results + prompt_logprobs[sgdx], + sample_logprobs[sgdx], + ) + elif logprobs_are_requested: + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( + # profile_run: use already-computed logprobs + output.outputs[sgdx].prompt_logprobs, + [sample.logprobs for sample in output.outputs[sgdx].samples]) + seq_ids = seq_group.seq_ids next_token_ids = sample_result parent_ids = [0] @@ -443,11 +625,19 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, if seq_group.sampling_params.logits_processors: assert len(seq_group.sampling_params.logits_processors) == 0, ( "Logits Processors are not supported in multi-step decoding") - for parent_id, next_token_id in zip(parent_ids, next_token_ids): - # TODO(will): support logprobs - # Hard coded logprob + for tdx, (parent_id, + next_token_id) in enumerate(zip(parent_ids, next_token_ids)): seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, - {next_token_id: Logprob(logprob=-1)})) - output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + (group_sample_logprobs[tdx] + if logprobs_are_requested else { + next_token_id: + Logprob(logprob=float('inf'), + rank=None, + decoded_token=None) + }))) + output.outputs.append( + CompletionSequenceGroupOutput( + seq_outputs, + (group_prompt_logprobs if logprobs_are_requested else None))) assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 2ed77dd698f5..517b0ab78c46 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,10 +1,12 @@ +import dataclasses from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch from vllm.distributed import broadcast_tensor_dict, get_pp_group -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.model_runner_base import BroadcastableModelInput from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, StatefulModelInput) @@ -61,6 +63,13 @@ def _get_driver_input_and_broadcast( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback, + use_async_and_multi_step=execute_model_req. + use_async_and_multi_step) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine] diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 4f3fed2dbd72..f3defffdfa52 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -8,11 +8,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index a1d09a2f9e53..f335e4e32efd 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -11,10 +11,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import SequenceGroupMetadata logger = init_logger(__name__) diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index c47f9acc4423..36339e175d7b 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -14,7 +14,8 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 01daa64b5a32..a0498315516b 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) from unittest.mock import patch import numpy as np @@ -10,14 +11,15 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + Logprob, SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -50,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase): best_of: List[int] seq_groups: List[List[int]] virtual_engine: int = 0 + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -144,11 +147,7 @@ def load_model(self) -> None: ) model = model.eval() xm.wait_device_ops() - model = ModelWrapper(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) + self.model = ModelWrapper(model) def _dummy_run( self, @@ -235,8 +234,15 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) + self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + num_samples, + kv_caches, + is_prompt=is_prompt) def warmup_model( self, @@ -530,7 +536,7 @@ def _execute_model(*args): if getattr(arg, "context_lens", None) is not None: arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) - return self.model(*new_args) + return self.model(*new_args, is_prompt=is_prompt) num_prefills = model_input.attn_metadata.num_prefills is_prompt = num_prefills > 0 @@ -558,6 +564,8 @@ def _execute_model(*args): model_input.attn_metadata, model_input.input_lens[i:i + 1], model_input.t[i:i + 1], model_input.p[i:i + 1], model_input.num_samples, kv_caches) + if i == 0 and model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx @@ -568,6 +576,8 @@ def _execute_model(*args): model_input.attn_metadata, model_input.input_lens, model_input.t, model_input.p, model_input.num_samples, kv_caches) + if model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids = output_token_ids.cpu().tolist() @@ -601,11 +611,32 @@ def _execute_model(*args): return [SamplerOutput(sampler_outputs)] -class ModelWrapper(nn.Module): +class ModelWrapper(TorchCompileWrapperWithCustomDispacther): def __init__(self, model: nn.Module): - super().__init__() self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) def forward( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 320b15d3604b..44fa3aed5816 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -143,10 +143,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = int(self.cache_config.swap_space_bytes // block_size_bytes) num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return num_tpu_blocks, num_cpu_blocks def initialize_cache( diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index 79c48896469e..d73023e8e172 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.multimodal_config is not None: + if enc_dec_mr.model_config.is_multimodal_model: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7ed609c3b447..0ff559a9af53 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -17,12 +17,12 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput, SequenceGroupMetadata, - SequenceGroupMetadataDelta) + SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 012043673b09..6ba4f272315c 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -11,9 +11,9 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput) +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 3894658a095f..f9037625d4af 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -15,12 +15,12 @@ from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( From 0d6ac3a6f2cea0af0707519ce431ccc8ddd265fa Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Mon, 2 Sep 2024 21:36:26 +0530 Subject: [PATCH 14/32] Updating Branch --- vllm/transformers_utils/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f3ac8d3178d4..dfe83ddb731d 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -16,6 +16,7 @@ MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig, UltravoxConfig) +from vllm.transformers_utils.utils import check_gguf_file if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -56,7 +57,7 @@ def get_config( ) -> PretrainedConfig: # Separate model folder from file path for GGUF models - is_gguf = Path(model).is_file() and Path(model).suffix == ".gguf" + is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = Path(model).name model = Path(model).parent @@ -112,7 +113,7 @@ def get_hf_image_processor_config( if VLLM_USE_MODELSCOPE: return dict() # Separate model folder from file path for GGUF models - if Path(model).is_file() and Path(model).suffix == ".gguf": + if check_gguf_file(model): model = Path(model).parent return get_image_processor_config(model, revision=revision, **kwargs) From 831d447fd159bd4a7280a36c552dfd0c76def637 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 12:45:40 +0530 Subject: [PATCH 15/32] Refactor --- vllm/config.py | 20 +++++++++---- vllm/model_executor/models/internlm2.py | 39 ++++++++++++++++--------- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/utils.py | 19 ++---------- 4 files changed, 43 insertions(+), 37 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 66889bf0afb9..d14cc78307f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -35,11 +35,21 @@ _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _PP_SUPPORTED_MODELS = [ - "AquilaModel", "AquilaForCausalLM", "DeepseekV2ForCausalLM", - "InternLMForCausalLM", "JAISLMHeadModel", "LlamaForCausalLM", - "LLaMAForCausalLM", "MistralForCausalLM", "Phi3ForCausalLM", - "GPT2LMHeadModel", "MixtralForCausalLM", "NemotronForCausalLM", - "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", + "AquilaModel", + "AquilaForCausalLM", + "DeepseekV2ForCausalLM", + "InternLMForCausalLM", + "JAISLMHeadModel", + "LlamaForCausalLM", + "LLaMAForCausalLM", + "MistralForCausalLM", + "Phi3ForCausalLM", + "GPT2LMHeadModel", + "MixtralForCausalLM", + "NemotronForCausalLM", + "Qwen2ForCausalLM", + "Qwen2MoeForCausalLM", + "QWenLMHeadModel", "InternVLChatModel" ] diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index f191fbf3d0c2..70a7f235061a 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -8,7 +8,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank,get_pp_group, +from vllm.distributed import (get_tensor_model_parallel_rank, get_pp_group, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) @@ -27,9 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors - -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) +from .utils import is_pp_missing_parameter, make_layers class InternLM2MLP(nn.Module): @@ -237,6 +235,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -246,14 +245,12 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: InternLMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -266,7 +263,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -277,7 +274,7 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(len(self.layers)): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, @@ -338,6 +335,20 @@ def compute_logits( sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: torch.Tensor, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 0e4569428aa7..d317fdce3ba6 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -446,7 +446,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> Union[SamplerOutput, IntermediateTensors]: + ) -> SamplerOutput: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: inputs_embeds = self.language_model.model.get_input_embeddings( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index af782a5b3d88..90187a95eb90 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import (Callable, Dict, Iterable, List, Literal, Optional, Protocol, Tuple, +from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) import torch @@ -278,19 +278,4 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: for missing_layer_name in get_pp_missing_layer_names(model): if name.startswith(missing_layer_name): return True - return False - - -def make_empty_intermediate_tensors_factory(keys: List[str], - hidden_size: int) -> Callable: - - def make_empty_intermediate_tensors(batch_size: int, dtype: torch.dtype, - device: torch.device) -> NestedTensors: - return NestedTensors({ - key: torch.zeros((batch_size, hidden_size), - dtype=dtype, - device=device) - for key in keys - }) - - return make_empty_intermediate_tensors \ No newline at end of file + return False \ No newline at end of file From a3e9a989797eeac3477c512872c046fa9d5d4ca4 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 13:01:33 +0530 Subject: [PATCH 16/32] Refactor --- vllm/model_executor/models/internlm2.py | 2 -- vllm/model_executor/models/internvl.py | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 70a7f235061a..c2f84788dfae 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -311,8 +311,6 @@ def __init__( self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) def forward( self, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d317fdce3ba6..5ca8d0b6a292 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -341,8 +341,6 @@ def __init__(self, nn.Linear(llm_hidden_size, llm_hidden_size)) self.img_context_token_id = None - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -463,7 +461,7 @@ def forward( positions, kv_caches, attn_metadata, - intermediate_tensors, + None, inputs_embeds=inputs_embeds) return hidden_states From e218d07bce6207f0e4803d9b05d11c7151c313fe Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 13:16:21 +0530 Subject: [PATCH 17/32] Refactor --- vllm/model_executor/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 90187a95eb90..16565e1467e8 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -278,4 +278,4 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: for missing_layer_name in get_pp_missing_layer_names(model): if name.startswith(missing_layer_name): return True - return False \ No newline at end of file + return False From ca6e920d2ad819acad5549354702e117f562417d Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 22:15:30 +0530 Subject: [PATCH 18/32] test case completion --- vllm/model_executor/models/internlm2.py | 21 ++++++--------------- vllm/model_executor/models/internvl.py | 4 +++- vllm/model_executor/models/utils.py | 18 +++++++++++++++++- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index c2f84788dfae..7322f8dbace1 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import is_pp_missing_parameter, make_layers +from .utils import is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory class InternLM2MLP(nn.Module): @@ -251,6 +251,9 @@ def __init__( quant_config), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -311,6 +314,8 @@ def __init__( self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -333,20 +338,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: torch.Tensor, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 5ca8d0b6a292..d317fdce3ba6 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -341,6 +341,8 @@ def __init__(self, nn.Linear(llm_hidden_size, llm_hidden_size)) self.img_context_token_id = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -461,7 +463,7 @@ def forward( positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 16565e1467e8..1e70b1b86918 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,5 @@ from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, - Union, overload) + Union, overload, Callable) import torch import torch.nn as nn @@ -13,6 +13,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors from vllm.utils import is_pin_memory_available +from vllm.sequence import IntermediateTensors def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): @@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: if name.startswith(missing_layer_name): return True return False + +def make_empty_intermediate_tensors_factory(keys: List[str], + hidden_size: int) -> Callable: + + def make_empty_intermediate_tensors( + batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + key: torch.zeros((batch_size, hidden_size), + dtype=dtype, + device=device) + for key in keys + }) + + return make_empty_intermediate_tensors From 7591225c24cec1a8fc2da801227cc90c6067e36a Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 22:17:24 +0530 Subject: [PATCH 19/32] test case completion --- vllm/model_executor/models/internlm2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 7322f8dbace1..1ec90e50fbf1 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory +from .utils import (is_pp_missing_parameter, + make_layers, make_empty_intermediate_tensors_factory) class InternLM2MLP(nn.Module): From 85504b182da6daf4798a383c1d86a3313c78affd Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 22:19:25 +0530 Subject: [PATCH 20/32] fixing imports --- vllm/model_executor/models/internlm2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 1ec90e50fbf1..11a8431a5e7f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -8,7 +8,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, get_pp_group, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) @@ -27,8 +27,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors + from .utils import (is_pp_missing_parameter, - make_layers, make_empty_intermediate_tensors_factory) + make_empty_intermediate_tensors_factory, make_layers) class InternLM2MLP(nn.Module): From d14d6b6d1fc500a78baaa4aa1cb1d976b840ecf4 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 22:20:54 +0530 Subject: [PATCH 21/32] fixing imports in utils.py --- vllm/model_executor/models/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1e70b1b86918..4e69d23d0b27 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,5 @@ -from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, - Union, overload, Callable) +from typing import (Callable, Dict, Iterable, List, Literal, Optional, + Protocol, Tuple, Union, overload) import torch import torch.nn as nn @@ -12,8 +12,8 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors -from vllm.utils import is_pin_memory_available from vllm.sequence import IntermediateTensors +from vllm.utils import is_pin_memory_available def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): From 41f83dd8ecd397a96ba9363b7023c8df6830e930 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Tue, 3 Sep 2024 23:26:44 +0530 Subject: [PATCH 22/32] optional settings for tokeniser for trust_remote_code --- tests/distributed/test_pipeline_parallel.py | 31 +++++++++++---------- tests/utils.py | 6 +++- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 04ee5945e640..9df1d4d041c1 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -18,23 +18,23 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " +@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE " "MODEL_NAME, DIST_BACKEND"), [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 2, 1, 1, "OpenGVLab/InternVL2-8B", "ray"), + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-8B", "ray"), ]) @fork_new_process_for_each_test -def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, +def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " @@ -72,9 +72,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, if EAGER_MODE: pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + if TRUST_REMOTE_CODE: + pp_args.append("--trust-remote-code") + tp_args.append("--trust-remote-code") pp_env = None - pp_args.append("--trust-remote-code") - tp_args.append("--trust-remote-code") if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 and CHUNKED_PREFILL): # Test Ray ADAG for a subset of the tests diff --git a/tests/utils.py b/tests/utils.py index f33340f0c755..7c4a0a833fcb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,7 +178,11 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ - tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + trust_remote_code = "--trust-remote-code" + if trust_remote_code in arg1 or trust_remote_code in arg2: + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(model) prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] From 003466351df9ac6e2045321e0e065570bc2b61c6 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 10:46:13 +0530 Subject: [PATCH 23/32] InternLM2ForCausalLM and internlm/internlm2_5-7b-chat inclusion --- tests/distributed/test_pipeline_parallel.py | 3 ++- vllm/config.py | 3 ++- vllm/model_executor/models/utils.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 9df1d4d041c1..286f8dc07949 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -18,7 +18,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE " +@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " "MODEL_NAME, DIST_BACKEND"), [ (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), @@ -32,6 +32,7 @@ (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-8B", "ray"), + (1, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray") ]) @fork_new_process_for_each_test def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, MODEL_NAME, diff --git a/vllm/config.py b/vllm/config.py index d14cc78307f6..d55727e6ed70 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -50,7 +50,8 @@ "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", - "InternVLChatModel" + "InternVLChatModel", + "InternLM2ForCausalLM" ] diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 4e69d23d0b27..32edb462ede4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import (Callable, Dict, Iterable, List, Literal, Optional, +from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) import torch @@ -282,7 +282,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: return False def make_empty_intermediate_tensors_factory(keys: List[str], - hidden_size: int) -> Callable: + hidden_size: int): def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, From 0ebf029c3c108c72ee697de0c0cfd1974ed29efd Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 11:16:52 +0530 Subject: [PATCH 24/32] line formatting --- tests/distributed/test_pipeline_parallel.py | 36 ++++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 286f8dc07949..029464267ee6 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -18,22 +18,26 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " - "MODEL_NAME, DIST_BACKEND"), - [ - (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-8B", "ray"), - (1, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray") - ]) +@pytest.mark.parametrize( + ( + "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " + "MODEL_NAME, DIST_BACKEND" + ), + [ + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-8B", "ray"), + (1, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), + ], +) @fork_new_process_for_each_test def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): From f689444a10a6277f4ed6be9e987e0a0fa8b1b053 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 11:18:36 +0530 Subject: [PATCH 25/32] line formatting --- tests/distributed/test_pipeline_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 029464267ee6..721e23a207b8 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -39,8 +39,8 @@ ], ) @fork_new_process_for_each_test -def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, MODEL_NAME, - DIST_BACKEND): +def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, + TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") From 0cd208c18c8f0d7e308612c348a84c091de6a6bc Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 11:49:10 +0530 Subject: [PATCH 26/32] sorting in utils.py --- vllm/model_executor/models/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 32edb462ede4..22563415cb14 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,5 @@ -from typing import (Dict, Iterable, List, Literal, Optional, - Protocol, Tuple, Union, overload) +from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, + Union, overload) import torch import torch.nn as nn From 9c3ef5f847dcedb9e83347842c3781537866b85d Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 11:59:07 +0530 Subject: [PATCH 27/32] formatting --- tests/distributed/test_pipeline_parallel.py | 6 ++---- tests/utils.py | 3 ++- vllm/model_executor/models/utils.py | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 721e23a207b8..248e17667ab4 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -19,10 +19,8 @@ @pytest.mark.parametrize( - ( - "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " - "MODEL_NAME, DIST_BACKEND" - ), + ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " + "MODEL_NAME, DIST_BACKEND"), [ (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), diff --git a/tests/utils.py b/tests/utils.py index 7c4a0a833fcb..04067ef372ac 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -180,7 +180,8 @@ def compare_two_settings(model: str, trust_remote_code = "--trust-remote-code" if trust_remote_code in arg1 or trust_remote_code in arg2: - tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model, + trust_remote_code=True) else: tokenizer = AutoTokenizer.from_pretrained(model) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 22563415cb14..9025ef94ce41 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -281,8 +281,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: return True return False -def make_empty_intermediate_tensors_factory(keys: List[str], - hidden_size: int): +def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, From 7e32e349b01b949a9d17afd2414010d86020fccf Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 12:01:44 +0530 Subject: [PATCH 28/32] formatting --- vllm/model_executor/models/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 9025ef94ce41..8b80dda96db4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -281,6 +281,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: return True return False + def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( From 468c9941694cffa659656e0d5f3f9dc728737370 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 12:17:42 +0530 Subject: [PATCH 29/32] formatting --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index d55727e6ed70..bc4a9ebe409a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -50,8 +50,8 @@ "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", + "InternLM2ForCausalLM", "InternVLChatModel", - "InternLM2ForCausalLM" ] From 92cfe5330e1b82941320a78b81390d0a2f3d441b Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 12:21:22 +0530 Subject: [PATCH 30/32] formatting _PP_SUPPORTED_MODELS list --- vllm/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index bc4a9ebe409a..e5e88c580919 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -35,23 +35,23 @@ _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _PP_SUPPORTED_MODELS = [ - "AquilaModel", "AquilaForCausalLM", + "AquilaModel", "DeepseekV2ForCausalLM", + "GPT2LMHeadModel", + "InternLM2ForCausalLM", "InternLMForCausalLM", + "InternVLChatModel", "JAISLMHeadModel", "LlamaForCausalLM", "LLaMAForCausalLM", "MistralForCausalLM", - "Phi3ForCausalLM", - "GPT2LMHeadModel", "MixtralForCausalLM", "NemotronForCausalLM", + "Phi3ForCausalLM", "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", - "InternLM2ForCausalLM", - "InternVLChatModel", ] From 9a31c28f5f8513c589885142e8ca0f35c1710b5c Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Wed, 4 Sep 2024 16:46:41 +0530 Subject: [PATCH 31/32] removal of internlm/internlm2_5-7b-chat in test_pipeline_parallel.py --- tests/distributed/test_pipeline_parallel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 248e17667ab4..00a3d555518d 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -32,7 +32,6 @@ (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-8B", "ray"), (1, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), ], ) From cb3f60273cd7fac79d2c92d29fb11270787c0ad1 Mon Sep 17 00:00:00 2001 From: manikandan-tm Date: Thu, 5 Sep 2024 14:58:47 +0530 Subject: [PATCH 32/32] increasing TP_SIZE in test_pipeline_parallel.py --- tests/distributed/test_pipeline_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 00a3d555518d..637d2b30f6b1 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -32,7 +32,7 @@ (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), + (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), ], ) @fork_new_process_for_each_test