From 4b9b955c39f9d5ddeb84bbdc6a60634bac521481 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 18 Jun 2024 11:39:37 +0300 Subject: [PATCH 01/13] support hqt on vllm --- setup.py | 2 +- vllm/attention/backends/habana_attn.py | 95 +++++++++++++++++-- vllm/model_executor/layers/linear.py | 14 ++- .../model_executor/layers/logits_processor.py | 5 +- vllm/utils.py | 15 +++ 5 files changed, 117 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index 964c467fd0a3f..c33557e9f115b 100644 --- a/setup.py +++ b/setup.py @@ -229,7 +229,7 @@ def _is_neuron() -> bool: torch_neuronx_installed = True try: subprocess.run(["neuron-ls"], capture_output=True, check=True) - except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): + except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False return torch_neuronx_installed or envs.VLLM_BUILD_WITH_NEURON diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 017cf9c8933e5..a77a23593b7fb 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -5,9 +5,11 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type +import os import torch import math import vllm.hpu.xops as xops +import vllm.hpu.utils from vllm.hpu.attn_bias import (AttentionBias, LowerTriangularMaskWithTensorBias) @@ -17,6 +19,7 @@ from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) from vllm.logger import init_logger +from vllm.utils import Matmul, Softmax logger = init_logger(__name__) @@ -111,7 +114,11 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None -class HabanaAttentionImpl(AttentionImpl): +import habana_frameworks.torch.utils.experimental as htexp +PA_SPLIT_VALUE_DEFAULT = '0' if (htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi3) else '1' +PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', PA_SPLIT_VALUE_DEFAULT) == '1') + +class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -137,8 +144,12 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, ) -> None: + super(AttentionImpl, self).__init__() self.num_heads = num_heads self.head_size = head_size + self.qk_matmul = Matmul() + self.softmax = Softmax() + self.kv_matmul = Matmul() self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window @@ -155,6 +166,75 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") + @vllm.hpu.utils.with_mark_steps + def prompt_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, + scale: Optional[float] = None, + ) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + query_heads = query.size(1) + kv_heads = key.size(1) + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + attn_bias = attn_bias.unsqueeze(2) + attn_weights = self.qk_matmul(query * scale, key.transpose(-1, -2)) + if attn_bias is not None: + attn_weights.add_(attn_bias) + attn_weights = self.softmax(attn_weights, dim=-1) + attn_weights = self.kv_matmul(attn_weights, value) + if query_heads != kv_heads: + attn_weights = attn_weights.flatten(1, 2) + attn_weights = attn_weights.transpose(1, 2) + return attn_weights + + def _fetch_from_cache(self, cache, blocks, permutations): + return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] + + @vllm.hpu.utils.with_mark_steps + def paged_attention_v1(self, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None: + seq_len = block_tables.size(1) + batch_size, query_heads, _ = query.shape + _, _, kv_heads, _ = key_cache.shape + min_inf = torch.finfo(query.dtype).min + mask = (torch.arange(0, seq_len * block_size, dtype=torch.int32, device=key_cache.device) + .view(1, -1) + .expand(batch_size, -1) + .ge(context_lens.view(-1, 1)) + .view(batch_size, 1, 1, -1)) + query.mul_(scale) + query = query.unsqueeze(-2) + keys = self._fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] + mask = mask.unsqueeze(2) + + attn_weights = [self.qk_matmul(query, k) for k in keys] + attn_weights = self.softmax(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf), + dim=-1) + + values = self._fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) + if PA_SPLIT_VALUE: + attn_weights = attn_weights.split(block_size, dim=-1) + else: + values = [torch.cat(values, dim=-2)] + attn_weights = [attn_weights] + if query_heads != kv_heads: + values = [v.unflatten(1, (kv_heads, 1)) for v in values] + attn_weights = [self.kv_matmul(a, v) for a, v in zip(attn_weights, values)] + if query_heads != kv_heads: + attn_weights = [a.flatten(1, 2) for a in attn_weights] + attn_weights = sum(attn_weights) + return attn_weights.squeeze(-2) + def forward( self, query: torch.Tensor, @@ -201,7 +281,7 @@ def forward( assert prefill_meta.attn_bias is not None, 'attn_bias must be set before calling model.forward!' query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) - out = xops.prompt_attention( + out = self.prompt_attention( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), @@ -227,17 +307,18 @@ def forward( ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = HabanaPagedAttention.forward_decode( + block_size = value_cache.shape[1] + output = self.paged_attention_v1( query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + block_size, self.alibi_slopes, - kv_scale + attn_metadata.kv_cache_dtype, ) # Reshape the output tensor. diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 7726dcb9a5fbd..39f845c98d46c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -225,6 +225,7 @@ def __init__( quant_config) self.gather_output = gather_output + self.collective_func = tensor_model_parallel_all_gather # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() @@ -281,7 +282,7 @@ def forward(self, input_): output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) + output = self.collective_func(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None @@ -624,6 +625,7 @@ def __init__( self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results + self.collective_func = tensor_model_parallel_all_reduce # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() @@ -674,8 +676,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def forward(self, input_): - # Set up backprop all-reduce. + def resolve_input(self, input_): if self.input_is_parallel: input_parallel = input_ else: @@ -683,12 +684,17 @@ def forward(self, input_): splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[tp_rank].contiguous() + return input_parallel + + def forward(self, input_): + # Set up backprop all-reduce. + input_parallel = self.resolve_input(input_) # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_parallel) if self.reduce_results and self.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) + output_ = self.collective_func(output_parallel) else: output_ = output_parallel diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 3951619c6e3ec..502ab3f368d55 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,7 +6,7 @@ from vllm.distributed import tensor_model_parallel_gather, tensor_model_parallel_all_gather from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.utils import is_hpu +from vllm.utils import is_hpu, Matmul class LogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. @@ -33,6 +33,7 @@ def __init__(self, self.logits_as_input = logits_as_input # original vocabulary size (without LoRA). self.org_vocab_size = org_vocab_size or vocab_size + self.matmul = Matmul() def forward( self, @@ -63,7 +64,7 @@ def forward( def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) + logits = self.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias # NOTE(kzawora): HPU PT bridge is missing support for single-rank gather. We'll use all-gather on Gaudi for now. diff --git a/vllm/utils.py b/vllm/utils.py index e7a2cde3e0f5d..8e5c220c9374b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -54,6 +54,21 @@ def reset(self) -> None: self.counter = 0 +class Matmul(torch.nn.Module): + def __init__(self): + super(Matmul, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class Softmax(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim = None, inv_head = None): + return torch.softmax(x, dim) + class LRUCache(Generic[T]): def __init__(self, capacity: int): From f3ffc8c44fc3441c2261f50832146f4a027f59cf Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 18 Jun 2024 11:39:37 +0300 Subject: [PATCH 02/13] Support HQT on VLLM - KVCache and Mark Step uses --- vllm/attention/backends/habana_attn.py | 16 +++++++--------- vllm/hpu/utils.py | 13 ++++++++++++- vllm/model_executor/models/llama.py | 3 +++ vllm/utils.py | 3 ++- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index a77a23593b7fb..99602ee2b0876 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -8,8 +8,7 @@ import os import torch import math -import vllm.hpu.xops as xops -import vllm.hpu.utils +from vllm.hpu.utils import VLLMKVCache from vllm.hpu.attn_bias import (AttentionBias, LowerTriangularMaskWithTensorBias) @@ -150,6 +149,7 @@ def __init__( self.qk_matmul = Matmul() self.softmax = Softmax() self.kv_matmul = Matmul() + self.key_value_cache = VLLMKVCache() self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window @@ -166,7 +166,6 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - @vllm.hpu.utils.with_mark_steps def prompt_attention(self, query: torch.Tensor, key: torch.Tensor, @@ -198,7 +197,6 @@ def prompt_attention(self, def _fetch_from_cache(self, cache, blocks, permutations): return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] - @vllm.hpu.utils.with_mark_steps def paged_attention_v1(self, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape @@ -268,11 +266,11 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - HabanaPagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - attn_metadata.prefill_metadata is not None) + key_cache, value_cache = self.key_value_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + attn_metadata.prefill_metadata is not None) if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 8d7f388cf262a..2956fed66c77e 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. ############################################################################### +import torch import habana_frameworks.torch as htorch +from vllm.attention.ops.habana_paged_attn import HabanaPagedAttention def with_mark_steps(fn): def wrapped(*args, **kwargs): @@ -96,4 +98,13 @@ def tfidf_backend(recipes): cm.ax_.set_ylabel("Source recipe number") plt.title(f'Recipe similarity ({backend_name})') return plt -# plt.savefig('similarity.png') \ No newline at end of file +# plt.savefig('similarity.png') + +class VLLMKVCache(torch.nn.Module): + def __init__(self): + super(VLLMKVCache, self).__init__() + + def forward(self, key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, is_prefill_metadata): + HabanaPagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, is_prefill_metadata) + return key_cache, value_cache diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f6d7fc8733fce..b773abaf57411 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -286,6 +286,8 @@ def forward( else: hidden_states = self.get_input_embeddings(input_ids) residual = None + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( @@ -295,6 +297,7 @@ def forward( attn_metadata, residual, ) + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/utils.py b/vllm/utils.py index 8e5c220c9374b..59389fe348ab3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,7 +31,7 @@ "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, - "fp8": torch.uint8, + "fp8": torch.float8_e4m3fn, } @@ -69,6 +69,7 @@ def __init__(self): def forward(self, x, dim = None, inv_head = None): return torch.softmax(x, dim) + class LRUCache(Generic[T]): def __init__(self, capacity: int): From 8ffc3d0138cf283413de87a2cca84005112d64e2 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 18 Jun 2024 11:39:37 +0300 Subject: [PATCH 03/13] HQT on VLLM - prep model and finish measurements and multi cards run --- vllm/distributed/parallel_state.py | 3 + vllm/engine/llm_engine.py | 3 + vllm/entrypoints/llm.py | 3 + vllm/executor/habana_executor.py | 9 ++ vllm/executor/ray_habana_executor.py | 3 + .../layers/quantization/__init__.py | 2 + .../model_executor/layers/quantization/hqt.py | 113 ++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 18 +-- vllm/worker/habana_model_runner.py | 30 ++++- vllm/worker/habana_worker.py | 12 ++ 10 files changed, 187 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/hqt.py diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be5bb4e857caf..d31e5a01266a9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -84,6 +84,9 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK global _LOCAL_RANK _LOCAL_RANK = local_rank + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["RANK"] = str(rank) def initialize_model_parallel( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0db6cc77342ec..26258f42a61a6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -537,6 +537,9 @@ def _process_model_outputs( request_outputs.append(request_output) return request_outputs + def finish_measurements(self): + self.model_executor.finish_measurements() + def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3ed660e183360..02fae07f9322c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,6 +134,9 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer + def finish_measurements(self): + self.llm_engine.finish_measurements() + def generate( self, prompts: Optional[Union[str, List[str]]] = None, diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index cfad194bf9cca..c0982d7864815 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -82,6 +82,9 @@ def initialize_cache(self, num_gpu_blocks : int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) logger.info(f"init_cache_engine took {cache_init_m.get_summary_string()}") + def finish_measurements(self): + self.driver_worker.finish_measurements() + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: @@ -128,6 +131,12 @@ def check_health(self) -> None: # it's running. return + def shutdown(self) -> None: + self.driver_worker.shutdown_hqt() + + def __del__(self): + self.shutdown() + class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index a17f509f11658..fb269b07bf5ce 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -146,6 +146,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + def finish_measurements(self): + self._run_workers("finish_measurements") + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1c652e347d4ad..4ece0ba7138bb 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -10,6 +10,7 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.hqt import HQTConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -19,6 +20,7 @@ "squeezellm": SqueezeLLMConfig, "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, + "hqt": HQTConfig, } diff --git a/vllm/model_executor/layers/quantization/hqt.py b/vllm/model_executor/layers/quantization/hqt.py new file mode 100644 index 0000000000000..ab3857436582c --- /dev/null +++ b/vllm/model_executor/layers/quantization/hqt.py @@ -0,0 +1,113 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class HQTConfig(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "hqt" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "HQTConfig": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["HQTLinearMethod"]: + if isinstance(layer, LinearBase): + return HQTLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_min_capability(self) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + +class HQTLinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: HQTConfig, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) \ No newline at end of file diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bafa2de62e5df..c9182c0a910f5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -26,6 +26,7 @@ get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.llava import LlavaForConditionalGeneration +from vllm.utils import is_hpu _VISION_MODEL_CLASSES = [ LlavaForConditionalGeneration, @@ -40,14 +41,15 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + if not is_hpu(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 6a9cb6f066ea1..5a604255829b8 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -281,6 +281,11 @@ def load_model(self) -> None: parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, ) + if self.model_config.quantization == 'hqt': + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(self.model) + import habana_frameworks.torch.core as htcore + htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) logger.info(f"Pre-loading model weights on {next(self.model.parameters()).device} took {m_getmodel.get_summary_string()}") # FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged @@ -810,6 +815,10 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: {'prefill_metadata': prefill_metadata, 'decode_metadata': decode_metadata}) + def finish_measurements(self): + import habana_quantization_toolkit + habana_quantization_toolkit.finish_measurements(self.model.model) + @torch.inference_mode() def execute_model( self, @@ -942,10 +951,16 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): logger.info(f"[Warmup][{phase}][{i+1}/{max_i}] batch_size:{batch_size} seq_len:{seq_len} free_mem:{free_mem}") def warmup_all_buckets(self, buckets, is_prompt, kv_caches): + counter = 0 for i, (batch_size, seq_len) in enumerate(reversed(buckets)): mem_usage = 100.0 * HabanaMemoryProfiler.current_device_memory_usage() / HabanaMemoryProfiler.total_device_memory() self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + try: + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + except: + print(f"Failed on scenario {i+1}: batch_size={batch_size}, seq_len={seq_len}, is_prompt={is_prompt}") + counter += 1 + print(f"Failed warm-up scenarios = {counter}") def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem): total_batch_seq = 0.001 @@ -1008,6 +1023,19 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info(f"Warmup finished in {elapsed_time:.0f} secs, allocated {format_bytes(end_mem - start_mem)} of device memory") self.profiler.end() + def shutdown_hqt(self): + print('hqt shutdown') + if model_config := getattr(self, "model_config", None): + if getattr(model_config, "quantization", None) == 'hqt': + print('hqt shutdown start') + import habana_quantization_toolkit + if habana_quantization_toolkit is not None: + habana_quantization_toolkit.finish_measurements(self.model.model) + print('hqt shutdown') + + def __del__(self): + self.shutdown_hqt() + @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index e253e4479a855..57fc43285b75c 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -101,6 +101,9 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): + if self.model_config.quantization == 'hqt': + import habana_frameworks.torch.core as htcore + htcore.hpu_set_env() self.model_runner.load_model() @torch.inference_mode() @@ -184,6 +187,9 @@ def cache_swap( if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) + def finish_measurements(self): + self.model_runner.finish_measurements() + @torch.inference_mode() def execute_model( self, @@ -236,6 +242,12 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: raise NotImplementedError("LoRA is not implemented for HPU backend.") + def shutdown_hqt(self): + self.model_runner.shutdown_hqt() + + def __del__(self): + self.shutdown_hqt() + @property def max_model_len(self) -> int: return self.model_config.max_model_len From f5f0972297ef87705165af24698f606049c9ca07 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 18 Jun 2024 11:39:37 +0300 Subject: [PATCH 04/13] HQT on VLLM - separate kv caches --- vllm/attention/backends/habana_attn.py | 19 ++++++++----------- vllm/hpu/cache_ops.py | 13 +++++++++++++ vllm/hpu/utils.py | 12 +++++++----- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 99602ee2b0876..a4a8f22b9090e 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -8,6 +8,7 @@ import os import torch import math +from vllm.hpu import cache_ops from vllm.hpu.utils import VLLMKVCache from vllm.hpu.attn_bias import (AttentionBias, LowerTriangularMaskWithTensorBias) @@ -149,7 +150,8 @@ def __init__( self.qk_matmul = Matmul() self.softmax = Softmax() self.kv_matmul = Matmul() - self.key_value_cache = VLLMKVCache() + self.key_cache = VLLMKVCache() + self.value_cache = VLLMKVCache() self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window @@ -194,9 +196,6 @@ def prompt_attention(self, attn_weights = attn_weights.transpose(1, 2) return attn_weights - def _fetch_from_cache(self, cache, blocks, permutations): - return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] - def paged_attention_v1(self, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape @@ -209,7 +208,7 @@ def paged_attention_v1(self, query, key_cache, value_cache, head_mapping, scale, .view(batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - keys = self._fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) + keys = self.key_cache.fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] @@ -219,7 +218,7 @@ def paged_attention_v1(self, query, key_cache, value_cache, head_mapping, scale, attn_weights = self.softmax(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf), dim=-1) - values = self._fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) + values = self.value_cache.fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) if PA_SPLIT_VALUE: attn_weights = attn_weights.split(block_size, dim=-1) else: @@ -266,11 +265,9 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - key_cache, value_cache = self.key_value_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - attn_metadata.prefill_metadata is not None) + block_indices, block_offset = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping) + key_cache = self.key_cache(key, key_cache, block_indices, block_offset) + value_cache = self.value_cache(value, value_cache, block_indices, block_offset) if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 56aafd2a4d0a9..4734333c99e49 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -19,6 +19,19 @@ def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, dtype, i value_cache.index_put_((indices, offsets), value) +def prepare_to_cache(cache, slot_mapping): + block_size = cache.size(1) + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + + return indices, offsets + + +def insert_or_update_cache(input, cache, block_indices, block_offsets): + cache.index_put_((block_indices, block_offsets), input) + + def swap_blocks(src, dst, block_mapping): index_src = torch.zeros((1,), dtype=torch.int32, device=src.device) index_dst = torch.zeros((1,), dtype=torch.int32, device=dst.device) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 2956fed66c77e..bf8cd72b45df2 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -7,7 +7,7 @@ import torch import habana_frameworks.torch as htorch -from vllm.attention.ops.habana_paged_attn import HabanaPagedAttention +from vllm.hpu.cache_ops import insert_or_update_cache def with_mark_steps(fn): def wrapped(*args, **kwargs): @@ -104,7 +104,9 @@ class VLLMKVCache(torch.nn.Module): def __init__(self): super(VLLMKVCache, self).__init__() - def forward(self, key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, is_prefill_metadata): - HabanaPagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, is_prefill_metadata) - return key_cache, value_cache + def forward(self, input, cache, block_indices, block_offset): + insert_or_update_cache(input, cache, block_indices, block_offset) + return cache + + def fetch_from_cache(self, cache, blocks, permutations): + return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] From c521c4dcae6410e40ffe458c0bd5cedad5a2e1d1 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 18 Jun 2024 11:39:37 +0300 Subject: [PATCH 05/13] HQT on VLLM - remove code duplications --- vllm/attention/backends/habana_attn.py | 91 ++++--------------------- vllm/attention/ops/habana_paged_attn.py | 10 +++ vllm/hpu/ops.py | 18 +++-- vllm/hpu/xops.py | 10 +-- vllm/worker/habana_model_runner.py | 8 ++- 5 files changed, 46 insertions(+), 91 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index a4a8f22b9090e..a202829f49257 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -8,7 +8,7 @@ import os import torch import math -from vllm.hpu import cache_ops +from vllm.hpu import cache_ops, xops from vllm.hpu.utils import VLLMKVCache from vllm.hpu.attn_bias import (AttentionBias, LowerTriangularMaskWithTensorBias) @@ -114,10 +114,6 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None -import habana_frameworks.torch.utils.experimental as htexp -PA_SPLIT_VALUE_DEFAULT = '0' if (htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi3) else '1' -PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', PA_SPLIT_VALUE_DEFAULT) == '1') - class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -168,70 +164,6 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - def prompt_attention(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - p: float = 0.0, - scale: Optional[float] = None, - ) -> torch.Tensor: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - query_heads = query.size(1) - kv_heads = key.size(1) - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - key = key.unflatten(1, (kv_heads, 1)) - value = value.unflatten(1, (kv_heads, 1)) - attn_bias = attn_bias.unsqueeze(2) - attn_weights = self.qk_matmul(query * scale, key.transpose(-1, -2)) - if attn_bias is not None: - attn_weights.add_(attn_bias) - attn_weights = self.softmax(attn_weights, dim=-1) - attn_weights = self.kv_matmul(attn_weights, value) - if query_heads != kv_heads: - attn_weights = attn_weights.flatten(1, 2) - attn_weights = attn_weights.transpose(1, 2) - return attn_weights - - def paged_attention_v1(self, query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None: - seq_len = block_tables.size(1) - batch_size, query_heads, _ = query.shape - _, _, kv_heads, _ = key_cache.shape - min_inf = torch.finfo(query.dtype).min - mask = (torch.arange(0, seq_len * block_size, dtype=torch.int32, device=key_cache.device) - .view(1, -1) - .expand(batch_size, -1) - .ge(context_lens.view(-1, 1)) - .view(batch_size, 1, 1, -1)) - query.mul_(scale) - query = query.unsqueeze(-2) - keys = self.key_cache.fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] - mask = mask.unsqueeze(2) - - attn_weights = [self.qk_matmul(query, k) for k in keys] - attn_weights = self.softmax(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf), - dim=-1) - - values = self.value_cache.fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) - if PA_SPLIT_VALUE: - attn_weights = attn_weights.split(block_size, dim=-1) - else: - values = [torch.cat(values, dim=-2)] - attn_weights = [attn_weights] - if query_heads != kv_heads: - values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [self.kv_matmul(a, v) for a, v in zip(attn_weights, values)] - if query_heads != kv_heads: - attn_weights = [a.flatten(1, 2) for a in attn_weights] - attn_weights = sum(attn_weights) - return attn_weights.squeeze(-2) - def forward( self, query: torch.Tensor, @@ -276,13 +208,16 @@ def forward( assert prefill_meta.attn_bias is not None, 'attn_bias must be set before calling model.forward!' query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) - out = self.prompt_attention( + out = xops.prompt_attention( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), attn_bias=prefill_meta.attn_bias, p=0.0, scale=self.scale, + qk_matmul_op=self.qk_matmul, + softmax_op=self.softmax, + kv_matmul_op=self.kv_matmul, ) output = out.reshape(batch_size, seq_len, hidden_size) else: @@ -302,18 +237,22 @@ def forward( ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. - block_size = value_cache.shape[1] - output = self.paged_attention_v1( + output = HabanaPagedAttention.forward_decode( query, key_cache, value_cache, - self.num_kv_heads, - self.scale, decode_meta.block_tables, decode_meta.seq_lens_tensor, - block_size, - self.alibi_slopes, attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + self.qk_matmul, + self.softmax, + self.kv_matmul, + self.key_cache.fetch_from_cache, + self.value_cache.fetch_from_cache, ) # Reshape the output tensor. diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index c8ed500f7af1c..6ba35a49ce06d 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -85,6 +85,11 @@ def forward_decode( scale: float, alibi_slopes: Optional[torch.Tensor], kv_scale: float, + qk_op=torch.matmul, + softmax_op=torch.softmax, + kv_op=torch.matmul, + keys_fetch=ops.fetch_from_cache, + values_fetch=ops.fetch_from_cache, ) -> torch.Tensor: block_size = value_cache.shape[1] return ops.paged_attention_v1( @@ -98,6 +103,11 @@ def forward_decode( block_size, alibi_slopes, kv_cache_dtype, + qk_op, + softmax_op, + kv_op, + keys_fetch, + values_fetch, ) @staticmethod diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 1f2e07bd59ccb..c6522f5c6c778 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -38,8 +38,8 @@ def fetch_from_cache(cache, blocks, permutations): return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] -@hpu_utils.with_mark_steps -def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None: +def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None, + qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape _, _, kv_heads, _ = key_cache.shape @@ -51,18 +51,16 @@ def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block .view(batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) + keys = keys_fetch_func(key_cache, block_tables, (0, 2, 3, 1)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) + attn_weights = [qk_matmul_op(query, k) for k in keys] + attn_weights = softmax_op(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf), + dim=-1) - attn_weights = [torch.matmul(query, k) for k in keys] - attn_weights = (torch.cat(attn_weights, dim=-1) - .masked_fill(mask, min_inf) - .softmax(dim=-1)) - - values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) + values = values_fetch_func(value_cache, block_tables, (0, 2, 1, 3)) if PA_SPLIT_VALUE: attn_weights = attn_weights.split(block_size, dim=-1) else: @@ -70,7 +68,7 @@ def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block attn_weights = [attn_weights] if query_heads != kv_heads: values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)] + attn_weights = [kv_matmul_op(a, v) for a, v in zip(attn_weights, values)] if query_heads != kv_heads: attn_weights = [a.flatten(1, 2) for a in attn_weights] attn_weights = sum(attn_weights) diff --git a/vllm/hpu/xops.py b/vllm/hpu/xops.py index d6404a4872c0d..cccafff6a8e55 100644 --- a/vllm/hpu/xops.py +++ b/vllm/hpu/xops.py @@ -11,7 +11,6 @@ import vllm.hpu.utils -@vllm.hpu.utils.with_mark_steps def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -19,6 +18,9 @@ def prompt_attention( attn_bias: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, + qk_matmul_op=torch.matmul, + softmax_op=torch.softmax, + kv_matmul_op=torch.matmul, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -30,11 +32,11 @@ def prompt_attention( key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) attn_bias = attn_bias.unsqueeze(2) - attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) + attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = torch.matmul(attn_weights, value) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = kv_matmul_op(attn_weights, value) if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) attn_weights = attn_weights.transpose(1, 2) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 5a604255829b8..352cc839cac21 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -977,6 +977,7 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem): raise NotImplementedError(f'Unsupported graph allocation strategy: {strategy}') buckets = list(sorted(buckets, key=ordering)) + counter = 0 for idx, (batch_size, seq_len) in enumerate(buckets): # Graph memory usage is proportional to seq dimension in a batch batch_seq = batch_size * seq_len if is_prompt else batch_size @@ -986,11 +987,16 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem): self.graphed_buckets.add((batch_size, seq_len, is_prompt)) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + try: + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + except: + print(f"Failed on graph scenario {idx}: batch_size={batch_size}, seq_len={seq_len}, is_prompt={is_prompt}") + counter += 1 used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem total_mem += used_mem total_batch_seq += batch_seq + print(f"Failed warm-up graph scenarios = {counter}") graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) logger.info(f'{phase} captured:{len(graphed)} ({100 * len(graphed) / num_candidates:.1f}%) used_mem:{format_bytes(total_mem)} buckets:{sorted(list(graphed))}') From 64c8c7f66fbfe485ace2ab525c3b972aa3bf6344 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 18 Jun 2024 11:39:37 +0300 Subject: [PATCH 06/13] HQT on VLLM - move matmul and softmax to hpu utils and revert logits use of matmul class --- vllm/attention/backends/habana_attn.py | 3 +-- vllm/hpu/utils.py | 17 +++++++++++++++++ vllm/model_executor/layers/logits_processor.py | 5 ++--- vllm/utils.py | 16 ---------------- vllm/worker/habana_model_runner.py | 9 ++++++--- 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index a202829f49257..ab425050f5b05 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -9,7 +9,7 @@ import torch import math from vllm.hpu import cache_ops, xops -from vllm.hpu.utils import VLLMKVCache +from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.hpu.attn_bias import (AttentionBias, LowerTriangularMaskWithTensorBias) @@ -19,7 +19,6 @@ from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) from vllm.logger import init_logger -from vllm.utils import Matmul, Softmax logger = init_logger(__name__) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index bf8cd72b45df2..d99d5dd83fc9f 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -100,6 +100,23 @@ def tfidf_backend(recipes): return plt # plt.savefig('similarity.png') + +class Matmul(torch.nn.Module): + def __init__(self): + super(Matmul, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class Softmax(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim = None, inv_head = None): + return torch.softmax(x, dim) + + class VLLMKVCache(torch.nn.Module): def __init__(self): super(VLLMKVCache, self).__init__() diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 502ab3f368d55..3951619c6e3ec 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,7 +6,7 @@ from vllm.distributed import tensor_model_parallel_gather, tensor_model_parallel_all_gather from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.utils import is_hpu, Matmul +from vllm.utils import is_hpu class LogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. @@ -33,7 +33,6 @@ def __init__(self, self.logits_as_input = logits_as_input # original vocabulary size (without LoRA). self.org_vocab_size = org_vocab_size or vocab_size - self.matmul = Matmul() def forward( self, @@ -64,7 +63,7 @@ def forward( def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - logits = self.matmul(hidden_states, embedding.t()) + logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias # NOTE(kzawora): HPU PT bridge is missing support for single-rank gather. We'll use all-gather on Gaudi for now. diff --git a/vllm/utils.py b/vllm/utils.py index 59389fe348ab3..2fad29bf01617 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -54,22 +54,6 @@ def reset(self) -> None: self.counter = 0 -class Matmul(torch.nn.Module): - def __init__(self): - super(Matmul, self).__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class Softmax(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, dim = None, inv_head = None): - return torch.softmax(x, dim) - - class LRUCache(Generic[T]): def __init__(self, capacity: int): diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 352cc839cac21..2ae71a83e0833 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -281,18 +281,21 @@ def load_model(self) -> None: parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, ) - if self.model_config.quantization == 'hqt': + logger.info(f"Pre-loading model weights on {next(self.model.parameters()).device} took {m_getmodel.get_summary_string()}") + + if self.model_config.quantization == 'hqt': + with HabanaMemoryProfiler() as m_useHQT: import habana_quantization_toolkit habana_quantization_toolkit.prep_model(self.model) import habana_frameworks.torch.core as htcore htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) - logger.info(f"Pre-loading model weights on {next(self.model.parameters()).device} took {m_getmodel.get_summary_string()}") + logger.info(f"HQT prep model took {m_useHQT.get_summary_string()}") # FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph(self.model) logger.info(f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}") - + self.model_memory_usage = m.consumed_device_memory logger.info(f"Loading model weights took in total {m.get_summary_string()}") From 2e291c5fedb87b50dda7e71a41f7e197a82b4101 Mon Sep 17 00:00:00 2001 From: Nir David Date: Wed, 19 Jun 2024 13:16:48 +0300 Subject: [PATCH 07/13] Move model to hpu when HQT is not used --- vllm/worker/habana_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 2ae71a83e0833..b5f12c4572130 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -290,6 +290,8 @@ def load_model(self) -> None: import habana_frameworks.torch.core as htcore htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) logger.info(f"HQT prep model took {m_useHQT.get_summary_string()}") + else: + self.model = self.model.to("hpu") # FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: From 9d0fbb7ed0fb539d9c2ddb158a4d9a943e3e01d5 Mon Sep 17 00:00:00 2001 From: Nir David Date: Wed, 19 Jun 2024 16:22:48 +0300 Subject: [PATCH 08/13] fix CR comments --- vllm/config.py | 5 +++-- vllm/distributed/parallel_state.py | 3 --- vllm/engine/arg_utils.py | 5 +++-- vllm/model_executor/models/llama.py | 10 ++++++---- vllm/utils.py | 12 ++++++------ vllm/worker/habana_model_runner.py | 26 ++++++++------------------ vllm/worker/habana_worker.py | 12 ++++++++++++ 7 files changed, 38 insertions(+), 35 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 20e1fde194b6d..5daeeebd56e38 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -366,14 +366,15 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype == "fp8": + elif self.cache_dtype in ["fp8", "hf8"]: logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "But it may cause slight accuracy drop without scaling " "factors. FP8_E5M2 (without scaling) is only supported on " "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 " - "is instead supported for common inference criteria.") + "is instead supported for common inference criteria. " + "FP8_E4M3 is also supported on hpu.") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d31e5a01266a9..be5bb4e857caf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -84,9 +84,6 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK global _LOCAL_RANK _LOCAL_RANK = local_rank - import os - os.environ["LOCAL_RANK"] = str(local_rank) - os.environ["RANK"] = str(rank) def initialize_model_parallel( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a8dcaef0e5754..af7d60b16812d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -186,12 +186,13 @@ def add_cli_args( parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], + choices=['auto', 'fp8', 'hf8'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria.') + 'supported for common inference criteria. FP8_E4M3 is also supported ' + 'on hpu.') parser.add_argument( '--quantization-param-path', type=nullable_str, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b773abaf57411..6699b8884bc5f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -47,7 +47,7 @@ default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from vllm.utils import is_hip +from vllm.utils import is_hip, is_hpu class LlamaMLP(nn.Module): @@ -286,8 +286,9 @@ def forward( else: hidden_states = self.get_input_embeddings(input_ids) residual = None - import habana_frameworks.torch as htorch - htorch.core.mark_step() + if is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( @@ -297,7 +298,8 @@ def forward( attn_metadata, residual, ) - htorch.core.mark_step() + if is_hpu(): + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/utils.py b/vllm/utils.py index 2fad29bf01617..020e6871e6771 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,7 +31,8 @@ "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, - "fp8": torch.float8_e4m3fn, + "fp8": torch.uint8, + "hf8": torch.float8_e4m3fn, } @@ -507,12 +508,12 @@ def current_device_memory_usage() -> float: # Return the device memory usage in bytes. free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory - free_hpu_memory - + def current_free_device_memory() -> float: # Return the device memory usage in bytes. free_hpu_memory, _ = torch.hpu.mem_get_info() return free_hpu_memory - + def total_device_memory() -> float: # Return the device memory usage in bytes. _, total_hpu_memory = torch.hpu.mem_get_info() @@ -521,11 +522,11 @@ def total_device_memory() -> float: def current_host_memory_usage() -> float: # Return the host memory usage in bytes. return HabanaMemoryProfiler.total_host_memory() - HabanaMemoryProfiler.current_free_host_memory() - + def current_free_host_memory() -> float: # Return the host memory usage in bytes. return psutil.virtual_memory().available - + def total_host_memory() -> float: # Return the host memory usage in bytes. return psutil.virtual_memory().total @@ -551,7 +552,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.final_host_memory = HabanaMemoryProfiler.current_host_memory_usage() self.consumed_device_memory = self.final_device_memory - self.initial_device_memory self.consumed_host_memory = self.final_host_memory - self.initial_host_memory - # Adapted from https://stackoverflow.com/a/49361727 diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index b5f12c4572130..9da749d3cd76f 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -283,15 +283,18 @@ def load_model(self) -> None: ) logger.info(f"Pre-loading model weights on {next(self.model.parameters()).device} took {m_getmodel.get_summary_string()}") + import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'hqt': - with HabanaMemoryProfiler() as m_useHQT: + logger.info("Preparing model with HQT..") + with HabanaMemoryProfiler() as m_hqt: import habana_quantization_toolkit habana_quantization_toolkit.prep_model(self.model) - import habana_frameworks.torch.core as htcore htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) - logger.info(f"HQT prep model took {m_useHQT.get_summary_string()}") + logger.info(f"Preparing model with HQT took {m_hqt.get_summary_string()}") else: self.model = self.model.to("hpu") + htcore.mark_step() + torch.hpu.synchronize() # FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: @@ -956,16 +959,9 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): logger.info(f"[Warmup][{phase}][{i+1}/{max_i}] batch_size:{batch_size} seq_len:{seq_len} free_mem:{free_mem}") def warmup_all_buckets(self, buckets, is_prompt, kv_caches): - counter = 0 for i, (batch_size, seq_len) in enumerate(reversed(buckets)): - mem_usage = 100.0 * HabanaMemoryProfiler.current_device_memory_usage() / HabanaMemoryProfiler.total_device_memory() self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - try: - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - except: - print(f"Failed on scenario {i+1}: batch_size={batch_size}, seq_len={seq_len}, is_prompt={is_prompt}") - counter += 1 - print(f"Failed warm-up scenarios = {counter}") + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem): total_batch_seq = 0.001 @@ -982,7 +978,6 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem): raise NotImplementedError(f'Unsupported graph allocation strategy: {strategy}') buckets = list(sorted(buckets, key=ordering)) - counter = 0 for idx, (batch_size, seq_len) in enumerate(buckets): # Graph memory usage is proportional to seq dimension in a batch batch_seq = batch_size * seq_len if is_prompt else batch_size @@ -992,16 +987,11 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem): self.graphed_buckets.add((batch_size, seq_len, is_prompt)) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - try: - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - except: - print(f"Failed on graph scenario {idx}: batch_size={batch_size}, seq_len={seq_len}, is_prompt={is_prompt}") - counter += 1 + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem total_mem += used_mem total_batch_seq += batch_seq - print(f"Failed warm-up graph scenarios = {counter}") graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) logger.info(f'{phase} captured:{len(graphed)} ({100 * len(graphed) / num_candidates:.1f}%) used_mem:{format_bytes(total_mem)} buckets:{sorted(list(graphed))}') diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 57fc43285b75c..61c4b8af7b7b7 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -86,6 +86,16 @@ def __init__( self.cache_engine: CacheEngine self.hpu_cache: List[torch.Tensor] + def _set_env_vars(self): + local_rank = self.local_rank + if self.parallel_config.world_size == 1: + local_rank = -1 + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["ID"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) + os.environ["RANK"] = str(self.rank) + def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") @@ -94,6 +104,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. + if self.model_config.quantization == 'hqt': + self._set_env_vars() init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) From 09e00787090b5bd8b7c7fa8da3a50c906cf4e357 Mon Sep 17 00:00:00 2001 From: Nir David Date: Fri, 21 Jun 2024 22:43:30 +0300 Subject: [PATCH 09/13] add model weights device load --- vllm/config.py | 2 ++ vllm/engine/arg_utils.py | 8 ++++++++ vllm/engine/llm_engine.py | 5 +++-- vllm/model_executor/model_loader/loader.py | 3 ++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5daeeebd56e38..a272dcf1963ce 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -475,10 +475,12 @@ class LoadConfig: mainly for profiling. "tensorizer" will use CoreWeave's tensorizer library for fast weight loading. + device: Device on which weights are loaded. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO download_dir: Optional[str] = None + device: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index af7d60b16812d..8fb10bf4d48ce 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -28,6 +28,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + weights_load_device: Optional[str] = None dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -168,6 +169,11 @@ def add_cli_args( '* "tensorizer" will load the weights using tensorizer from ' 'CoreWeave which assumes tensorizer_uri is set to the location of ' 'the serialized weights.') + parser.add_argument("--weights-load-device", + type=str, + default=EngineArgs.weights_load_device, + choices=["cuda", "neuron", "hpu", "cpu"], + help='Device on which weights are loaded.') parser.add_argument( '--dtype', type=str, @@ -575,9 +581,11 @@ def create_engine_config(self, ) -> EngineConfig: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + device = device_config.device_type if self.weights_load_device is None else self.weights_load_device load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + device=device, model_loader_extra_config=self.model_loader_extra_config, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 26258f42a61a6..d95cf030014da 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -104,8 +104,8 @@ def __init__( "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " "max_seq_len=%d, download_dir=%r, load_format=%s, " "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " - "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " + "quantization=%s, weights_load_device=%s, enforce_eager=%s, " + "kv_cache_dtype=%s, quantization_param_path=%s, device_config=%s, " "decoding_config=%r, seed=%d, served_model_name=%s)", vllm.__version__, model_config.model, @@ -123,6 +123,7 @@ def __init__( parallel_config.tensor_parallel_size, parallel_config.disable_custom_all_reduce, model_config.quantization, + load_config.device, model_config.enforce_eager, cache_config.cache_dtype, model_config.quantization_param_path, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c9182c0a910f5..ef53b7e73eaa3 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -220,9 +220,10 @@ def load_model(self, *, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with torch.device(self.load_config.device): model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config) + logger.info("Loading weights on %s ...", self.load_config.device) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, From 24847a9d0ab26d4ee48c57d5bc5fe5a5aae8c611 Mon Sep 17 00:00:00 2001 From: Nir David Date: Wed, 26 Jun 2024 13:44:09 +0300 Subject: [PATCH 10/13] skip replay cached graphs during warmup --- vllm/worker/habana_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 9da749d3cd76f..c928552256856 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -832,6 +832,7 @@ def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], + warmup_mode=False, ) -> Optional[SamplerOutput]: if self.is_driver_worker: event_start = self.profiler.get_timestamp_us() @@ -872,7 +873,7 @@ def execute_model( else: model_event_name = 'model_executable' with self.profiler.record_event('internal', model_event_name): - hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices, bypass_hpu_graphs=not use_graphs) + hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices, bypass_hpu_graphs=not use_graphs, warmup_mode=warmup_mode) # Compute the logits. with self.profiler.record_event('internal', f'compute_logits_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'): @@ -949,7 +950,7 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, kv_caches) -> None: seqs = [self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) for i in range(batch_size)] torch.hpu.synchronize() for _ in range(times): - self.execute_model(seqs, kv_caches) + self.execute_model(seqs, kv_caches, warmup_mode=True) torch.hpu.synchronize() self.profiler.end() gc.collect() From 90c25275394ad5044b92c81c510aa997955c7ac5 Mon Sep 17 00:00:00 2001 From: Nir David Date: Wed, 26 Jun 2024 13:46:34 +0300 Subject: [PATCH 11/13] HQT on VLLM - Enable split value in G3 --- vllm/hpu/ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index c6522f5c6c778..81cc6de8ae03f 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -14,8 +14,7 @@ import vllm.hpu.utils as hpu_utils -# FIXME: For some reason splitting value causes DFAs on G3. This needs to be debugged -PA_SPLIT_VALUE_DEFAULT = '0' if (htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi3) else '1' +PA_SPLIT_VALUE_DEFAULT = '1' PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', PA_SPLIT_VALUE_DEFAULT) == '1') From f7c21579aaf483d8c438dc4540ad35a27992e671 Mon Sep 17 00:00:00 2001 From: Nir David Date: Thu, 27 Jun 2024 14:02:32 +0300 Subject: [PATCH 12/13] pass optimizations flags only in Lazy mode --- vllm/worker/habana_model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c928552256856..1bff70e49c3af 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -145,8 +145,6 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') - if 'bypass_hpu_graphs' in kwargs: - kwargs.pop('bypass_hpu_graphs') # required for PT eager input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), input_ids.device, torch.bfloat16) hidden_states = self.model(*args, **kwargs) @@ -866,14 +864,15 @@ def execute_model( } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) - + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({"bypass_hpu_graphs":not use_graphs, "warmup_mode":warmup_mode}) htorch.core.mark_step() if self.is_driver_worker: model_event_name = f"model_{'prompt' if is_prompt else 'decode'}_bs{batch_size}_seq{seq_len}_graphs{'T' if use_graphs else 'F'}" else: model_event_name = 'model_executable' with self.profiler.record_event('internal', model_event_name): - hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices, bypass_hpu_graphs=not use_graphs, warmup_mode=warmup_mode) + hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices) # Compute the logits. with self.profiler.record_event('internal', f'compute_logits_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'): From 608123bbc8c1ca9ad034c3ddc11ae451c6f23485 Mon Sep 17 00:00:00 2001 From: Nir David Date: Fri, 28 Jun 2024 14:28:45 +0300 Subject: [PATCH 13/13] barak rms norm optimization and a WA to remove transpose nodes --- vllm/attention/backends/habana_attn.py | 1 + vllm/attention/ops/habana_paged_attn.py | 2 ++ vllm/hpu/ops.py | 12 ++++++++++-- vllm/hpu/utils.py | 3 +++ vllm/model_executor/layers/layernorm.py | 8 ++++---- 5 files changed, 20 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index ab425050f5b05..4ec2500b914bc 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -252,6 +252,7 @@ def forward( self.kv_matmul, self.key_cache.fetch_from_cache, self.value_cache.fetch_from_cache, + self.key_cache.permute_cache, ) # Reshape the output tensor. diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 6ba35a49ce06d..1fd0a79de44cd 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -90,6 +90,7 @@ def forward_decode( kv_op=torch.matmul, keys_fetch=ops.fetch_from_cache, values_fetch=ops.fetch_from_cache, + keys_permute=ops.permute_cache, ) -> torch.Tensor: block_size = value_cache.shape[1] return ops.paged_attention_v1( @@ -108,6 +109,7 @@ def forward_decode( kv_op, keys_fetch, values_fetch, + keys_permute, ) @staticmethod diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 81cc6de8ae03f..e028023474dec 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -37,8 +37,13 @@ def fetch_from_cache(cache, blocks, permutations): return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] +def permute_cache(cache, permutations): + return [v.permute(permutations) for v in cache] + + def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None, - qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache) -> None: + qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache, + keys_permute=permute_cache) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape _, _, kv_heads, _ = key_cache.shape @@ -50,11 +55,14 @@ def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block .view(batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - keys = keys_fetch_func(key_cache, block_tables, (0, 2, 3, 1)) + keys = keys_fetch_func(key_cache, block_tables, (0, 2, 1, 3)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] + keys = keys_permute(keys, (0, 1, 2, 4, 3)) mask = mask.unsqueeze(2) + else: + keys = keys_permute(keys, (0, 1, 3, 2)) attn_weights = [qk_matmul_op(query, k) for k in keys] attn_weights = softmax_op(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf), dim=-1) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index d99d5dd83fc9f..560be9a5ab391 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -127,3 +127,6 @@ def forward(self, input, cache, block_indices, block_offset): def fetch_from_cache(self, cache, blocks, permutations): return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] + + def permute_cache(self, cache, permutations): + return [v.permute(permutations) for v in cache] \ No newline at end of file diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a77b0f1d687eb..a64f5bc32488b 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -61,8 +61,8 @@ def forward( orig_shape = x.shape residual += x.view(residual.shape) # Note: FusedRMSNorm requires 3D tensors as inputs - x = FusedRMSNorm.apply(residual.float(), self.weight.float(), self.variance_epsilon) - return x.to(orig_dtype).view(orig_shape), residual + x = FusedRMSNorm.apply(residual, self.weight, self.variance_epsilon) + return x.view(orig_shape), residual ops.fused_add_rms_norm( x, residual, @@ -72,8 +72,8 @@ def forward( return x, residual if x.device.type == "hpu" and FusedRMSNorm: orig_dtype = x.dtype - x = FusedRMSNorm.apply(x.float(), self.weight.float(), self.variance_epsilon) - return x.to(orig_dtype) + x = FusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x out = torch.empty_like(x) ops.rms_norm( out,