From e08acaa51252e4a40a090dc999d53ad66456a7cf Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Mon, 14 Aug 2023 20:20:05 +0800 Subject: [PATCH 01/52] add llama quant --- vllm/model_executor/__init__.py | 3 +- vllm/model_executor/model_loader.py | 24 +- vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/llamaq.py | 609 +++++++++++++++++++++++++ vllm/worker/worker.py | 5 +- 5 files changed, 638 insertions(+), 5 deletions(-) create mode 100644 vllm/model_executor/models/llamaq.py diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 36fc30f9c1e3..3fc785d894c1 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,9 +1,10 @@ from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader import get_model, get_quant_model_v2 from vllm.model_executor.utils import set_random_seed __all__ = [ "InputMetadata", "get_model", + "get_quant_model_v2", "set_random_seed", ] diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 526b4f8b5c87..ff26c7dbeeab 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -23,8 +23,10 @@ "GPTJForCausalLM": GPTJForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "InternLMForCausalLM": InternLMForCausalLM, - "LlamaForCausalLM": LlamaForCausalLM, - "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + # "LlamaForCausalLM": LlamaForCausalLM, + # "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + "LlamaForCausalLM": LlamaQForCausalLM, + "LLaMAForCausalLM": LlamaQForCausalLM, # For decapoda-research/llama-* "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, @@ -101,3 +103,21 @@ def get_model(model_config: ModelConfig) -> nn.Module: model_config.load_format, model_config.revision) model = model.cuda() return model.eval() + + +def get_quant_model_v2(model_config: ModelConfig) -> nn.Module: + model_class = _get_model_architecture(model_config.hf_config) + torch.set_default_dtype(model_config.dtype) + + # Create a model instance. + # The weights will be initialized as empty tensors. + model = model_class(model_config.hf_config) + + int4_path = "/workdir/code/awq-llama/quant_cache/llama" + fp16_path = "/mnt/dolphinfs/hdd_pool/docker/share/1/zhangpeng/zhangpeng/model_weights/llama/13b" + + model.load_mix_weights2(fp16_path, int4_path, model_config.download_dir, + model_config.use_np_weights) + model = model.cuda() + + return model.eval() \ No newline at end of file diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f20e5d8e6f20..f0525162172a 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,6 +9,7 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.internlm import InternLMForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.llamaq import LlamaQForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel @@ -25,6 +26,7 @@ "GPTNeoXForCausalLM", "InternLMForCausalLM", "LlamaForCausalLM", + "LlamaQForCausalLM", "MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel", diff --git a/vllm/model_executor/models/llamaq.py b/vllm/model_executor/models/llamaq.py new file mode 100644 index 000000000000..6b0a2da3484a --- /dev/null +++ b/vllm/model_executor/models/llamaq.py @@ -0,0 +1,609 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +import torch.nn.functional as F + +from transformers import LlamaConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +# from vllm.model_executor.layers.temp_sampler import TempSampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs +from awq.quantize.qmodule import WQLinear +import awq_inference_engine +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast + +KVCache = Tuple[torch.Tensor, torch.Tensor] + +class QuantLlamaQRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + + # self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + # self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_sin_cache", cache.half(), persistent=False) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + positions: torch.Tensor, + ): + # Apply rotary embedding to the query and key before passing them + # to the attention op. + # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape) + query = query.contiguous() + key = key.contiguous() + awq_inference_engine.rotary_embedding_neox( + positions, + query, + key, + self.dim, + self.cos_sin_cache, + ) + return query, key + + + +class LlamaQMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.in_features = hidden_size + self.intermediate_size = intermediate_size + self.out_features = hidden_size + self.w_bit = 4 + self.g_size = 128 + + # self.gate_up_proj = WQLinear(self.w_bit, self.g_size, self.in_features, 2 * self.intermediate_size, False, 'cuda') + self.gate_proj = WQLinear(self.w_bit, self.g_size, self.in_features, self.intermediate_size, False, 'cuda') + self.up_proj = WQLinear(self.w_bit, self.g_size, self.in_features, self.intermediate_size, False, 'cuda') + self.down_proj = WQLinear(self.w_bit, self.g_size, self.intermediate_size, self.out_features, False, 'cuda') + + + def forward(self, x): + return self.down_proj(self.custom_LlamaQ_mlp(x)) + + def custom_LlamaQ_mlp(self, x): + out_shape = x.shape[:-1] + (self.intermediate_size, ) + x = x.reshape(-1, x.shape[-1]) + + gate_output = awq_inference_engine.gemm_forward_cuda( + x, self.gate_proj.qweight, self.gate_proj.scales, self.gate_proj.qzeros, 8 + ) + gate_output = F.silu(gate_output) + + up_output = awq_inference_engine.gemm_forward_cuda( + x, self.up_proj.qweight, self.up_proj.scales, self.up_proj.qzeros, 8 + ) + c = gate_output * up_output + c = c.reshape(out_shape) + return c + + +class LlamaQAttention2(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.scaling = self.head_dim**-0.5 + self.w_bit = 4 + self.g_size = 128 + + self.qkv_proj = WQLinear(self.w_bit, self.g_size, hidden_size, 3 * self.total_num_heads * self.head_dim, False, 'cuda') + self.o_proj = WQLinear(self.w_bit, self.g_size, self.total_num_heads * self.head_dim, hidden_size, False, 'cuda') + + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # print(f"qkv proj size: {self.qkv_proj.shape}, hidden_states size: {hidden_states.shape} ") + # 这里把qkv_proj和o_proj都变成WQLinear + + out_shape = hidden_states.shape[:-1] + (self.total_num_heads * self.head_dim, ) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + + qkv = awq_inference_engine.gemm_forward_cuda( + hidden_states, self.qkv_proj.qweight, self.qkv_proj.scales, self.qkv_proj.qzeros, 8 + ) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + output = awq_inference_engine.gemm_forward_cuda( + hidden_states, self.o_proj.qweight, self.o_proj.scales, self.o_proj.qzeros, 8 + ) + + return output + + +class LlamaQAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + 3 * self.total_num_heads * self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + ) + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # print(f"qkv proj size: {self.qkv_proj.shape}, hidden_states size: {hidden_states.shape} ") + # 这里把qkv_proj和o_proj都变成WQLinear + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaQDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaQAttention2( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + ) + self.mlp = LlamaQMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaQModel(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.embed_tokens = VocabParallelEmbedding( + vocab_size, config.hidden_size, perform_initialization=False) + self.layers = nn.ModuleList([ + LlamaQDecoderLayer(config) for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class LlamaQForCausalLM(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.model = LlamaQModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.lm_head = ColumnParallelLinear(config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False) + self.sampler = Sampler(config.vocab_size) + #self.sampler = TempSampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + input_metadata) + + return next_tokens + + _column_parallel_weights = [ + "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", + "gate_proj.weight", "up_proj.weight" + ] + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + + _column_parallel_weights_fp16 = [ + "embed_tokens.weight", "lm_head.weight" + ] + + _row_parallel_weights_fp16 = [] + + _column_parallel_weights_int4 = [ + "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" + ] + + _row_parallel_weights_int4 = ["o_proj.weight", "down_proj.weight"] + + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + if "embed_tokens" in name or "lm_head" in name: + param = state_dict[name] + # Consider padding in the vocab size. + padded_vocab_size = (param.shape[0] * + tensor_model_parallel_world_size) + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + is_attention_weight = False + for stride_id, att_weight_name in enumerate( + ["q_proj", "k_proj", "v_proj"]): + if att_weight_name not in name: + continue + param = state_dict[name.replace(att_weight_name, "qkv_proj")] + shard_size = param.shape[0] // 3 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + loaded_weight = loaded_weight.repeat(1, 3) + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + loaded_weight = loaded_weight.repeat(1, 2) + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) + + def load_mix_weights(self, + model_name_or_path: str, + q_weight_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + # load fp16 + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + if "mlp" in name: + continue + + if "embed_tokens" in name or "lm_head" in name: + param = state_dict[name] + # Consider padding in the vocab size. + padded_vocab_size = (param.shape[0] * + tensor_model_parallel_world_size) + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + is_attention_weight = False + for stride_id, att_weight_name in enumerate( + ["q_proj", "k_proj", "v_proj"]): + if att_weight_name not in name: + continue + param = state_dict[name.replace(att_weight_name, "qkv_proj")] + shard_size = param.shape[0] // 3 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + # loaded_weight = loaded_weight.repeat(1, 3) + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + param = state_dict[name] + # print(f"name: {name}") + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) + # load int4 + for name, loaded_weight in hf_model_weights_iterator( + q_weight_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + if "embed_tokens" in name or "lm_head" in name: + continue + + if "self_attn" in name: + continue + + if "input_layernorm" in name or "post_attention_layernorm" in name: + continue + + if "norm.weight" in name: + continue + + param = state_dict[name] + print(f"name: {name}") + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) + + def load_mix_weights2(self, + model_name_or_path: str, + q_weight_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + # load fp16 + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + if "mlp" in name: + continue + + if "self_attn" in name: + continue + + if "embed_tokens" in name or "lm_head" in name: + param = state_dict[name] + # Consider padding in the vocab size. + padded_vocab_size = (param.shape[0] * + tensor_model_parallel_world_size) + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + param = state_dict[name] + print(f"fp16 layer name: {name}") + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights_fp16, + self._row_parallel_weights_fp16, + tensor_model_parallel_rank) + # load int4 + for name, loaded_weight in hf_model_weights_iterator( + q_weight_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + if "embed_tokens" in name or "lm_head" in name: + continue + + # if "self_attn" in name: + # continue + + if "input_layernorm" in name or "post_attention_layernorm" in name: + continue + + if "norm.weight" in name: + continue + + is_attention_weight = False + for stride_id, att_weight_name in enumerate( + ["q_proj", "k_proj", "v_proj"]): + if att_weight_name not in name: + continue + print(f"int4 layer name: {name}") + param = state_dict[name.replace(att_weight_name, "qkv_proj")] + shard_size = param.shape[0] // 3 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + loaded_weight = loaded_weight.repeat(1, 3) + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + # print(f"param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + + param = state_dict[name] + # print(f"name: {name}") + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights_int4, + self._row_parallel_weights_int4, + tensor_model_parallel_rank) + + print(f"state dict keys: {state_dict.keys()}") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2d2021d9fe95..55b9b4033c07 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.model_executor import get_model, InputMetadata, set_random_seed +from vllm.model_executor import get_model, get_quant_model_v2, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sampling_params import SamplingParams @@ -64,7 +64,8 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) - self.model = get_model(self.model_config) + # self.model = get_model(self.model_config) + self.model = get_quant_model_v2(self.model_config) @torch.inference_mode() def profile_num_available_blocks( From 387c804d52a9d6c0a0d1111656bdd460a3f163f5 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Mon, 14 Aug 2023 20:38:37 +0800 Subject: [PATCH 02/52] change weight path --- vllm/model_executor/model_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ff26c7dbeeab..28821fb2351f 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -113,7 +113,7 @@ def get_quant_model_v2(model_config: ModelConfig) -> nn.Module: # The weights will be initialized as empty tensors. model = model_class(model_config.hf_config) - int4_path = "/workdir/code/awq-llama/quant_cache/llama" + int4_path = "/mnt/dolphinfs/hdd_pool/docker/share/1/zhangpeng/quanted/quant_cache/llama" fp16_path = "/mnt/dolphinfs/hdd_pool/docker/share/1/zhangpeng/zhangpeng/model_weights/llama/13b" model.load_mix_weights2(fp16_path, int4_path, model_config.download_dir, From 68cd1e04ca46c7ba8777111176361fe59d6351f8 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Tue, 15 Aug 2023 12:44:25 +0800 Subject: [PATCH 03/52] fix weight load --- vllm/model_executor/models/llamaq.py | 122 ++++++++++++++++----------- 1 file changed, 74 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/models/llamaq.py b/vllm/model_executor/models/llamaq.py index 6b0a2da3484a..8d4a8cbf7201 100644 --- a/vllm/model_executor/models/llamaq.py +++ b/vllm/model_executor/models/llamaq.py @@ -127,8 +127,9 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + # tensor_model_parallel_world_size = ( + # get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = 1 self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = (self.total_num_heads // @@ -183,8 +184,9 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + # tensor_model_parallel_world_size = ( + # get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = 1 self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = (self.total_num_heads // @@ -235,7 +237,7 @@ class LlamaQDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaQAttention2( + self.self_attn = LlamaQAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, ) @@ -356,25 +358,32 @@ def forward( _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _column_parallel_weights_fp16 = [ - "embed_tokens.weight", "lm_head.weight" + "embed_tokens.weight", "lm_head.weight", "model.norm.weight" ] _row_parallel_weights_fp16 = [] _column_parallel_weights_int4 = [ - "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" + "qkv_proj.qweight", "gate_proj.qweight", "up_proj.qweight", + "qkv_proj.qzeros", "gate_proj.qzeros", "up_proj.qzeros", + "qkv_proj.scales", "gate_proj.scales", "up_proj.scales", + # "input_layernorm", "post_attention_layernorm" ] - _row_parallel_weights_int4 = ["o_proj.weight", "down_proj.weight"] + _row_parallel_weights_int4 = ["o_proj.qweight", "down_proj.qweight", + "o_proj.qzeros", "down_proj.qzeros", + "o_proj.scales", "down_proj.scales"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - tensor_model_parallel_rank = get_tensor_model_parallel_rank() + # tensor_model_parallel_world_size = ( + # get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = 1 + # tensor_model_parallel_rank = get_tensor_model_parallel_rank() + tensor_model_parallel_rank = 0 state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( @@ -403,7 +412,6 @@ def load_weights(self, loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] - loaded_weight = loaded_weight.repeat(1, 3) param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] assert param_slice.shape == loaded_weight.shape @@ -422,7 +430,6 @@ def load_weights(self, loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] - loaded_weight = loaded_weight.repeat(1, 2) param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] assert param_slice.shape == loaded_weight.shape @@ -443,11 +450,31 @@ def load_mix_weights(self, q_weight_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - tensor_model_parallel_rank = get_tensor_model_parallel_rank() + # tensor_model_parallel_world_size = ( + # get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = 1 + # tensor_model_parallel_rank = get_tensor_model_parallel_rank() + tensor_model_parallel_rank = 0 state_dict = self.state_dict() + column_parallel_weights_fp16 = [ + # "embed_tokens.weight", "lm_head.weight", "model.norm.weight", + # "input_layernorm", "post_attention_layernorm" + "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight" + ] + + row_parallel_weights_fp16 = ["o_proj.weight"] + + column_parallel_weights_int4 = [ + "gate_proj.qweight", "up_proj.qweight", + "gate_proj.qzeros", "up_proj.qzeros", + "gate_proj.scales", "up_proj.scales" + ] + + row_parallel_weights_int4 = ["down_proj.qweight", "down_proj.qzeros", "down_proj.scales"] + + + # load fp16 for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, use_np_cache): @@ -478,7 +505,6 @@ def load_mix_weights(self, loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] - # loaded_weight = loaded_weight.repeat(1, 3) param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] assert param_slice.shape == loaded_weight.shape @@ -489,10 +515,10 @@ def load_mix_weights(self, continue param = state_dict[name] - # print(f"name: {name}") + print(f"fp16 layer name: {name}") load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, + column_parallel_weights_fp16, + row_parallel_weights_fp16, tensor_model_parallel_rank) # load int4 for name, loaded_weight in hf_model_weights_iterator( @@ -509,14 +535,14 @@ def load_mix_weights(self, if "input_layernorm" in name or "post_attention_layernorm" in name: continue - if "norm.weight" in name: + if "model.norm.weight" in name: continue param = state_dict[name] - print(f"name: {name}") + print(f"int4 layer name: {name}") load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, + column_parallel_weights_int4, + row_parallel_weights_int4, tensor_model_parallel_rank) def load_mix_weights2(self, @@ -524,11 +550,16 @@ def load_mix_weights2(self, q_weight_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - tensor_model_parallel_rank = get_tensor_model_parallel_rank() + # tensor_model_parallel_world_size = ( + # get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = 1 + # tensor_model_parallel_rank = get_tensor_model_parallel_rank() + tensor_model_parallel_rank = 0 state_dict = self.state_dict() + # for name, param in state_dict.items(): + # print(f"state_dict name: {name}, param shape: {param.shape}") + # load fp16 for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, use_np_cache): @@ -541,6 +572,9 @@ def load_mix_weights2(self, if "self_attn" in name: continue + if "input_layernorm" in name or "post_attention_layernorm" in name: + continue + if "embed_tokens" in name or "lm_head" in name: param = state_dict[name] # Consider padding in the vocab size. @@ -564,16 +598,7 @@ def load_mix_weights2(self, if "rotary_emb.inv_freq" in name: continue - if "embed_tokens" in name or "lm_head" in name: - continue - - # if "self_attn" in name: - # continue - - if "input_layernorm" in name or "post_attention_layernorm" in name: - continue - - if "norm.weight" in name: + if "embed_tokens" in name or "lm_head" or "model.norm.weight" in name: continue is_attention_weight = False @@ -581,16 +606,17 @@ def load_mix_weights2(self, ["q_proj", "k_proj", "v_proj"]): if att_weight_name not in name: continue - print(f"int4 layer name: {name}") - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - loaded_weight = loaded_weight.repeat(1, 3) - param_slice = param.data[shard_size * stride_id:shard_size * + # print(f"int4 layer name: {name}") + # print(f"stride_id: {stride_id}, att_weight_name: {att_weight_name}") + param_name = name.replace(att_weight_name, "qkv_proj") + param = state_dict[param_name] + shard_size = param.shape[1] // 3 + # loaded_weight = loaded_weight[ + # shard_size * tensor_model_parallel_rank:shard_size * + # (tensor_model_parallel_rank + 1)] + param_slice = param.data[:, shard_size * stride_id:shard_size * (stride_id + 1)] - # print(f"param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") + # print(f"*** {param_name}*** param.shape: {param.shape}, param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_attention_weight = True @@ -600,10 +626,10 @@ def load_mix_weights2(self, param = state_dict[name] - # print(f"name: {name}") + print(f"int4 layer name: {name}") load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights_int4, self._row_parallel_weights_int4, tensor_model_parallel_rank) - print(f"state dict keys: {state_dict.keys()}") + # print(f"state dict keys: {state_dict.keys()}") From ca088d69c1b425bdec9a2fc26e1ce0a6ec68887c Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Wed, 16 Aug 2023 20:34:36 +0800 Subject: [PATCH 04/52] merge gate and up matrix --- vllm/model_executor/models/llamaq.py | 130 +++++++++++++++------------ vllm/model_executor/weight_utils.py | 12 +++ 2 files changed, 86 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/models/llamaq.py b/vllm/model_executor/models/llamaq.py index 8d4a8cbf7201..ffc8e221c299 100644 --- a/vllm/model_executor/models/llamaq.py +++ b/vllm/model_executor/models/llamaq.py @@ -13,7 +13,8 @@ from vllm.model_executor.layers.sampler import Sampler # from vllm.model_executor.layers.temp_sampler import TempSampler from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) + load_tensor_parallel_weights, + load_tensor_parallel_weights2) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -92,14 +93,17 @@ def __init__( self.w_bit = 4 self.g_size = 128 - # self.gate_up_proj = WQLinear(self.w_bit, self.g_size, self.in_features, 2 * self.intermediate_size, False, 'cuda') - self.gate_proj = WQLinear(self.w_bit, self.g_size, self.in_features, self.intermediate_size, False, 'cuda') - self.up_proj = WQLinear(self.w_bit, self.g_size, self.in_features, self.intermediate_size, False, 'cuda') + self.gate_up_proj = WQLinear(self.w_bit, self.g_size, self.in_features, 2 * self.intermediate_size, False, 'cuda') self.down_proj = WQLinear(self.w_bit, self.g_size, self.intermediate_size, self.out_features, False, 'cuda') + self.act_fn = SiluAndMul() def forward(self, x): - return self.down_proj(self.custom_LlamaQ_mlp(x)) + # return self.down_proj(self.custom_LlamaQ_mlp(x)) + gate_up = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.down_proj(x) + return x def custom_LlamaQ_mlp(self, x): out_shape = x.shape[:-1] + (self.intermediate_size, ) @@ -108,7 +112,7 @@ def custom_LlamaQ_mlp(self, x): gate_output = awq_inference_engine.gemm_forward_cuda( x, self.gate_proj.qweight, self.gate_proj.scales, self.gate_proj.qzeros, 8 ) - gate_output = F.silu(gate_output) + gate_output = self.act_fn(gate_output) up_output = awq_inference_engine.gemm_forward_cuda( x, self.up_proj.qweight, self.up_proj.scales, self.up_proj.qzeros, 8 @@ -158,19 +162,12 @@ def forward( # print(f"qkv proj size: {self.qkv_proj.shape}, hidden_states size: {hidden_states.shape} ") # 这里把qkv_proj和o_proj都变成WQLinear - out_shape = hidden_states.shape[:-1] + (self.total_num_heads * self.head_dim, ) - hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - - qkv = awq_inference_engine.gemm_forward_cuda( - hidden_states, self.qkv_proj.qweight, self.qkv_proj.scales, self.qkv_proj.qzeros, 8 - ) + qkv = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) - output = awq_inference_engine.gemm_forward_cuda( - hidden_states, self.o_proj.qweight, self.o_proj.scales, self.o_proj.qzeros, 8 - ) + output = self.o_proj(attn_output) return output @@ -237,7 +234,7 @@ class LlamaQDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaQAttention( + self.self_attn = LlamaQAttention2( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, ) @@ -515,11 +512,12 @@ def load_mix_weights(self, continue param = state_dict[name] - print(f"fp16 layer name: {name}") + # print(f"fp16 layer name: {name}") load_tensor_parallel_weights(param, loaded_weight, name, column_parallel_weights_fp16, row_parallel_weights_fp16, tensor_model_parallel_rank) + print("****************** load int weight ***********************") # load int4 for name, loaded_weight in hf_model_weights_iterator( q_weight_path, cache_dir, use_np_cache): @@ -538,14 +536,33 @@ def load_mix_weights(self, if "model.norm.weight" in name: continue + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + + shard_size = param.shape[1] // 2 + start = shard_size * stride_id + end = shard_size * (stride_id + 1) + param_slice = param.data[:, start:end] + + print(f"{name} param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + param = state_dict[name] - print(f"int4 layer name: {name}") + # print(f"int4 layer name: {name}") load_tensor_parallel_weights(param, loaded_weight, name, column_parallel_weights_int4, row_parallel_weights_int4, tensor_model_parallel_rank) - def load_mix_weights2(self, + def load_int4_weights(self, model_name_or_path: str, q_weight_path: str, cache_dir: Optional[str] = None, @@ -557,51 +574,37 @@ def load_mix_weights2(self, tensor_model_parallel_rank = 0 state_dict = self.state_dict() - # for name, param in state_dict.items(): - # print(f"state_dict name: {name}, param shape: {param.shape}") - - # load fp16 + q_proj_shard_size = (self.config.hidden_size // tensor_model_parallel_world_size) + kv_proj_shard_size = (self.config.hidden_size // + self.config.num_attention_heads * + self.config.num_key_value_heads // tensor_model_parallel_world_size) + + print(f"q_proj_shard_size: {q_proj_shard_size}, kv_proj_shard_size: {kv_proj_shard_size}") + attention_weight_specs = [ + # (weight_name, shard_size, offset) + ("q_proj", q_proj_shard_size, 0), + ("k_proj", kv_proj_shard_size, q_proj_shard_size), + ("v_proj", kv_proj_shard_size, + q_proj_shard_size + kv_proj_shard_size), + ] + # load int4 for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + q_weight_path, cache_dir, use_np_cache): if "rotary_emb.inv_freq" in name: continue - - if "mlp" in name: - continue - - if "self_attn" in name: - continue - - if "input_layernorm" in name or "post_attention_layernorm" in name: - continue if "embed_tokens" in name or "lm_head" in name: param = state_dict[name] # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * - tensor_model_parallel_world_size) + padded_vocab_size = (param.shape[0] * tensor_model_parallel_world_size) num_extra_rows = padded_vocab_size - self.config.vocab_size extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) extra_rows = extra_rows.to(loaded_weight) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - - param = state_dict[name] - print(f"fp16 layer name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights_fp16, - self._row_parallel_weights_fp16, - tensor_model_parallel_rank) - # load int4 - for name, loaded_weight in hf_model_weights_iterator( - q_weight_path, cache_dir, use_np_cache): - if "rotary_emb.inv_freq" in name: - continue - - if "embed_tokens" in name or "lm_head" or "model.norm.weight" in name: - continue is_attention_weight = False + for stride_id, att_weight_name in enumerate( ["q_proj", "k_proj", "v_proj"]): if att_weight_name not in name: @@ -609,6 +612,7 @@ def load_mix_weights2(self, # print(f"int4 layer name: {name}") # print(f"stride_id: {stride_id}, att_weight_name: {att_weight_name}") param_name = name.replace(att_weight_name, "qkv_proj") + param = state_dict[param_name] shard_size = param.shape[1] // 3 # loaded_weight = loaded_weight[ @@ -616,7 +620,7 @@ def load_mix_weights2(self, # (tensor_model_parallel_rank + 1)] param_slice = param.data[:, shard_size * stride_id:shard_size * (stride_id + 1)] - # print(f"*** {param_name}*** param.shape: {param.shape}, param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") + print(f"*** {param_name}*** param.shape: {param.shape}, param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_attention_weight = True @@ -624,12 +628,26 @@ def load_mix_weights2(self, if is_attention_weight: continue + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + + shard_size = param.shape[1] // 2 + start = shard_size * stride_id + end = shard_size * (stride_id + 1) + param_slice = param.data[:, start:end] + + print(f"{name} param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue param = state_dict[name] print(f"int4 layer name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights_int4, - self._row_parallel_weights_int4, + load_tensor_parallel_weights2(param, loaded_weight, name, tensor_model_parallel_rank) - - # print(f"state dict keys: {state_dict.keys()}") diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 74de96842296..4c76fbd5268e 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -294,6 +294,18 @@ def load_tensor_parallel_weights( f"{param.shape} != {loaded_weight.shape}") param.data.copy_(loaded_weight) +def load_tensor_parallel_weights2( + param: torch.Tensor, + loaded_weight: torch.Tensor, + param_name: str, + tensor_model_parallel_rank: int, +) -> None: + assert param.shape == loaded_weight.shape, ( + f"{param_name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}") + param.data.copy_(loaded_weight) + + def initialize_dummy_weights( model: torch.nn.Module, From 6bde51e48452138352d8a8a2b8ce580594349c16 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Thu, 17 Aug 2023 16:53:45 +0800 Subject: [PATCH 05/52] use FTLlamaRMSNorm --- vllm/model_executor/models/llamaq.py | 60 ++++++++++++++++------------ 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/llamaq.py b/vllm/model_executor/models/llamaq.py index ffc8e221c299..6b9bc301411c 100644 --- a/vllm/model_executor/models/llamaq.py +++ b/vllm/model_executor/models/llamaq.py @@ -78,6 +78,24 @@ def forward( +class FTLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + # print(f"norm weight shape {self.weight.shape}") + + def forward(self, x): + x = x.unsqueeze(0) + output = torch.empty_like(x) + awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) + output = output.squeeze(0) + return output + + class LlamaQMLP(nn.Module): def __init__( @@ -243,10 +261,14 @@ def __init__(self, config: LlamaConfig): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) - self.input_layernorm = RMSNorm(config.hidden_size, + self.input_layernorm = FTLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, + self.post_attention_layernorm = FTLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.input_layernorm = RMSNorm(config.hidden_size, + # eps=config.rms_norm_eps) + # self.post_attention_layernorm = RMSNorm(config.hidden_size, + # eps=config.rms_norm_eps) def forward( self, @@ -290,7 +312,8 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList([ LlamaQDecoderLayer(config) for _ in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = FTLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -513,9 +536,7 @@ def load_mix_weights(self, param = state_dict[name] # print(f"fp16 layer name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights_fp16, - row_parallel_weights_fp16, + load_tensor_parallel_weights2(param, loaded_weight, name, tensor_model_parallel_rank) print("****************** load int weight ***********************") # load int4 @@ -557,9 +578,7 @@ def load_mix_weights(self, param = state_dict[name] # print(f"int4 layer name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights_int4, - row_parallel_weights_int4, + load_tensor_parallel_weights2(param, loaded_weight, name, tensor_model_parallel_rank) def load_int4_weights(self, @@ -573,20 +592,9 @@ def load_int4_weights(self, # tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = 0 state_dict = self.state_dict() + # for name, weight in state_dict.items(): + # print(f"state dict name: {name}") - q_proj_shard_size = (self.config.hidden_size // tensor_model_parallel_world_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tensor_model_parallel_world_size) - - print(f"q_proj_shard_size: {q_proj_shard_size}, kv_proj_shard_size: {kv_proj_shard_size}") - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), - ] # load int4 for name, loaded_weight in hf_model_weights_iterator( q_weight_path, cache_dir, use_np_cache): @@ -620,7 +628,7 @@ def load_int4_weights(self, # (tensor_model_parallel_rank + 1)] param_slice = param.data[:, shard_size * stride_id:shard_size * (stride_id + 1)] - print(f"*** {param_name}*** param.shape: {param.shape}, param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") + # print(f"*** {param_name}*** param.shape: {param.shape}, param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_attention_weight = True @@ -639,7 +647,7 @@ def load_int4_weights(self, end = shard_size * (stride_id + 1) param_slice = param.data[:, start:end] - print(f"{name} param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") + # print(f"{name} param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True @@ -648,6 +656,8 @@ def load_int4_weights(self, continue param = state_dict[name] - print(f"int4 layer name: {name}") + # print(f"int4 layer name: {name}") + if "norm" in name: + print(f"{name} shape: {loaded_weight.shape}") load_tensor_parallel_weights2(param, loaded_weight, name, tensor_model_parallel_rank) From 931e51c815172ee5eb18d17deaf3ae8333d6ffa6 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Mon, 28 Aug 2023 17:20:47 +0800 Subject: [PATCH 06/52] support bitsandbytes int8 --- vllm/model_executor/models/llama.py | 71 +++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0b7f4181a150..b80be1fc67e7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -78,10 +78,35 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - gate_up, _ = self.gate_up_proj(x) + gate_up = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + x = self.down_proj(x) return x + + def trans_int8(self): + int8_gate_up = Linear8bitLt( + self.gate_up_proj.in_features, + self.gate_up_proj.out_features, + self.gate_up_proj.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + int8_gate_up.weight = bnb.nn.Int8Params( + self.gate_up_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(self.gate_up_proj.weight.dtype) + self.gate_up_proj = int8_gate_up + + int8_down = Linear8bitLt( + self.down_proj.in_features, + self.down_proj.out_features, + self.down_proj.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + int8_down.weight = bnb.nn.Int8Params( + self.down_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(self.down_proj.weight.dtype) + self.down_proj = int8_down class LlamaAttention(nn.Module): @@ -96,7 +121,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + # tp_size = get_tensor_model_parallel_world_size() + tp_size = 1 self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -141,13 +167,38 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) + qkv = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) - output, _ = self.o_proj(attn_output) + output = self.o_proj(attn_output) return output + + def trans_int8(self): + int8_qkv = Linear8bitLt( + self.qkv_proj.in_features, + self.qkv_proj.out_features, + self.qkv_proj.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + int8_qkv.weight = bnb.nn.Int8Params( + self.qkv_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(self.qkv_proj.weight.dtype) + self.qkv_proj = int8_qkv + + int8_o = Linear8bitLt( + self.o_proj.in_features, + self.o_proj.out_features, + self.o_proj.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + int8_o.weight = bnb.nn.Int8Params( + self.o_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(self.o_proj.weight.dtype) + self.o_proj = int8_o class LlamaDecoderLayer(nn.Module): @@ -326,6 +377,9 @@ def load_weights(self, ] state_dict = self.state_dict() + # for name, param in state_dict.items(): + # print(f"state_dict name: {name}, param shape: {param.shape}") + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: @@ -398,3 +452,10 @@ def load_weights(self, column_parallel_weights, row_parallel_weights, tensor_model_parallel_rank) + + for i in range(len(self.model.layers)): + layer = self.model.layers[i] + layer.mlp.trans_int8() + layer.self_attn.trans_int8() + + print(self.model) \ No newline at end of file From c0c2a4de5cfd75e65c3aeaffc0fcc1ec5c54ab23 Mon Sep 17 00:00:00 2001 From: zhangpeng156 Date: Wed, 30 Aug 2023 10:38:52 +0800 Subject: [PATCH 07/52] llama support bnb 4bit --- vllm/model_executor/models/llama.py | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b80be1fc67e7..51a05a4a09d6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -107,6 +107,25 @@ def trans_int8(self): self.down_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False ).to(self.down_proj.weight.dtype) self.down_proj = int8_down + + def trans_fp4(self): + int8_gate_up = LinearFP4( + self.gate_up_proj.in_features, + self.gate_up_proj.out_features, + self.gate_up_proj.bias is not None + ) + int8_gate_up.weight = bnb.nn.Params4bit( + self.gate_up_proj.weight.data.clone(), requires_grad=False).to(self.gate_up_proj.weight.dtype) + self.gate_up_proj = int8_gate_up + + int8_down = LinearFP4( + self.down_proj.in_features, + self.down_proj.out_features, + self.down_proj.bias is not None + ) + int8_down.weight = bnb.nn.Params4bit( + self.down_proj.weight.data.clone(), requires_grad=False).to(self.down_proj.weight.dtype) + self.down_proj = int8_down class LlamaAttention(nn.Module): @@ -200,6 +219,24 @@ def trans_int8(self): ).to(self.o_proj.weight.dtype) self.o_proj = int8_o + def trans_fp4(self): + int8_qkv = LinearFP4( + self.qkv_proj.in_features, + self.qkv_proj.out_features, + self.qkv_proj.bias is not None + ) + int8_qkv.weight = bnb.nn.Params4bit( + self.qkv_proj.weight.data.clone(), requires_grad=False).to(self.qkv_proj.weight.dtype) + self.qkv_proj = int8_qkv + + int8_o = LinearFP4( + self.o_proj.in_features, + self.o_proj.out_features, + self.o_proj.bias is not None + ) + int8_o.weight = bnb.nn.Params4bit( + self.o_proj.weight.data.clone(), requires_grad=False).to(self.o_proj.weight.dtype) + self.o_proj = int8_o class LlamaDecoderLayer(nn.Module): From 3bb6e31280a9607f9192a2efa03c74294b8cfe53 Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Tue, 19 Sep 2023 10:58:20 +0800 Subject: [PATCH 08/52] support kv cache quantization --- csrc/attention.cpp | 22 ++ csrc/attention/attention_dtypes.h | 1 + csrc/attention/attention_kernels.cu | 466 +++++++++++++++++++++++- csrc/attention/dtype_float32.cuh | 8 + csrc/attention/dtype_int8.cuh | 49 +++ csrc/cache.cpp | 15 + csrc/cache_kernels.cu | 101 +++++ csrc/quant_utils.cuh | 235 ++++++++++++ tests/kernels/test_attention.py | 308 +++++++++++++++- tests/kernels/test_cache.py | 152 ++++++++ vllm/config.py | 7 + vllm/engine/arg_utils.py | 12 + vllm/model_executor/__init__.py | 3 +- vllm/model_executor/layers/attention.py | 74 ++-- vllm/model_executor/model_loader.py | 19 +- vllm/model_executor/models/llama.py | 119 +----- vllm/worker/cache_engine.py | 3 +- vllm/worker/worker.py | 4 +- 18 files changed, 1469 insertions(+), 129 deletions(-) create mode 100644 csrc/attention/dtype_int8.cuh create mode 100644 csrc/quant_utils.cuh diff --git a/csrc/attention.cpp b/csrc/attention.cpp index 6be8a6d25ae4..e1b8159feb79 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -14,9 +14,31 @@ void single_query_cached_kv_attention( int max_context_len, const c10::optional& alibi_slopes); +void single_query_cached_kv_quantized_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "single_query_cached_kv_attention", &single_query_cached_kv_attention, "Compute the attention between an input query and the cached key/value tensors"); + m.def( + "single_query_cached_kv_quantized_attention", + &single_query_cached_kv_quantized_attention, + "Compute the attention between an input query and the cached & quantized key/value tensors" + ); } diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 88b4eddec7fc..ce1a03375233 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -4,3 +4,4 @@ #include "dtype_float16.cuh" #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" +#include "dtype_int8.cuh" diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3fc5860bf147..5cd5aeeddbc5 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -17,7 +17,7 @@ */ #include #include - +#include "../quant_utils.cuh" #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -338,6 +338,282 @@ __global__ void single_query_cached_kv_attention_kernel( } } +template< + typename scalar_t, + typename cache_type, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void single_query_cached_kv_attention_quantized_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_type* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_type* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Vec_quant = typename Vec::Type; + using Vec_dequant = typename FloatVec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_type); + float qk_max = -FLT_MAX; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_type* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + // dequant and conversion + Vec_quant k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + Vec_dequant k_vec_dequant = dequant(k_vec_quant, k_scale, k_zp); + k_vecs[j] = vec_conversion(k_vec_dequant); + // k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using V_vec_quant = typename Vec::Type; + using V_vec_dequant = typename FloatVec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + + const cache_type* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + // dequant and conversion + V_vec_quant v_vec_quant = *reinterpret_cast(v_ptr + offset); + V_vec_dequant v_vec_dequant = dequant(v_vec_quant, v_scale, v_zp); + V_vec v_vec = vec_conversion(v_vec_dequant); + // V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} } // namespace vllm #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ @@ -357,6 +633,28 @@ __global__ void single_query_cached_kv_attention_kernel( kv_block_stride, \ kv_head_stride); +// specifying cache type to int8 manually +#define LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + vllm::single_query_cached_kv_attention_quantized_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + k_scale, \ + k_zp, \ + v_scale, \ + v_zp); + // TODO(woosuk): Tune NUM_THREADS. template< typename T, @@ -442,6 +740,94 @@ void single_query_cached_kv_attention_launcher( } } +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_attention_quantized_launcher( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + int8_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); // TODO: support other types + int8_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); // TODO: support other types + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; + case 64: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 112: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + // case 160: + // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; + case 256: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + #define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ single_query_cached_kv_attention_launcher( \ out, \ @@ -455,6 +841,24 @@ void single_query_cached_kv_attention_launcher( max_context_len, \ alibi_slopes); +#define CALL_QUANTIZED_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_attention_quantized_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes, \ + k_scale, \ + k_zp, \ + v_scale, \ + k_zp); + + // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ @@ -491,6 +895,40 @@ void single_query_cached_kv_attention_launcher( break; \ } +#define CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ + case 8: \ + CALL_QUANTIZED_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_QUANTIZED_KERNEL_LAUNCHER(T, 16); \ + break; \ + /*case 32: \ + CALL_QUANTIZED_KERNEL_LAUNCHER(T, 32); \ + break;*/ \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + void single_query_cached_kv_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] @@ -514,6 +952,32 @@ void single_query_cached_kv_attention( } } +void single_query_cached_kv_quantized_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + if (query.dtype() == at::ScalarType::Float) { + CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} #undef WARP_SIZE #undef MAX #undef MIN diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb..51407f35e2d0 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -86,6 +86,14 @@ inline __device__ float4 add(float4 a, float4 b) { return c; } +// for compiling, the above function seems to be useless +inline __device__ Float4_ add(Float4_ a, Float4_ b) { + Float4_ c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + // Vector multiplication. template<> inline __device__ float mul(float a, float b) { diff --git a/csrc/attention/dtype_int8.cuh b/csrc/attention/dtype_int8.cuh new file mode 100644 index 000000000000..91e6ec40b038 --- /dev/null +++ b/csrc/attention/dtype_int8.cuh @@ -0,0 +1,49 @@ +#pragma once + +#include +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +namespace vllm { +// define int8 vector types for quantization of kv cache + +template<> +struct Vec { + using Type = int8_t; +}; + +template<> +struct Vec { + using Type = int16_t; +}; + +template<> +struct Vec { + using Type = int32_t; +}; + +template<> +struct Vec { + using Type = int64_t; +}; + +template<> +struct FloatVec { + using Type = float; +}; + +template<> +struct FloatVec { + using Type = float2; +}; + +template<> +struct FloatVec { + using Type = Float4_; +}; + +template<> +struct FloatVec { + using Type = Float8_; +}; +} diff --git a/csrc/cache.cpp b/csrc/cache.cpp index 9ae17bb2985c..5ada275ad472 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -27,6 +27,17 @@ void gather_cached_kv( torch::Tensor& value_cache, torch::Tensor& slot_mapping); +void reshape_and_cache_quantized( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "swap_blocks", @@ -44,4 +55,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "gather_cached_kv", &gather_cached_kv, "Gather key and value from the cache into contiguous QKV tensors"); + m.def( + "reshape_and_cache_quantized", + &reshape_and_cache_quantized, + "Reshape and quantized key and value tensors and cache them"); } diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ddad2b5a29b9..85865eca4466 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -7,6 +7,7 @@ #include #include #include +#include "quant_utils.cuh" void swap_blocks( torch::Tensor& src, @@ -128,6 +129,9 @@ void copy_blocks( dim3 block(std::min(1024, numel_per_block)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( + at::ScalarType::Half, + // at::ScalarType::BFloat16, + at::ScalarType::Char, key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -137,6 +141,7 @@ void copy_blocks( })); } + namespace vllm { template @@ -181,6 +186,54 @@ __global__ void reshape_and_cache_kernel( } } +template // cache_dtype can only be int8_t for now +__global__ void reshape_and_cache_quantized_kernel( + const attn_dtype* __restrict__ key, // [num_tokens, num_heads, head_size] + const attn_dtype* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_dtype* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + cache_dtype* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int src_key_idx = token_idx * key_stride + i; + const int src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int tgt_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + // TODO (Lin Pengyun): use vector reading and quantization to improve IO ultilization + attn_dtype tgt_key = __ldg(&key[src_key_idx]); + key_cache[tgt_key_idx] = quant(tgt_key, k_scale, k_zp); + attn_dtype tgt_value = __ldg(&value[src_value_idx]); + value_cache[tgt_value_idx] = quant(tgt_value, v_scale, v_zp); + } +} } // namespace vllm void reshape_and_cache( @@ -221,6 +274,54 @@ void reshape_and_cache( }); } +void reshape_and_cache_quantized( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) +{ + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + at::ScalarType::Half, + at::ScalarType::BFloat16, + key.scalar_type(), + "reshape_and_cache_quantized_kernel", + [&] { + vllm::reshape_and_cache_quantized_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size, + x, + k_scale, + k_zp, + v_scale, + v_zp); + }); +} + namespace vllm { // Grid: (num_blocks, block_size). diff --git a/csrc/quant_utils.cuh b/csrc/quant_utils.cuh new file mode 100644 index 000000000000..f2639ba4cf9c --- /dev/null +++ b/csrc/quant_utils.cuh @@ -0,0 +1,235 @@ +#pragma once + +#include +#include +#include +#include +#include "attention/attention_dtypes.h" +#include "attention/dtype_float32.cuh" +using namespace vllm; + +// this function is for function matching, delete it after writing customized dispatch functions +inline __device__ int8_t quant(double a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ int8_t quant(float a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ short quant(float2 a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + return int16; +} + +inline __device__ int32_t quant(float4 a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + int8[2] = round(max(-128.f, min(127.f, (a.z - zp) / scale))); + int8[3] = round(max(-128.f, min(127.f, (a.w - zp) / scale))); + return int32; +} + +// float16 to int8 +inline __device__ int8_t quant(uint16_t a, const float scale, const float zp) +{ + int8_t int8; + float b = half_to_float(a); + int8 = round(max(-128.f, min(127.f, (b - zp) / scale))); + return int8; +} + +// float16x2 to int8x2 +inline __device__ int16_t quant(uint32_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + float2 b = half2_to_float2(a); + + int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale))); + return int16; +} + +// float16x4 to int8x4 +inline __device__ int32_t quant(uint2 a, const float scale, const float zp) +{ + union { + int16_t int16[2]; + int32_t int32; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + return int32; +} + +// float16x8 to int8x8 +inline __device__ int64_t quant(uint4 a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + int16[2] = quant(a.z, scale, zp); + int16[3] = quant(a.w, scale, zp); + return int64; +} + +// int8 to float32, then `vec_conversion` to target format +inline __device__ float dequant(int8_t a, const float scale, const float zp) +{ + float b = a * scale + zp; + return b; +} + +// int8x2 to float32x2 +inline __device__ float2 dequant(int16_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + int16 = a; + + float2 b; + b.x = int8[0] * scale + zp; + b.y = int8[1] * scale + zp; + return b; +} + +// int8x4 to float32x4 +inline __device__ Float4_ dequant(int32_t a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int32 = a; + + Float4_ b; + b.x.x = (int8[0] * scale) + zp; + b.x.y = (int8[1] * scale) + zp; + b.y.x = (int8[2] * scale) + zp; + b.y.y = (int8[3] * scale) + zp; + return b; +} + +inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + int64 = a; + + Float8_ b; + b.x = dequant(int16[0], scale, zp); + b.y = dequant(int16[1], scale, zp); + b.z = dequant(int16[2], scale, zp); + b.w = dequant(int16[3], scale, zp); + return b; +} + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +template<> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template<> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + + return b; +} + +template<> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template<> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { + return __float22bfloat162_rn(a); +} + +template<> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { + bf16_4_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + return b; +} + +template<> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { + bf16_8_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + b.z = vec_conversion<__nv_bfloat162, float2>(a.z); + b.w = vec_conversion<__nv_bfloat162, float2>(a.w); + return b; +} diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 18985669d159..4d575428d646 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -99,6 +99,145 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +def ref_single_query_cached_kv_attention_quantized( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + k_scale: float, + k_zp: float, + v_scale: float, + v_zp: float, +) -> None: + num_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + + num_input_tokens = query.shape[0] + for i in range(num_input_tokens): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_heads, head_size) + k = k.to(torch.float32) + k = k * k_scale + k_zp + k = k.to(q.dtype) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + v = v.to(torch.float32) + v = v * v_scale + v_zp + v = v.to(q.dtype) + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + scale = 1.0 / (head_size**0.5) + out = ref_masked_attention(q, keys, values, scale) + out = out.view(num_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + head_size = query.shape[-1] + scale = 1.0 / (head_size**0.5) + + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device='cuda') + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def ref_multi_query_cached_kv_attention( + cu_query_lens: List[int], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + num_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + scale = 1.0 / (head_size**0.5) + + num_queries = len(cu_query_lens) - 1 + ref_outputs = [] + for i in range(num_queries): + start_idx = cu_query_lens[i] + end_idx = cu_query_lens[i + 1] + query_len = end_idx - start_idx + context_len = int(context_lens[i]) + block_table = block_tables[i] + + # Create attention mask + attn_mask = torch.triu(torch.ones(query_len, context_len), + diagonal=context_len - query_len + 1) * -1e5 + attn_mask = attn_mask.to(dtype=dtype, device='cuda') + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + keys, + values, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + @torch.inference_mode() def test_single_query_cached_kv_attention( kv_cache_factory, @@ -231,7 +370,109 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_multi_query_kv_attention( +def run_single_query_cached_kv_attention_quantized( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + num_kv_heads: int = None, + k_scale: float = 1e-2, + k_zp: float = 0.0, + v_scale: float = 1e-2, + v_zp: float = 0.0, +) -> None: + qkv = torch.empty(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + qkv.uniform_(-1e-3, 1e-3) + query, _, _ = qkv.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_block_shape = (num_heads, head_size // x, block_size, x) + key_cache = torch.empty(size=(num_blocks, *key_block_shape), + dtype=torch.int8, ## fixed this to int8 + device='cuda') + key_cache.random_(-1, 2) ## change data range + value_block_shape = (num_heads, head_size, block_size) + value_cache = torch.empty(size=(num_blocks, *value_block_shape), + dtype=torch.int8, ## fixed this to int8 + device='cuda') + value_cache.random_(-1, 2) ## change data range + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_tokens): + block_table = [ + random.randint(0, num_blocks - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') + head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda") + + scale = float(1.0 / (head_size**0.5)) + + num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + assert num_heads % num_kv_heads == 0 + num_queries_per_kv = num_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + + output = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + attention_ops.single_query_cached_kv_quantized_attention( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + None, # ALiBi slopes. + k_scale, + k_zp, + v_scale, + v_zp, + ) + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention_quantized( + ref_output, + query, + key_cache, + value_cache, + block_tables, + context_lens, + k_scale, + k_zp, + v_scale, + v_zp, + ) + # NOTE(woosuk): Due to the difference in the data types the two + # implementations use for attention softmax logits and accumulation, + # there is a small difference in the final outputs. + # We should use a relaxed tolerance for the test. + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +@torch.inference_mode() +def run_multi_query_kv_attention( num_seqs: int, num_heads: Tuple[int, int], head_size: int, @@ -284,3 +525,68 @@ def test_multi_query_kv_attention( dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +def test_single_query_cached_kv_attention() -> None: + torch.random.manual_seed(TEST_SEED) + torch.cuda.manual_seed(TEST_SEED) + for dtype in [torch.half, torch.bfloat16, torch.float]: + for block_size in [8, 16, 32]: + for head_size in [64, 80, 96, 112, 128, 256]: + print(f'Testing single_query_cached_kv_attention with ' + f'dtype={dtype}, block_size={block_size}, ' + f'head_size={head_size}') + run_single_query_cached_kv_attention( + num_tokens=37, + num_heads=3, + head_size=head_size, + block_size=block_size, + num_blocks=1024, + dtype=dtype, + ) + + +def test_single_query_cached_kv_attention_quantized() -> None: + torch.random.manual_seed(TEST_SEED) + torch.cuda.manual_seed(TEST_SEED) + for dtype in [ + torch.half, + torch.bfloat16, + torch.float, + ]: + for block_size in [8, + 16, + ]: + for head_size in [64, + 80, + 96, + 112, + 128, + 256, + ]: + print(f'Testing single_query_cached_kv_attention with ' + f'dtype={dtype}, block_size={block_size}, ' + f'head_size={head_size}') + run_single_query_cached_kv_attention_quantized( + num_tokens=37, + num_heads=3, + head_size=head_size, + block_size=block_size, + num_blocks=1024, + dtype=dtype, + ) + + +def test_multi_query_kv_attention() -> None: + torch.random.manual_seed(TEST_SEED) + torch.cuda.manual_seed(TEST_SEED) + for dtype in [torch.half, torch.bfloat16, torch.float]: + for head_size in [64, 80, 96, 112, 128, 256]: + print(f'Testing multi_query_kv_attention with dtype={dtype}, ' + f'head_size={head_size}') + run_multi_query_kv_attention( + num_seqs=5, + num_heads=3, + head_size=head_size, + dtype=dtype, + ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index cca037df235d..7e449cb182b3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -144,3 +144,155 @@ def test_reshape_and_cache( assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) + + +@torch.inference_mode() +def run_reshape_and_cache_quantized( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + k_scale: float = 3.0, + k_zp: float = 0.0, + v_scale: float = 3.0, + v_zp: float = 0.0, +) -> None: + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + _, key, value = qkv.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache = torch.randint(-10, 10, size=key_cache_shape, dtype=torch.int8, device='cuda') ## change to int8 + cloned_key_cache = key_cache.clone() + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache = torch.randint(-10, 10, size=value_cache_shape, + dtype=torch.int8, ## change to int8 + device='cuda') + cloned_value_cache = value_cache.clone() + + cache_ops.reshape_and_cache_quantized(key, value, key_cache, value_cache, + slot_mapping, k_scale, k_zp, v_scale, v_zp) + lower_bound, upper_bound = torch.tensor([-128.0], dtype=dtype, device='cuda'), torch.tensor([127.0], dtype=dtype, device='cuda') + ## quantize and store here + reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) + reshaped_key = torch.maximum(lower_bound, torch.minimum(upper_bound, (reshaped_key - k_zp) / k_scale)) + reshaped_key = torch.round(reshaped_key) + reshaped_key = reshaped_key.to(torch.int8) ## change to int8 + quantized_value = torch.maximum(lower_bound, torch.minimum(upper_bound, (value - v_zp) / v_scale)) + quantized_value = torch.round(quantized_value) + quantized_value = quantized_value.to(torch.int8) + + for i in range(num_tokens): + block_idx = torch.div(slot_mapping[i], + block_size, + rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] + cloned_value_cache[block_idx, :, :, block_offset] = quantized_value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + +@torch.inference_mode() +def run_gather_cached_kv( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + _, key, value = qkv.unbind(dim=1) + + qkv_clone = qkv.clone() + _, cloned_key, cloned_value = qkv_clone.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache = torch.randn(size=value_cache_shape, + dtype=dtype, + device='cuda') + + cache_ops.gather_cached_kv(key, value, key_cache, value_cache, + slot_mapping) + + # Reference implementation. + for i in range(num_tokens): + reshaped_key = cloned_key.reshape(num_tokens, num_heads, + head_size // x, x) + block_idx = torch.div(slot_mapping[i], + block_size, + rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] + cloned_value[i] = value_cache[block_idx, :, :, block_offset] + + assert torch.allclose(key, cloned_key) + assert torch.allclose(value, cloned_value) + + +def test_copy_blocks() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + run_copy_blocks(num_mappings=23, + num_layers=7, + num_heads=17, + head_size=16, + block_size=8, + num_blocks=1024, + dtype=dtype) + + +def test_reshape_and_cache() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + run_reshape_and_cache(num_tokens=3, + num_heads=2, + head_size=16, + block_size=8, + num_blocks=2, + dtype=dtype) + + +def test_reshape_and_cache_quantized() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + run_reshape_and_cache_quantized(num_tokens=3, + num_heads=2, + head_size=16, + block_size=8, + num_blocks=2, + dtype=dtype) + + +def test_gather_cached_kv() -> None: + for dtype in [torch.half, torch.bfloat16, torch.float]: + run_gather_cached_kv(num_tokens=3, + num_heads=2, + head_size=16, + block_size=8, + num_blocks=2, + dtype=dtype) diff --git a/vllm/config.py b/vllm/config.py index dd92fbccd899..39d04aff1058 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -60,6 +60,8 @@ def __init__( revision: Optional[str], max_model_len: Optional[int] = None, quantization: Optional[str] = None, + kv_cache_dtype: str = None, ## for kv cache quantization, only for int8 right now + kv_quant_params_path: str = None, ## path for kv scales and zero points ) -> None: self.model = model self.tokenizer = tokenizer @@ -74,6 +76,10 @@ def __init__( self.hf_config = get_config(model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_load_format() + ## for kv cache quantization + self.kv_cache_dtype = _STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] if kv_cache_dtype else self.dtype + self.quant_kv_cache = self.kv_cache_dtype == self.dtype + self.kv_quant_params_path = kv_quant_params_path self._verify_tokenizer_mode() self._verify_quantization() self.max_model_len = None @@ -296,6 +302,7 @@ def __init__(self, max_num_batched_tokens: int, max_num_seqs: int, _STR_DTYPE_TO_TORCH_DTYPE = { + "int8": torch.int8, "half": torch.float16, "float16": torch.float16, "float": torch.float32, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a03155a4929d..d43e83016fc2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -103,6 +103,18 @@ def add_cli_args( default=None, help='model context length. If unspecified, ' 'will be automatically derived from the model.') + # kv cache quantization + parser.add_argument( + '--kv-cache-dtype', + type=str, + default="float16", + help='data type for kv cache') + parser.add_argument( + 'kv-quant-params-path', + type=str, + default=None, + help="path to kv scales and zero points" + ) # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 3fc785d894c1..ab7a59dab318 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,5 +1,5 @@ from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.model_loader import get_model, get_quant_model_v2 +from vllm.model_executor.model_loader import get_model, get_quant_model_v2, get_quant_model_kv from vllm.model_executor.utils import set_random_seed __all__ = [ @@ -7,4 +7,5 @@ "get_model", "get_quant_model_v2", "set_random_seed", + "get_quant_model_kv" ] diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 5e9360a3c20e..a1d6dfd35dd1 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -56,7 +56,9 @@ def __init__(self, num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None) -> None: + num_kv_heads: Optional[int] = None, + quant_kv_cache: bool = False, + kv_quant_params: List[int] = None) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size @@ -65,6 +67,8 @@ def __init__(self, assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.quant_kv_cache = quant_kv_cache + self.kv_quant_params = kv_quant_params self.head_mapping = torch.repeat_interleave( torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), self.num_queries_per_kv) @@ -144,19 +148,35 @@ def single_query_cached_kv_attention( input_metadata: metadata for paged attention. """ block_size = value_cache.shape[3] - attention_ops.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - None, # alibi_slopes - ) + if self.quant_kv_cache: + attention_ops.single_query_cached_kv_quantized_attention( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + None, # alibi_slopes + *self.kv_quant_params, + ) + else: + attention_ops.single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + None, # alibi_slopes + ) def forward( self, @@ -221,13 +241,23 @@ def forward( if (num_valid_tokens > 0 and key_cache is not None and value_cache is not None): # The stride is 3 because the key and value are sliced from qkv. - cache_ops.reshape_and_cache( - key[:num_valid_tokens], - value[:num_valid_tokens], - key_cache, - value_cache, - input_metadata.slot_mapping, - ) + if self.quant_kv_cache: + cache_ops.reshape_and_cache_quantized( + key[:num_valid_tokens], + value[:num_valid_tokens], + key_cache, + value_cache, + input_metadata.slot_mapping, + *self.kv_quant_params, + ) + else: + cache_ops.reshape_and_cache( + key[:num_valid_tokens], + value[:num_valid_tokens], + key_cache, + value_cache, + input_metadata.slot_mapping, + ) if input_metadata.num_generation_tokens > 0: # Decoding run. @@ -259,6 +289,8 @@ def __init__( base: int = 10000, num_kv_heads: Optional[int] = None, is_neox_style: bool = True, + quant_kv_cache: bool = False, + kv_quant_params: torch.Tensor = None, ) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads) self.is_neox_style = is_neox_style diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 28821fb2351f..acd09753b184 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -2,11 +2,12 @@ import contextlib from typing import Type +import numpy as np import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, ParallelConfig from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -105,6 +106,22 @@ def get_model(model_config: ModelConfig) -> nn.Module: return model.eval() +def get_quant_model_kv(model_config: ModelConfig, parallel_config: ParallelConfig, + rank: int): + num_layers = model_config.get_num_layers(parallel_config) + ## num_layers * [k_scale, k_zp, v_scale, v_zp] + kv_quant_params_list = [] + for i in range(num_layers): + path = model_config.kv_quant_params_path + f"/layers.{i}.past_kv_scale.{rank}.weight" + kv_quant_params = list(np.fromfile(path, dtype=np.float32)) + kv_quant_params_list.append(kv_quant_params) + model_class = _get_model_architecture(model_config.hf_config) + torch.set_default_dtype(model_config.dtype) + model = model_class(model_config.hf_config, model_config.quant_kv_cache, kv_quant_params) + model = model.cuda() + return model.eval() + + def get_quant_model_v2(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) torch.set_default_dtype(model_config.dtype) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 51a05a4a09d6..8c919a2890a6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -82,50 +82,6 @@ def forward(self, x): x = self.act_fn(gate_up) x = self.down_proj(x) return x - - def trans_int8(self): - int8_gate_up = Linear8bitLt( - self.gate_up_proj.in_features, - self.gate_up_proj.out_features, - self.gate_up_proj.bias is not None, - has_fp16_weights=False, - threshold=6.0, - ) - int8_gate_up.weight = bnb.nn.Int8Params( - self.gate_up_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False - ).to(self.gate_up_proj.weight.dtype) - self.gate_up_proj = int8_gate_up - - int8_down = Linear8bitLt( - self.down_proj.in_features, - self.down_proj.out_features, - self.down_proj.bias is not None, - has_fp16_weights=False, - threshold=6.0, - ) - int8_down.weight = bnb.nn.Int8Params( - self.down_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False - ).to(self.down_proj.weight.dtype) - self.down_proj = int8_down - - def trans_fp4(self): - int8_gate_up = LinearFP4( - self.gate_up_proj.in_features, - self.gate_up_proj.out_features, - self.gate_up_proj.bias is not None - ) - int8_gate_up.weight = bnb.nn.Params4bit( - self.gate_up_proj.weight.data.clone(), requires_grad=False).to(self.gate_up_proj.weight.dtype) - self.gate_up_proj = int8_gate_up - - int8_down = LinearFP4( - self.down_proj.in_features, - self.down_proj.out_features, - self.down_proj.bias is not None - ) - int8_down.weight = bnb.nn.Params4bit( - self.down_proj.weight.data.clone(), requires_grad=False).to(self.down_proj.weight.dtype) - self.down_proj = int8_down class LlamaAttention(nn.Module): @@ -137,11 +93,12 @@ def __init__( num_kv_heads: int, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params: List[int] = None ) -> None: super().__init__() self.hidden_size = hidden_size - # tp_size = get_tensor_model_parallel_world_size() - tp_size = 1 + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -176,7 +133,9 @@ def __init__( self.scaling, base=self.rope_theta, rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + quant_kv_cache=quant_kv_cache, + kv_quant_params=kv_quant_params) def forward( self, @@ -193,50 +152,7 @@ def forward( input_metadata, cache_event) output = self.o_proj(attn_output) return output - - def trans_int8(self): - int8_qkv = Linear8bitLt( - self.qkv_proj.in_features, - self.qkv_proj.out_features, - self.qkv_proj.bias is not None, - has_fp16_weights=False, - threshold=6.0, - ) - int8_qkv.weight = bnb.nn.Int8Params( - self.qkv_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False - ).to(self.qkv_proj.weight.dtype) - self.qkv_proj = int8_qkv - - int8_o = Linear8bitLt( - self.o_proj.in_features, - self.o_proj.out_features, - self.o_proj.bias is not None, - has_fp16_weights=False, - threshold=6.0, - ) - int8_o.weight = bnb.nn.Int8Params( - self.o_proj.weight.data.clone(), requires_grad=False, has_fp16_weights=False - ).to(self.o_proj.weight.dtype) - self.o_proj = int8_o - - def trans_fp4(self): - int8_qkv = LinearFP4( - self.qkv_proj.in_features, - self.qkv_proj.out_features, - self.qkv_proj.bias is not None - ) - int8_qkv.weight = bnb.nn.Params4bit( - self.qkv_proj.weight.data.clone(), requires_grad=False).to(self.qkv_proj.weight.dtype) - self.qkv_proj = int8_qkv - - int8_o = LinearFP4( - self.o_proj.in_features, - self.o_proj.out_features, - self.o_proj.bias is not None - ) - int8_o.weight = bnb.nn.Params4bit( - self.o_proj.weight.data.clone(), requires_grad=False).to(self.o_proj.weight.dtype) - self.o_proj = int8_o + class LlamaDecoderLayer(nn.Module): @@ -244,6 +160,8 @@ def __init__( self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params: List[int] = None ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -255,6 +173,8 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, quant_config=quant_config, + quant_kv_cache=quant_kv_cache, + kv_quant_params=kv_quant_params ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -301,6 +221,8 @@ def __init__( self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params_list: List[List[int]] = None ) -> None: super().__init__() self.config = config @@ -311,7 +233,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) + LlamaDecoderLayer(config, quant_config, quant_kv_cache, kv_quant_params_list[i]) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,11 +270,13 @@ def __init__( self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params_list: List[List[int]] = None ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = LlamaModel(config, quant_config) + self.model = LlamaModel(config, quant_config, quant_kv_cache, kv_quant_params_list) vocab_size = ((config.vocab_size + 63) // 64) * 64 # NOTE: The LM head is not quantized. self.lm_head = ParallelLinear.column(config.hidden_size, @@ -406,7 +330,7 @@ def load_weights(self, self.config.num_attention_heads * self.config.num_key_value_heads // tp_size) attention_weight_specs = [ - # (weight_name, shard_size, offset) + # (weight_name, shard_size, offset), ("q_proj", q_proj_shard_size, 0), ("k_proj", kv_proj_shard_size, q_proj_shard_size), ("v_proj", kv_proj_shard_size, @@ -489,10 +413,3 @@ def load_weights(self, column_parallel_weights, row_parallel_weights, tensor_model_parallel_rank) - - for i in range(len(self.model.layers)): - layer = self.model.layers[i] - layer.mlp.trans_int8() - layer.self_attn.trans_int8() - - print(self.model) \ No newline at end of file diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 3d5a723d9d42..8471bac36b4d 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -34,7 +34,8 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) self.num_heads = model_config.get_num_heads(parallel_config) - self.dtype = model_config.dtype + ## for kv cache quantization + self.dtype = model_config.kv_cache_dtype self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 55b9b4033c07..cb5579f93089 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.model_executor import get_model, get_quant_model_v2, InputMetadata, set_random_seed +from vllm.model_executor import get_model, get_quant_model_v2, InputMetadata, set_random_seed, get_quant_model_kv from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sampling_params import SamplingParams @@ -65,7 +65,7 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) # self.model = get_model(self.model_config) - self.model = get_quant_model_v2(self.model_config) + self.model = get_quant_model_kv(self.model_config, self.parallel_config, self.rank) @torch.inference_mode() def profile_num_available_blocks( From bc9fada305ccad9fac6b86c93e91fdee41b5c065 Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Tue, 19 Sep 2023 16:05:07 +0800 Subject: [PATCH 09/52] fix python code --- vllm/config.py | 2 +- vllm/engine/arg_utils.py | 11 +++++++---- vllm/model_executor/layers/attention.py | 2 +- vllm/model_executor/model_loader.py | 19 ++++++++++--------- vllm/model_executor/models/__init__.py | 4 ++-- vllm/model_executor/models/llama.py | 18 +++++++++--------- vllm/worker/cache_engine.py | 2 +- 7 files changed, 31 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 39d04aff1058..4f9168f524d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -78,7 +78,7 @@ def __init__( self._verify_load_format() ## for kv cache quantization self.kv_cache_dtype = _STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] if kv_cache_dtype else self.dtype - self.quant_kv_cache = self.kv_cache_dtype == self.dtype + self.quant_kv_cache = not self.kv_cache_dtype == self.dtype self.kv_quant_params_path = kv_quant_params_path self._verify_tokenizer_mode() self._verify_quantization() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d43e83016fc2..c4b987761869 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -30,6 +30,8 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None quantization: Optional[str] = None + kv_cache_dtype: str = "float16" + kv_quant_params_path: str = None def __post_init__(self): if self.tokenizer is None: @@ -107,12 +109,12 @@ def add_cli_args( parser.add_argument( '--kv-cache-dtype', type=str, - default="float16", + default=EngineArgs.kv_cache_dtype, help='data type for kv cache') parser.add_argument( - 'kv-quant-params-path', + '--kv-quant-params-path', type=str, - default=None, + default=EngineArgs.kv_quant_params_path, help="path to kv scales and zero points" ) # Parallel arguments @@ -186,7 +188,8 @@ def create_engine_configs( self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, - self.max_model_len, self.quantization) + self.max_model_len, self.quantization, + self.kv_cache_dtype, self.kv_quant_params_path) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index a1d6dfd35dd1..dc090a0886be 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -58,7 +58,7 @@ def __init__(self, scale: float, num_kv_heads: Optional[int] = None, quant_kv_cache: bool = False, - kv_quant_params: List[int] = None) -> None: + kv_quant_params: List[float] = None) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index acd09753b184..7e05e1eadafa 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -24,10 +24,10 @@ "GPTJForCausalLM": GPTJForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "InternLMForCausalLM": InternLMForCausalLM, - # "LlamaForCausalLM": LlamaForCausalLM, - # "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* - "LlamaForCausalLM": LlamaQForCausalLM, - "LLaMAForCausalLM": LlamaQForCausalLM, # For decapoda-research/llama-* + "LlamaForCausalLM": LlamaForCausalLM, + "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + # "LlamaForCausalLM": LlamaQForCausalLM, + # "LLaMAForCausalLM": LlamaQForCausalLM, # For decapoda-research/llama-* "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, @@ -111,13 +111,14 @@ def get_quant_model_kv(model_config: ModelConfig, parallel_config: ParallelConfi num_layers = model_config.get_num_layers(parallel_config) ## num_layers * [k_scale, k_zp, v_scale, v_zp] kv_quant_params_list = [] - for i in range(num_layers): - path = model_config.kv_quant_params_path + f"/layers.{i}.past_kv_scale.{rank}.weight" - kv_quant_params = list(np.fromfile(path, dtype=np.float32)) - kv_quant_params_list.append(kv_quant_params) + if model_config.quant_kv_cache: + for i in range(num_layers): + path = model_config.kv_quant_params_path + f"/layers.{i}.past_kv_scale.{rank}.weight" + kv_quant_params = list(np.fromfile(path, dtype=np.float32)) + kv_quant_params_list.append(kv_quant_params) model_class = _get_model_architecture(model_config.hf_config) torch.set_default_dtype(model_config.dtype) - model = model_class(model_config.hf_config, model_config.quant_kv_cache, kv_quant_params) + model = model_class(model_config.hf_config, model_config.quant_kv_cache, kv_quant_params_list) model = model.cuda() return model.eval() diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f0525162172a..4481e85fee5d 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,7 +9,7 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.internlm import InternLMForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.llamaq import LlamaQForCausalLM +# from vllm.model_executor.models.llamaq import LlamaQForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel @@ -26,7 +26,7 @@ "GPTNeoXForCausalLM", "InternLMForCausalLM", "LlamaForCausalLM", - "LlamaQForCausalLM", + # "LlamaQForCausalLM", "MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8c919a2890a6..90f1a4aca42b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -78,9 +78,9 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - gate_up = self.gate_up_proj(x) + gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x = self.down_proj(x) + x, _ = self.down_proj(x) return x @@ -94,7 +94,7 @@ def __init__( rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, quant_kv_cache: bool = False, - kv_quant_params: List[int] = None + kv_quant_params: List[float] = None ) -> None: super().__init__() self.hidden_size = hidden_size @@ -145,12 +145,12 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - qkv = self.qkv_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) - output = self.o_proj(attn_output) + output, _ = self.o_proj(attn_output) return output @@ -161,7 +161,7 @@ def __init__( config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, quant_kv_cache: bool = False, - kv_quant_params: List[int] = None + kv_quant_params: List[float] = None ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -222,7 +222,7 @@ def __init__( config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, quant_kv_cache: bool = False, - kv_quant_params_list: List[List[int]] = None + kv_quant_params_list: List[List[float]] = None ) -> None: super().__init__() self.config = config @@ -233,7 +233,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config, quant_kv_cache, kv_quant_params_list[i]) + LlamaDecoderLayer(config, quant_config, quant_kv_cache, kv_quant_params_list[i] if quant_kv_cache else None) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -271,7 +271,7 @@ def __init__( config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, quant_kv_cache: bool = False, - kv_quant_params_list: List[List[int]] = None + kv_quant_params_list: List[List[float]] = None ) -> None: super().__init__() self.config = config diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 8471bac36b4d..2f3fd3237042 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -153,7 +153,7 @@ def get_cache_block_size( key_cache_block = block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) - dtype_size = _get_dtype_size(model_config.dtype) + dtype_size = _get_dtype_size(model_config.kv_cache_dtype) return dtype_size * total From 976874d99802c61eb7b245caad586c6ad38bb288 Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Wed, 20 Sep 2023 16:14:52 +0800 Subject: [PATCH 10/52] merge and reformat --- csrc/cache_kernels.cu | 7 +------ csrc/dispatch_utils.h | 11 ++++++++++- vllm/model_executor/layers/attention.py | 2 +- vllm/model_executor/model_loader.py | 2 +- vllm/model_executor/models/llama.py | 4 +++- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 85865eca4466..948193278d29 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -128,10 +128,7 @@ void copy_blocks( dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - at::ScalarType::Half, - // at::ScalarType::BFloat16, - at::ScalarType::Char, + VLLM_DISPATCH_QUANT_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -298,8 +295,6 @@ void reshape_and_cache_quantized( dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - at::ScalarType::Half, - at::ScalarType::BFloat16, key.scalar_type(), "reshape_and_cache_quantized_kernel", [&] { diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 7c0c49d392a9..921d453b703c 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -7,8 +7,17 @@ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + // AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) + +#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index dc090a0886be..48cb1a2e1ee4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -292,7 +292,7 @@ def __init__( quant_kv_cache: bool = False, kv_quant_params: torch.Tensor = None, ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads) + super().__init__(num_heads, head_size, scale, num_kv_heads, quant_kv_cache, kv_quant_params) self.is_neox_style = is_neox_style # Create the cos and sin cache. diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 7e05e1eadafa..3caa8dce79ad 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -118,7 +118,7 @@ def get_quant_model_kv(model_config: ModelConfig, parallel_config: ParallelConfi kv_quant_params_list.append(kv_quant_params) model_class = _get_model_architecture(model_config.hf_config) torch.set_default_dtype(model_config.dtype) - model = model_class(model_config.hf_config, model_config.quant_kv_cache, kv_quant_params_list) + model = model_class(model_config.hf_config, None, model_config.quant_kv_cache, kv_quant_params_list) ## None is for quant config model = model.cuda() return model.eval() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 90f1a4aca42b..4972c1812104 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -232,9 +232,11 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) + # print(kv_quant_params_list) + # print(quant_kv_cache) self.layers = nn.ModuleList([ LlamaDecoderLayer(config, quant_config, quant_kv_cache, kv_quant_params_list[i] if quant_kv_cache else None) - for _ in range(config.num_hidden_layers) + for i in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 2c0c3116c989f6fcce9fdd24cb3e8d89d9e32a46 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Wed, 20 Sep 2023 17:06:45 +0800 Subject: [PATCH 11/52] add int8gemm --- csrc/int8gemm/bindings.cpp | 20 + csrc/int8gemm/bmm.cu | 211 ++++++++ csrc/int8gemm/fused.cu | 25 + csrc/int8gemm/include/bmm.h | 10 + csrc/int8gemm/include/common.h | 11 + csrc/int8gemm/include/fused.h | 16 + csrc/int8gemm/include/linear.h | 43 ++ csrc/int8gemm/linear.cu | 491 ++++++++++++++++++ csrc/int8gemm/setup.py | 29 ++ setup.py | 20 +- .../layers/int8_linear/__init__.py | 0 .../layers/int8_linear/quantization.py | 97 ++++ .../model_executor/layers/int8_linear/test.py | 38 ++ .../layers/int8_linear/w8a8linear.py | 217 ++++++++ 14 files changed, 1227 insertions(+), 1 deletion(-) create mode 100644 csrc/int8gemm/bindings.cpp create mode 100644 csrc/int8gemm/bmm.cu create mode 100644 csrc/int8gemm/fused.cu create mode 100644 csrc/int8gemm/include/bmm.h create mode 100644 csrc/int8gemm/include/common.h create mode 100644 csrc/int8gemm/include/fused.h create mode 100644 csrc/int8gemm/include/linear.h create mode 100644 csrc/int8gemm/linear.cu create mode 100644 csrc/int8gemm/setup.py create mode 100644 vllm/model_executor/layers/int8_linear/__init__.py create mode 100644 vllm/model_executor/layers/int8_linear/quantization.py create mode 100644 vllm/model_executor/layers/int8_linear/test.py create mode 100644 vllm/model_executor/layers/int8_linear/w8a8linear.py diff --git a/csrc/int8gemm/bindings.cpp b/csrc/int8gemm/bindings.cpp new file mode 100644 index 000000000000..4eaf7bc3b7e7 --- /dev/null +++ b/csrc/int8gemm/bindings.cpp @@ -0,0 +1,20 @@ +#include "include/bmm.h" +#include "include/fused.h" +#include "include/linear.h" +#include +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_relu_a8_w8_b8_o8", &linear_relu_a8_w8_b8_o8, + "Linear ReLU (INT8)"); + m.def("linear_a8_w8_b32_o32", &linear_a8_w8_b32_o32, "Linear (INT32)"); + m.def("linear_a8_w8_bfp32_ofp32", &linear_a8_w8_bfp32_ofp32, + "Linear (I8-OFP32)"); + m.def("linear_a8_w8_b32_o32_with_scaling", &linear_a8_w8_b32_o32_with_scaling, + "Linear (INT32) with scaling"); + m.def("linear_a8_w8_b8_o8", &linear_a8_w8_b8_o8, "Linear (INT8)"); + m.def("dq_add_layernorm_q", &dq_add_layernorm_q, + "DQ + Add + LayerNorm (INT8)"); + m.def("bmm_s8t_s8n_s8t", &bmm_s8t_s8n_s8t, "BMM (INT8 IO) A x B.T"); + m.def("bmm_s8t_s8n_f32t", &bmm_s8t_s8n_f32t, "BMM (INT8 I FP32 O) A x B.T"); + m.def("bmm_s8t_s8n_s32t", &bmm_s8t_s8n_s32t, + "BMM (INT8 In Int32 Out) A x B.T"); +} diff --git a/csrc/int8gemm/bmm.cu b/csrc/int8gemm/bmm.cu new file mode 100644 index 000000000000..93b8e06f96d8 --- /dev/null +++ b/csrc/int8gemm/bmm.cu @@ -0,0 +1,211 @@ +#include "include/bmm.h" +#include "include/common.h" +#include +#include +#include +#include +#include +#include + +torch::Tensor bmm_s8t_s8n_f32t(torch::Tensor A, torch::Tensor B, float alpha) { + int batch_size = A.size(0); + int M = A.size(1); + int N = B.size(1); + int K = A.size(2); + + auto C = torch::empty({batch_size, M, N}, + torch::dtype(torch::kFloat32).device(A.device())); + int lda = A.size(2); + int ldb = B.size(2); + int ldc = C.size(2); + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementOutput = float; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementComputeEpilogue>; + + using Gemm = cutlass::gemm::device::GemmBatched< + ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, + LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp>; + + long long int batch_stride_A = M * K; + long long int batch_stride_B = N * K; + long long int batch_stride_C = M * N; + + Gemm gemm_op; + typename Gemm::Arguments arguments{ + {M, N, K}, {A.data_ptr(), lda}, + batch_stride_A, {B.data_ptr(), ldb}, + batch_stride_B, {C.data_ptr(), ldc}, + batch_stride_C, {C.data_ptr(), ldc}, + batch_stride_C, {alpha, 0}, + batch_size}; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } + return C; +} + +torch::Tensor bmm_s8t_s8n_s8t(torch::Tensor A, torch::Tensor B, float alpha) { + int batch_size = A.size(0); + int M = A.size(1); + int N = B.size(1); + int K = A.size(2); + + auto C = torch::empty({batch_size, M, N}, + torch::dtype(torch::kInt8).device(A.device())); + int lda = A.size(2); + int ldb = B.size(2); + int ldc = C.size(2); + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementOutput = int8_t; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementComputeEpilogue>; + + using Gemm = cutlass::gemm::device::GemmBatched< + ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, + LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp>; + + long long int batch_stride_A = M * K; + long long int batch_stride_B = N * K; + long long int batch_stride_C = M * N; + + Gemm gemm_op; + typename Gemm::Arguments arguments{ + {M, N, K}, {A.data_ptr(), lda}, + batch_stride_A, {B.data_ptr(), ldb}, + batch_stride_B, {C.data_ptr(), ldc}, + batch_stride_C, {C.data_ptr(), ldc}, + batch_stride_C, {alpha, 0}, + batch_size}; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } + return C; +} + +torch::Tensor bmm_s8t_s8n_s32t(torch::Tensor A, torch::Tensor B) { + int batch_size = A.size(0); + int M = A.size(1); + int N = B.size(1); + int K = A.size(2); + + auto C = torch::empty({batch_size, M, N}, + torch::dtype(torch::kInt32).device(A.device())); + int lda = A.size(2); + int ldb = B.size(2); + int ldc = C.size(2); + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementOutput = int32_t; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = int32_t; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementComputeEpilogue>; + + using Gemm = cutlass::gemm::device::GemmBatched< + ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, + LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp>; + + long long int batch_stride_A = M * K; + long long int batch_stride_B = N * K; + long long int batch_stride_C = M * N; + + Gemm gemm_op; + + ElementComputeEpilogue alpha = 1; + + cutlass::Status status = gemm_op({{M, N, K}, + {A.data_ptr(), lda}, + batch_stride_A, + {B.data_ptr(), ldb}, + batch_stride_B, + {C.data_ptr(), ldc}, + batch_stride_C, + {C.data_ptr(), ldc}, + batch_stride_C, + {alpha, 0}, + batch_size}); + + if (status != cutlass::Status::kSuccess) { + std::cout << "cutlass error code: " << (int)status << std::endl; + } + return C; +} \ No newline at end of file diff --git a/csrc/int8gemm/fused.cu b/csrc/int8gemm/fused.cu new file mode 100644 index 000000000000..ee5e99d89f68 --- /dev/null +++ b/csrc/int8gemm/fused.cu @@ -0,0 +1,25 @@ +#include "include/fused.h" +#include "include/common.h" + + +std::tuple // (residual_output (FP), ln_output (INT8)) +dq_add_layernorm_q( + torch::Tensor input, // INT32 + float input_scale, // FP + torch::Tensor residual_input, // FP + torch::Tensor gamma, // FP + torch::Tensor beta, // FP + float epsilon // FP + ) // The output scale has already been fused into gamma and beta +{ + // residual_output = residual_input + input * input_scale + auto residual_output_fp = torch::add(residual_input, input, input_scale); + + auto ln_output_fp = + torch::layer_norm(residual_output_fp, {residual_output_fp.size(-1)}, + gamma, beta, epsilon); + ln_output_fp.clamp_(-128, 127).round_(); + auto ln_output_int8 = ln_output_fp.to(torch::kChar); + return std::make_tuple(residual_output_fp, ln_output_int8); +} \ No newline at end of file diff --git a/csrc/int8gemm/include/bmm.h b/csrc/int8gemm/include/bmm.h new file mode 100644 index 000000000000..847265d72a33 --- /dev/null +++ b/csrc/int8gemm/include/bmm.h @@ -0,0 +1,10 @@ +#ifndef BMM_H +#define BMM_H +#include +torch::Tensor bmm_s8t_s8n_f32t(torch::Tensor A, torch::Tensor B, float alpha); + +torch::Tensor bmm_s8t_s8n_s8t(torch::Tensor A, torch::Tensor B, float alpha); + +torch::Tensor bmm_s8t_s8n_s32t(torch::Tensor A, torch::Tensor B); + +#endif // BMM_H \ No newline at end of file diff --git a/csrc/int8gemm/include/common.h b/csrc/int8gemm/include/common.h new file mode 100644 index 000000000000..2f3bdd3221b8 --- /dev/null +++ b/csrc/int8gemm/include/common.h @@ -0,0 +1,11 @@ +#ifndef COMMON_H +#define COMMON_H +#include +#include +#include +#include +#include +#include + + +#endif // !COMMON \ No newline at end of file diff --git a/csrc/int8gemm/include/fused.h b/csrc/int8gemm/include/fused.h new file mode 100644 index 000000000000..42ac634507ef --- /dev/null +++ b/csrc/int8gemm/include/fused.h @@ -0,0 +1,16 @@ +#ifndef FUSED_H +#define FUSED_H + +#include + +std::tuple // (residual_output (FP), ln_output (INT8)) +dq_add_layernorm_q(torch::Tensor input, // INT32 + float input_scale, // FP + torch::Tensor residual_input, // FP + torch::Tensor gamma, // FP + torch::Tensor beta, // FP + float epsilon // FP +); // The output scale has already been fused into gamma and beta + +#endif // FUSED_H \ No newline at end of file diff --git a/csrc/int8gemm/include/linear.h b/csrc/int8gemm/include/linear.h new file mode 100644 index 000000000000..5df6ac6f1bc1 --- /dev/null +++ b/csrc/int8gemm/include/linear.h @@ -0,0 +1,43 @@ +#ifndef LINEAR_H +#define LINEAR_H +#include + +// used by out_proj and fc2, return INT32 +torch::Tensor linear_a8_w8_b32_o32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias // INT32 +); + +// used by out_proj and fc2, return INT32 +torch::Tensor linear_a8_w8_b32_o32_with_scaling(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT32 + float alpha, // FP32 + float beta // FP32 +); + +// used by out_proj and fc2, return FP32 +torch::Tensor linear_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +); + +// used by fc1, return INT8 +torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +); + +// used by q_proj, k_proj, v_proj, return INT8 +torch::Tensor linear_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +); + +#endif // LINEAR_HS \ No newline at end of file diff --git a/csrc/int8gemm/linear.cu b/csrc/int8gemm/linear.cu new file mode 100644 index 000000000000..0e11d7b46175 --- /dev/null +++ b/csrc/int8gemm/linear.cu @@ -0,0 +1,491 @@ +#include "include/linear.h" +#include "include/common.h" + +#include +#include +#include + +#include +#include +#include + +// used by out_proj and fc2, return INT32 +torch::Tensor linear_a8_w8_b32_o32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias // INT32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = int32_t; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementComputeEpilogue, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } + + return out; +} + +// used by out_proj and fc2, return INT32 +torch::Tensor linear_a8_w8_b32_o32_with_scaling(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT32 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementComputeEpilogue>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } + + return out; +} + +// used by out_proj and fc2, return FP32 +torch::Tensor linear_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = float; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementComputeEpilogue>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } + + return out; +} + + +// used by q_proj, k_proj, v_proj, return INT8 +torch::Tensor linear_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + auto device = input.device(); + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement, status: " + + std::to_string((int)status)); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize, status: " + + std::to_string((int)status)); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run, status: " + + std::to_string((int)status)); + } + + return out; +} + +// used by fc1 +torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement, status: " + + std::to_string((int)status)); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize, status: " + + std::to_string((int)status)); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run, status: " + + std::to_string((int)status)); + } + + return out; +} \ No newline at end of file diff --git a/csrc/int8gemm/setup.py b/csrc/int8gemm/setup.py new file mode 100644 index 000000000000..e1e5da93d48d --- /dev/null +++ b/csrc/int8gemm/setup.py @@ -0,0 +1,29 @@ +# adapt from https://github.com/Guangxuan-Xiao/torch-int +from setuptools import setup, find_packages +from torch.utils import cpp_extension + +setup( + name='intgemm', + ext_modules=[ + cpp_extension.CUDAExtension( + name='intgemm._CUDA', + sources=[ + 'linear.cu', + 'bmm.cu', + 'fused.cu', + 'bindings.cpp', + ], + include_dirs=['include'], + extra_link_args=['-lcublas_static', '-lcublasLt_static', + '-lculibos', '-lcudart', '-lcudart_static', + '-lrt', '-lpthread', '-ldl', '-L/usr/lib/x86_64-linux-gnu/'], + extra_compile_args={'cxx': ['-std=c++14', '-O3'], + 'nvcc': ['-O3', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__']}, + ), + ], + cmdclass={ + 'build_ext': cpp_extension.BuildExtension.with_options(use_ninja=False) + }, + packages=find_packages( + exclude=['notebook', 'scripts', 'tests']), +) diff --git a/setup.py b/setup.py index 047ee8d0e894..37113961cdf7 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,24 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ext_modules = [] +# Int8GEMM(cutlass required) +i8gemm_extension = CUDAExtension( + name='vllm.i8gemm', + sources=[ + 'csrc/int8gemm/linear.cu', + 'csrc/int8gemm/bmm.cu', + 'csrc/int8gemm/fused.cu', + # 'csrc/int8gemm/bindings.cpp', + ], + include_dirs=['csrc/int8gemm/include'], + extra_link_args=['-lcublas_static', '-lcublasLt_static', + '-lculibos', '-lcudart', '-lcudart_static', + '-lrt', '-lpthread', '-ldl', '-L/usr/lib/x86_64-linux-gnu/'], + extra_compile_args={'cxx': ['-std=c++14', '-O3'], + 'nvcc': ['-O3', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__']}, +) +ext_modules.append(i8gemm_extension) + # Cache operations. cache_extension = CUDAExtension( name="vllm.cache_ops", @@ -217,5 +235,5 @@ def get_requirements() -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension}, + cmdclass={"build_ext": BuildExtension.with_options(use_ninja=False)}, ) diff --git a/vllm/model_executor/layers/int8_linear/__init__.py b/vllm/model_executor/layers/int8_linear/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/int8_linear/quantization.py b/vllm/model_executor/layers/int8_linear/quantization.py new file mode 100644 index 000000000000..555e09c39618 --- /dev/null +++ b/vllm/model_executor/layers/int8_linear/quantization.py @@ -0,0 +1,97 @@ +# adapt from https://github.com/Guangxuan-Xiao/torch-int +import torch +import numpy as np + +@torch.no_grad() +def quantize_per_tensor_absmax(t): + scale = t.abs().max() / 127 + if not t.is_cuda: + # half rounding is not supported on CPU + t = t.float() + # use inplace operation to save memory + t.div_(scale).round_() + t_q = t.to(torch.int8) + return t_q, scale + +@torch.no_grad() +def quantize_weight_per_channel_absmax(w): + # w: [out_channel, in_channel] + scales = w.abs().max(dim=1)[0] / 127 + scales = scales.view(-1, 1) + if not w.is_cuda: + # half rounding is not supported on CPU + w = w.float() + # use inplace operation to save memory + w.div_(scales).round_().clamp_(-128, 127) + w_q = w.to(torch.int8) + return w_q, scales + + +@torch.no_grad() +def dynamic_quantize_activation_per_tensor_zeropoint(t): + max_val = t.max()[0] + min_val = t.min()[0] + quant_min = -127 + quant_max = 127 + nudged_scale = (max_val - min_val) / (quant_max - quant_min) + zp = (max_val + min_val) / 2 + zp = (zp / nudged_scale).round() * nudged_scale + t -= zp + max_val = (max_val - min_val) / 2 + + max_val = torch.clamp(max_val, min=1e-8) / 127 + q_act = (t / max_val).round().clamp(-128, 127).to(torch.int8) + return q_act, max_val, zp + + +@torch.no_grad() +def dynamic_quantize_activation_per_tensor_absmax(t): + max_val = t.abs().max() + max_val = torch.clamp(max_val, min=1e-8) / 127 + q_act = (t / max_val).round().clamp(-128, 127).to(torch.int8) + return q_act, max_val + + +@torch.no_grad() +def dynamic_quantize_activation_per_token_absmax(t): + max_val = t.abs().max(dim=-1, keepdim=True)[0] + max_val = torch.clamp(max_val, min=1e-8) / 127 + t.div_(max_val).round_().clamp_(-128, 127) + q_act = t.to(torch.int8) + return q_act, max_val + +@torch.no_grad() +def fake_quantize_activation_per_tensor_absmax(t): + max_val = t.abs().max() + max_val = torch.clamp(max_val, min=1e-8) / 127 + t.div_(max_val).round_().clamp_(-128, 127).mul_(max_val) + return t + + +@torch.no_grad() +def fake_quantize_activation_per_token_absmax(t): + max_val = t.abs().max(dim=-1, keepdim=True)[0] + max_val = torch.clamp(max_val, min=1e-8) / 127 + t.div_(max_val).round_().clamp_(-128, 127).mul_(max_val) + return t + + +@torch.no_grad() +def dequantize_activation_w_per_channel_a_per_token(q_act, w_scales, a_scales): + # q_act: [B, dim] + # w_scales: [dim] + # a_scales: [B 1] + dtype = a_scales.dtype + q_act = q_act.to(torch.float32) + q_act.mul_(w_scales.reshape(1, -1)).mul_(a_scales.reshape(-1, 1)) + return q_act.to(dtype) + +@torch.no_grad() +def dequantize_activation_w_per_channel_a_per_tensor(q_act, w_scales, a_scales): + # q_act: [..., dim] + # w_scales: [dim] + # a_scales: [1] + dtype = a_scales.dtype + q_act = q_act.to(torch.float32) + q_act = q_act * w_scales.reshape(1, -1) * a_scales + return q_act.to(dtype) diff --git a/vllm/model_executor/layers/int8_linear/test.py b/vllm/model_executor/layers/int8_linear/test.py new file mode 100644 index 000000000000..b83d36e9b08e --- /dev/null +++ b/vllm/model_executor/layers/int8_linear/test.py @@ -0,0 +1,38 @@ +import torch +from w8a8linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear +from icecream import ic + +@torch.no_grad() +def test_w8a8b8o8_linear(): + B, M, N = 128, 512, 1024 + x = torch.randn(B, M) + x_scale = x.abs().max() / 127 + qx = (x / x_scale).round().to(torch.int8) + linear = torch.nn.Linear(M, N, bias=True) + y_gt = linear(x) + y_scale = y_gt.abs().max() / 127 + q_linear = W8A8B8O8Linear.from_float(linear, x_scale, y_scale).cuda() + q_y = q_linear(qx.cuda()).cpu() + y_hat = q_y * y_scale + r2 = (y_gt - y_hat).pow(2).mean() / y_gt.pow(2).mean() + ic(r2) + +@torch.no_grad() +def test_w8a8bfp32ofp32_linear(): + B, M, N = 128, 512, 1024 + x = torch.randn(B, M) + x_scale = x.abs().max() / 127 + qx = (x / x_scale).round().to(torch.int8) + linear = torch.nn.Linear(M, N, bias=True) + y_gt = linear(x) + q_linear = W8A8BFP32OFP32Linear.from_float(linear, x_scale).cuda() + y_hat = q_linear(qx.cuda()).cpu() + r2 = (y_gt - y_hat).pow(2).mean() / y_gt.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + print('test_w8a8b8o8_linear') + test_w8a8b8o8_linear() + print('test_w8a8bfp32ofp32_linear') + test_w8a8bfp32ofp32_linear() \ No newline at end of file diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py new file mode 100644 index 000000000000..cbbd569dc3c2 --- /dev/null +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -0,0 +1,217 @@ +# adapt from https://github.com/Guangxuan-Xiao/torch-int +import torch +from intgemm._CUDA import (linear_a8_w8_b32_o32, + linear_relu_a8_w8_b8_o8, + linear_a8_w8_b8_o8, + linear_a8_w8_b32_o32_with_scaling, + linear_a8_w8_bfp32_ofp32 + ) +from quantization import ( + quantize_per_tensor_absmax, + quantize_weight_per_channel_absmax, + fake_quantize_activation_per_tensor_absmax, + fake_quantize_activation_per_token_absmax, +) + +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + self.register_buffer('b', torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, + self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + # int8_bias, bias_scale should be 0, 0.0 + mockbias = torch.zeros((1, module.out_features), dtype=torch.int8, requires_grad=False) + int8_bias, bias_scale = quantize_per_tensor_absmax(mockbias) + alpha = input_scale * weight_scale / output_scale + beta = bias_scale / output_scale + int8_module.weight = int8_weight + int8_module.bias = int8_bias + int8_module.a = alpha + int8_module.b = beta + return int8_module + + +class W8A8B8O8LinearWithSFactor(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0, inscale=1.0, ouscale=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + self.register_buffer('b', torch.tensor(beta)) + self.register_buffer('inscale', torch.tensor(inscale)) + self.register_buffer('ouscale', torch.tensor(ouscale)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, + self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8LinearWithSFactor( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + mockbias = torch.zeros((1, module.out_features), dtype=torch.int8, requires_grad=False) + int8_bias, bias_scale = quantize_per_tensor_absmax(mockbias) + alpha = input_scale * weight_scale / output_scale + beta = bias_scale / output_scale + int8_module.weight = int8_weight + int8_module.bias = int8_bias + int8_module.a = alpha + int8_module.b = beta + int8_module.inscale = input_scale + int8_module.ouscale = output_scale + return int8_module + + +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.float32, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + self.bias = self.bias.to(torch.float32) + # beta should be 1 ? + y = linear_a8_w8_bfp32_ofp32( + x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) + int8_module.bias = mockbias.to(torch.float32) + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + return int8_module + + +class W8A8BFP32OFP32LinearWithSFactor(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.float32, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + self.register_buffer('inscale', torch.tensor(inscale)) + + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + self.inscale = self.inscale.to(*args, **kwargs) + self.inscale = self.inscale.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + self.bias = self.bias.to(torch.float32) + # beta should be 1 ? + y = linear_a8_w8_bfp32_ofp32( + x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32LinearWithSFactor( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) + int8_module.bias = mockbias.to(torch.float32) + int8_module.a = alpha + int8_module.inscale = torch.tensor(input_scale) + return int8_module From 27e3b4b832374db27cc6e30f47b1a0cb90393fcd Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Wed, 20 Sep 2023 21:01:34 +0800 Subject: [PATCH 12/52] support int8 inference --- csrc/int8gemm/bindings.cpp | 4 + csrc/int8gemm/setup.py | 29 --- setup.py | 4 +- vllm/config.py | 2 +- vllm/engine/arg_utils.py | 2 +- vllm/model_executor/__init__.py | 3 +- .../layers/int8_linear/w8a8linear.py | 31 ++- vllm/model_executor/model_loader.py | 27 +-- vllm/model_executor/models/llama.py | 179 +++++++++++++++--- .../quantization_utils/__init__.py | 2 + .../quantization_utils/smoothquant.py | 70 +++++++ vllm/worker/worker.py | 2 +- 12 files changed, 254 insertions(+), 101 deletions(-) delete mode 100644 csrc/int8gemm/setup.py create mode 100644 vllm/model_executor/quantization_utils/smoothquant.py diff --git a/csrc/int8gemm/bindings.cpp b/csrc/int8gemm/bindings.cpp index 4eaf7bc3b7e7..3bc20df7fbb0 100644 --- a/csrc/int8gemm/bindings.cpp +++ b/csrc/int8gemm/bindings.cpp @@ -2,6 +2,10 @@ #include "include/fused.h" #include "include/linear.h" #include + +/* +adapt from https://github.com/Guangxuan-Xiao/torch-int +*/ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("linear_relu_a8_w8_b8_o8", &linear_relu_a8_w8_b8_o8, "Linear ReLU (INT8)"); diff --git a/csrc/int8gemm/setup.py b/csrc/int8gemm/setup.py deleted file mode 100644 index e1e5da93d48d..000000000000 --- a/csrc/int8gemm/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -# adapt from https://github.com/Guangxuan-Xiao/torch-int -from setuptools import setup, find_packages -from torch.utils import cpp_extension - -setup( - name='intgemm', - ext_modules=[ - cpp_extension.CUDAExtension( - name='intgemm._CUDA', - sources=[ - 'linear.cu', - 'bmm.cu', - 'fused.cu', - 'bindings.cpp', - ], - include_dirs=['include'], - extra_link_args=['-lcublas_static', '-lcublasLt_static', - '-lculibos', '-lcudart', '-lcudart_static', - '-lrt', '-lpthread', '-ldl', '-L/usr/lib/x86_64-linux-gnu/'], - extra_compile_args={'cxx': ['-std=c++14', '-O3'], - 'nvcc': ['-O3', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__']}, - ), - ], - cmdclass={ - 'build_ext': cpp_extension.BuildExtension.with_options(use_ninja=False) - }, - packages=find_packages( - exclude=['notebook', 'scripts', 'tests']), -) diff --git a/setup.py b/setup.py index 37113961cdf7..8a461c8cae63 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: 'csrc/int8gemm/linear.cu', 'csrc/int8gemm/bmm.cu', 'csrc/int8gemm/fused.cu', - # 'csrc/int8gemm/bindings.cpp', + 'csrc/int8gemm/bindings.cpp', ], include_dirs=['csrc/int8gemm/include'], extra_link_args=['-lcublas_static', '-lcublasLt_static', @@ -235,5 +235,5 @@ def get_requirements() -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension.with_options(use_ninja=False)}, + cmdclass={"build_ext": BuildExtension}, ) diff --git a/vllm/config.py b/vllm/config.py index 4f9168f524d3..3e2d50425c2f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -112,7 +112,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq"] + supported_quantization = ["awq", "smoothquant"] if self.quantization is None: return quantization = self.quantization.lower() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c4b987761869..ad3b3883107b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -168,7 +168,7 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', None], + choices=['awq', "smoothquant", None], default=None, help='Method used to quantize the weights') return parser diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index ab7a59dab318..81f0eeef2397 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,11 +1,10 @@ from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.model_loader import get_model, get_quant_model_v2, get_quant_model_kv +from vllm.model_executor.model_loader import get_model, get_quant_model_kv from vllm.model_executor.utils import set_random_seed __all__ = [ "InputMetadata", "get_model", - "get_quant_model_v2", "set_random_seed", "get_quant_model_kv" ] diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index cbbd569dc3c2..0d77e70b519d 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -1,12 +1,7 @@ # adapt from https://github.com/Guangxuan-Xiao/torch-int import torch -from intgemm._CUDA import (linear_a8_w8_b32_o32, - linear_relu_a8_w8_b8_o8, - linear_a8_w8_b8_o8, - linear_a8_w8_b32_o32_with_scaling, - linear_a8_w8_bfp32_ofp32 - ) -from quantization import ( +from vllm import i8gemm +from .quantization import ( quantize_per_tensor_absmax, quantize_weight_per_channel_absmax, fake_quantize_activation_per_tensor_absmax, @@ -37,10 +32,12 @@ def to(self, *args, **kwargs): def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) - y = linear_a8_w8_b8_o8(x, self.weight, self.bias, + x = (x / self.inscale).clamp(-128, 127).to(torch.int8) + y = i8gemm.linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) y = y.view(*x_shape[:-1], -1) - return y + # FIXME: Just adapt to ParallelLinears' output + return y, None @staticmethod def from_float(module: torch.nn.Linear, input_scale, output_scale): @@ -85,10 +82,10 @@ def to(self, *args, **kwargs): def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) - y = linear_a8_w8_b8_o8(x, self.weight, self.bias, + y = i8gemm.linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) y = y.view(*x_shape[:-1], -1) - return y + return y, None @staticmethod def from_float(module: torch.nn.Linear, input_scale, output_scale): @@ -139,11 +136,10 @@ def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) self.bias = self.bias.to(torch.float32) - # beta should be 1 ? - y = linear_a8_w8_bfp32_ofp32( + y = i8gemm.linear_a8_w8_bfp32_ofp32( x, self.weight, self.bias, self.a.item(), 1) y = y.view(*x_shape[:-1], -1) - return y + return y, None @staticmethod def from_float(module: torch.nn.Linear, input_scale): @@ -196,12 +192,13 @@ def to(self, *args, **kwargs): def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) + # quant activation + x = (x / self.inscale).clamp(-128, 127).to(torch.int8) self.bias = self.bias.to(torch.float32) - # beta should be 1 ? - y = linear_a8_w8_bfp32_ofp32( + y = i8gemm.linear_a8_w8_bfp32_ofp32( x, self.weight, self.bias, self.a.item(), 1) y = y.view(*x_shape[:-1], -1) - return y + return y, None @staticmethod def from_float(module: torch.nn.Linear, input_scale): diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 3caa8dce79ad..4bb0084dce45 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -89,10 +89,15 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. + if model_config.quant_kv_cache: + for i in range(num_layers): + path = model_config.kv_quant_params_path + f"/layers.{i}.past_kv_scale.{rank}.weight" + kv_quant_params = list(np.fromfile(path, dtype=np.float32)) + kv_quant_params_list.append(kv_quant_params) if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: - model = model_class(model_config.hf_config, quant_config) + model = model_class(model_config.hf_config, quant_config, model_config.quant_kv_cache, kv_quant_params_list) else: - model = model_class(model_config.hf_config) + model = model_class(model_config.hf_config, None, model_config.quant_kv_cache, kv_quant_params_list) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign @@ -120,22 +125,4 @@ def get_quant_model_kv(model_config: ModelConfig, parallel_config: ParallelConfi torch.set_default_dtype(model_config.dtype) model = model_class(model_config.hf_config, None, model_config.quant_kv_cache, kv_quant_params_list) ## None is for quant config model = model.cuda() - return model.eval() - - -def get_quant_model_v2(model_config: ModelConfig) -> nn.Module: - model_class = _get_model_architecture(model_config.hf_config) - torch.set_default_dtype(model_config.dtype) - - # Create a model instance. - # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config) - - int4_path = "/mnt/dolphinfs/hdd_pool/docker/share/1/zhangpeng/quanted/quant_cache/llama" - fp16_path = "/mnt/dolphinfs/hdd_pool/docker/share/1/zhangpeng/zhangpeng/model_weights/llama/13b" - - model.load_mix_weights2(fp16_path, int4_path, model_config.download_dir, - model_config.use_np_weights) - model = model.cuda() - return model.eval() \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4972c1812104..7b5d5cd04889 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,6 +37,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.int8_linear.w8a8linear import W8A8BFP32OFP32LinearWithSFactor from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -60,18 +61,32 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, + if quant_config is not None and quant_config.get_name() == "smoothquant": + self.gate_up_proj = W8A8BFP32OFP32LinearWithSFactor(hidden_size, 2 * intermediate_size, bias=False, gather_output=False, perform_initialization=False, quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False, - quant_config=quant_config) + self.down_proj = W8A8BFP32OFP32LinearWithSFactor(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config) + else: + self.gate_up_proj = ParallelLinear.column(hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=quant_config) + self.down_proj = ParallelLinear.row(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -111,23 +126,32 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.qkv_proj = ParallelLinear.column( - hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, - bias=False, - gather_output=False, - perform_initialization=False, - quant_config=quant_config, - ) - self.o_proj = ParallelLinear.row( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False, - quant_config=quant_config, - ) + if quant_config is not None and quant_config.get_name() == "smoothquant": + self.qkv_proj = W8A8BFP32OFP32LinearWithSFactor( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim) + self.o_proj = W8A8BFP32OFP32LinearWithSFactor( + self.total_num_heads * self.head_dim, + hidden_size) + else: + self.qkv_proj = ParallelLinear.column( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=quant_config, + ) + self.o_proj = ParallelLinear.row( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config, + ) self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, @@ -311,6 +335,14 @@ def load_weights(self, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): + if self.quant_config is not None and self.quant_config.get_name() == "smoothquant": + return self._load_int8_weights( + model_name_or_path, + cache_dir, + load_format, + revision + ) + if self.quant_config is None: weight_suffixes = ["weight"] else: @@ -332,7 +364,7 @@ def load_weights(self, self.config.num_attention_heads * self.config.num_key_value_heads // tp_size) attention_weight_specs = [ - # (weight_name, shard_size, offset), + # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), ("k_proj", kv_proj_shard_size, q_proj_shard_size), ("v_proj", kv_proj_shard_size, @@ -340,9 +372,6 @@ def load_weights(self, ] state_dict = self.state_dict() - # for name, param in state_dict.items(): - # print(f"state_dict name: {name}, param shape: {param.shape}") - for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: @@ -415,3 +444,97 @@ def load_weights(self, column_parallel_weights, row_parallel_weights, tensor_model_parallel_rank) + + def _load_int8_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + # TODO: support tp in intlinear + tp_size = 1 + tensor_model_parallel_rank = 0 + q_proj_shard_size = (self.config.hidden_size // tp_size) + kv_proj_shard_size = (self.config.hidden_size // + self.config.num_attention_heads * + self.config.num_key_value_heads // tp_size) + attention_weight_specs = [ + # (weight_name, shard_size, offset) + ("q_proj", q_proj_shard_size, 0), + ("k_proj", kv_proj_shard_size, q_proj_shard_size), + ("v_proj", kv_proj_shard_size, + q_proj_shard_size + kv_proj_shard_size), + ] + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + is_packed = False + is_transposed = False + if self.quant_config is not None: + is_packed = self.quant_config.is_packed(name) + is_transposed = self.quant_config.is_transposed(name) + if is_transposed: + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + loaded_weight = loaded_weight.T + + is_attention_weight = False + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "qkv_proj")] + if is_transposed: + param = param.T + + if is_packed: + shard_size //= self.quant_config.pack_factor + offset //= self.quant_config.pack_factor + + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[offset:offset + shard_size] + assert param_slice.shape == loaded_weight.shape + + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + if is_transposed: + param = param.T + + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + if is_transposed: + param = param.T + + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, + tensor_model_parallel_rank) + continue + + load_tensor_parallel_weights(param, loaded_weight, name, + column_parallel_weights, + row_parallel_weights, + tensor_model_parallel_rank) diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py index df67758f7110..5fb6547ffeb0 100644 --- a/vllm/model_executor/quantization_utils/__init__.py +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -1,10 +1,12 @@ from typing import Type from vllm.model_executor.quantization_utils.awq import AWQConfig +from vllm.model_executor.quantization_utils.smoothquant import SmoothQuantConfig from vllm.model_executor.quantization_utils.base import QuantizationConfig _QUANTIZATION_REGISTRY = { "awq": AWQConfig, + "smoothquant": SmoothQuantConfig, } diff --git a/vllm/model_executor/quantization_utils/smoothquant.py b/vllm/model_executor/quantization_utils/smoothquant.py new file mode 100644 index 000000000000..d516286c2a33 --- /dev/null +++ b/vllm/model_executor/quantization_utils/smoothquant.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.quantization_utils.base import QuantizationConfig + + +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant + + Reference: https://github.com/mit-han-lab/smoothquant + """ + + def __init__( + self, + weight_bits: int = 8, + quant_type: str = "tensor" + ) -> None: + self.weight_bits = weight_bits + self.quant_type = quant_type + + if self.weight_bits != 8: + raise ValueError( + "Currently, only w8a8 quantization is supported for " + f"SmoothQuant, but got {self.weight_bits} bits.") + if self.quant_type != "tensor": + raise ValueError( + "Currently, only tensor wise quantization is supported for " + f"SmoothQuant, but got {self.quant_type} type quantization.") + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, " + f"quant_type={self.quant_type})") + + @classmethod + def get_name(cls) -> str: + return "smoothquant" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.float] + + @classmethod + def get_min_capability(cls) -> int: + # The smoothquant kernel only supports Ampere or newer GPUs. + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + quant_type = cls.get_from_keys(config, ["quant_type", "q_type"]) + return cls(weight_bits, quant_type) + + @classmethod + def get_packed_tensor_names(cls) -> List[str]: + return [] + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + return ["weight", "bias"] + + @classmethod + def get_tp_tensor_names(cls) -> List[str]: + return ["weight", "bias"] + diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index cb5579f93089..e9237e5cc71b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.model_executor import get_model, get_quant_model_v2, InputMetadata, set_random_seed, get_quant_model_kv +from vllm.model_executor import get_model, InputMetadata, set_random_seed, get_quant_model_kv from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sampling_params import SamplingParams From e6f45ffc2ed58e679c7e734c80472a9fb91f179c Mon Sep 17 00:00:00 2001 From: sleepcoo Date: Thu, 21 Sep 2023 16:54:39 +0800 Subject: [PATCH 13/52] Reduce alpha,beta unnecessary d2h --- .../layers/int8_linear/w8a8linear.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index 0d77e70b519d..2825d62548c8 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -22,6 +22,12 @@ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): self.register_buffer('a', torch.tensor(alpha)) self.register_buffer('b', torch.tensor(beta)) + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + self.b = self.b.cpu() + return self + def to(self, *args, **kwargs): super().to(*args, **kwargs) self.weight = self.weight.to(*args, **kwargs) @@ -72,6 +78,12 @@ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0, inscale=1.0, self.register_buffer('inscale', torch.tensor(inscale)) self.register_buffer('ouscale', torch.tensor(ouscale)) + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + self.b = self.b.cpu() + return self + def to(self, *args, **kwargs): super().to(*args, **kwargs) self.weight = self.weight.to(*args, **kwargs) @@ -121,6 +133,7 @@ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): def _apply(self, fn): # prevent the bias from being converted to half super()._apply(fn) + self.a = self.a.cpu() self.bias = self.bias.to(torch.float32) return self @@ -174,6 +187,7 @@ def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): def _apply(self, fn): # prevent the bias from being converted to half super()._apply(fn) + self.a = self.a.cpu() self.bias = self.bias.to(torch.float32) return self From 96c10ca35bb99f55595674f05f8306c9b892bb1b Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Thu, 21 Sep 2023 14:14:04 +0800 Subject: [PATCH 14/52] fix weight load --- vllm/model_executor/model_loader.py | 6 +++++- vllm/model_executor/models/llama.py | 14 ++++---------- .../quantization_utils/smoothquant.py | 5 ++++- vllm/worker/worker.py | 4 ++-- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 4bb0084dce45..c9557e7e9547 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -59,7 +59,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + rank: int) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the quantization config. @@ -89,6 +91,8 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. + num_layers = model_config.get_num_layers(parallel_config) + kv_quant_params_list = [] if model_config.quant_kv_cache: for i in range(num_layers): path = model_config.kv_quant_params_path + f"/layers.{i}.past_kv_scale.{rank}.weight" diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7b5d5cd04889..d5257420ca59 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -63,17 +63,9 @@ def __init__( super().__init__() if quant_config is not None and quant_config.get_name() == "smoothquant": self.gate_up_proj = W8A8BFP32OFP32LinearWithSFactor(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - perform_initialization=False, - quant_config=quant_config) + 2 * intermediate_size) self.down_proj = W8A8BFP32OFP32LinearWithSFactor(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False, - quant_config=quant_config) + hidden_size) else: self.gate_up_proj = ParallelLinear.column(hidden_size, 2 * intermediate_size, @@ -496,6 +488,7 @@ def _load_int8_weights(self, shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] param_slice = param.data[offset:offset + shard_size] + print(f"{name} param shape: {param.shape} param_slice shape:{param_slice.shape} weight shape:{loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) @@ -518,6 +511,7 @@ def _load_int8_weights(self, (tensor_model_parallel_rank + 1)] param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] + print(f"{name} param shape: {param.shape} param_slice shape:{param_slice.shape} weight shape:{loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True diff --git a/vllm/model_executor/quantization_utils/smoothquant.py b/vllm/model_executor/quantization_utils/smoothquant.py index d516286c2a33..09ddb40242e5 100644 --- a/vllm/model_executor/quantization_utils/smoothquant.py +++ b/vllm/model_executor/quantization_utils/smoothquant.py @@ -48,7 +48,10 @@ def get_min_capability(cls) -> int: @classmethod def get_config_filenames(cls) -> List[str]: """List of filenames to search for in the model directory.""" - raise NotImplementedError + return [ + "quant_config.json", + "quantize_config.json", + ] @classmethod def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e9237e5cc71b..5bed88727b0a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -64,8 +64,8 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) - # self.model = get_model(self.model_config) - self.model = get_quant_model_kv(self.model_config, self.parallel_config, self.rank) + self.model = get_model(self.model_config, self.parallel_config, self.rank) + # self.model = get_quant_model_kv(self.model_config, self.parallel_config, self.rank) @torch.inference_mode() def profile_num_available_blocks( From 4be7d834b3533a3af813cd121c7603780ed60e88 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Fri, 22 Sep 2023 15:01:58 +0800 Subject: [PATCH 15/52] fix weight load --- vllm/model_executor/layers/layernorm.py | 28 ++++++++ vllm/model_executor/models/llama.py | 64 +++++++++++++++---- .../quantization_utils/smoothquant.py | 2 +- vllm/worker/worker.py | 1 - 4 files changed, 81 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 731bc7cbf53f..fced7b42cb07 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -30,3 +30,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.variance_epsilon, ) return out + +class I8RMSNorm(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + layernorm_ops.rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + # TODO: kernel fusion + q_out = out.round().clamp(-128, 127).to(torch.int8) + return q_out diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d5257420ca59..54d95f19e41b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -33,11 +33,11 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm, I8RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear -from vllm.model_executor.layers.int8_linear.w8a8linear import W8A8BFP32OFP32LinearWithSFactor +from vllm.model_executor.layers.int8_linear.w8a8linear import W8A8BFP32OFP32LinearWithSFactor, W8A8BFP32OFP32Linear from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -62,7 +62,7 @@ def __init__( ) -> None: super().__init__() if quant_config is not None and quant_config.get_name() == "smoothquant": - self.gate_up_proj = W8A8BFP32OFP32LinearWithSFactor(hidden_size, + self.gate_up_proj = W8A8BFP32OFP32Linear(hidden_size, 2 * intermediate_size) self.down_proj = W8A8BFP32OFP32LinearWithSFactor(intermediate_size, hidden_size) @@ -86,8 +86,10 @@ def __init__( def forward(self, x): gate_up, _ = self.gate_up_proj(x) + gate_up = gate_up.half() x = self.act_fn(gate_up) x, _ = self.down_proj(x) + x = x.half() return x @@ -119,10 +121,9 @@ def __init__( self.rope_theta = rope_theta if quant_config is not None and quant_config.get_name() == "smoothquant": - self.qkv_proj = W8A8BFP32OFP32LinearWithSFactor( + self.qkv_proj = W8A8BFP32OFP32Linear( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim) + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) self.o_proj = W8A8BFP32OFP32LinearWithSFactor( self.total_num_heads * self.head_dim, hidden_size) @@ -162,11 +163,13 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) + qkv = qkv.half() q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) output, _ = self.o_proj(attn_output) + output = output.half() return output @@ -198,9 +201,9 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, + self.input_layernorm = I8RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, + self.post_attention_layernorm = I8RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -449,6 +452,22 @@ def _load_int8_weights(self, kv_proj_shard_size = (self.config.hidden_size // self.config.num_attention_heads * self.config.num_key_value_heads // tp_size) + + if self.quant_config is None: + weight_suffixes = ["weight"] + else: + weight_suffixes = self.quant_config.get_tp_tensor_names() + + column_parallel_weights: List[str] = [] + for layer in self._column_parallel_layers: + for suffix in weight_suffixes: + column_parallel_weights.append(f"{layer}.{suffix}") + + row_parallel_weights: List[str] = [] + for layer in self._row_parallel_layers: + for suffix in weight_suffixes: + row_parallel_weights.append(f"{layer}.{suffix}") + attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), @@ -462,6 +481,10 @@ def _load_int8_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + print(f"{name} origin weight shape: {loaded_weight.shape}") + # bias is useless for llama + if "bias" in name: + continue is_packed = False is_transposed = False @@ -477,12 +500,17 @@ def _load_int8_weights(self, if weight_name not in name: continue param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T + # if is_transposed: + # param = param.T if is_packed: shard_size //= self.quant_config.pack_factor offset //= self.quant_config.pack_factor + + if "proj.a" in name: + param.copy_(loaded_weight) + is_attention_weight = True + continue loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * @@ -502,13 +530,20 @@ def _load_int8_weights(self, if weight_name not in name: continue param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T + # if is_transposed: + # loaded_weight = loaded_weight.T + + if "proj.a" in name: + param.copy_(loaded_weight) + is_gate_up_weight = True + continue shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] + if is_transposed: + loaded_weight = loaded_weight.T param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] print(f"{name} param shape: {param.shape} param_slice shape:{param_slice.shape} weight shape:{loaded_weight.shape}") @@ -523,6 +558,11 @@ def _load_int8_weights(self, if is_transposed: param = param.T + #copy down and out pro + if "proj.a" in name or "bias" in name or "inscale" in name: + param.copy_(loaded_weight) + continue + if "embed_tokens" in name or "lm_head" in name: load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank) diff --git a/vllm/model_executor/quantization_utils/smoothquant.py b/vllm/model_executor/quantization_utils/smoothquant.py index 09ddb40242e5..2b68974553e2 100644 --- a/vllm/model_executor/quantization_utils/smoothquant.py +++ b/vllm/model_executor/quantization_utils/smoothquant.py @@ -69,5 +69,5 @@ def get_transposed_tensor_names(cls) -> List[str]: @classmethod def get_tp_tensor_names(cls) -> List[str]: - return ["weight", "bias"] + return ["weight"] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5bed88727b0a..96123c7ead2e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -65,7 +65,6 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) self.model = get_model(self.model_config, self.parallel_config, self.rank) - # self.model = get_quant_model_kv(self.model_config, self.parallel_config, self.rank) @torch.inference_mode() def profile_num_available_blocks( From be6f7b8221fb984cce2762c0c1a13366726ea53b Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Fri, 22 Sep 2023 17:45:56 +0800 Subject: [PATCH 16/52] fix ln layer init --- vllm/model_executor/__init__.py | 5 ++--- vllm/model_executor/model_loader.py | 17 ----------------- vllm/model_executor/models/llama.py | 17 +++++++++++------ vllm/worker/worker.py | 2 +- 4 files changed, 14 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 81f0eeef2397..e1c687aa5aef 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,10 +1,9 @@ from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.model_loader import get_model, get_quant_model_kv +from vllm.model_executor.model_loader import get_model from vllm.model_executor.utils import set_random_seed __all__ = [ "InputMetadata", "get_model", - "set_random_seed", - "get_quant_model_kv" + "set_random_seed" ] diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index c9557e7e9547..51b64c7bedaa 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -113,20 +113,3 @@ def get_model(model_config: ModelConfig, model_config.load_format, model_config.revision) model = model.cuda() return model.eval() - - -def get_quant_model_kv(model_config: ModelConfig, parallel_config: ParallelConfig, - rank: int): - num_layers = model_config.get_num_layers(parallel_config) - ## num_layers * [k_scale, k_zp, v_scale, v_zp] - kv_quant_params_list = [] - if model_config.quant_kv_cache: - for i in range(num_layers): - path = model_config.kv_quant_params_path + f"/layers.{i}.past_kv_scale.{rank}.weight" - kv_quant_params = list(np.fromfile(path, dtype=np.float32)) - kv_quant_params_list.append(kv_quant_params) - model_class = _get_model_architecture(model_config.hf_config) - torch.set_default_dtype(model_config.dtype) - model = model_class(model_config.hf_config, None, model_config.quant_kv_cache, kv_quant_params_list) ## None is for quant config - model = model.cuda() - return model.eval() \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 54d95f19e41b..55821592aa65 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -201,10 +201,16 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, ) - self.input_layernorm = I8RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = I8RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + if quant_config is not None and quant_config.get_name() == "smoothquant": + self.input_layernorm = I8RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = I8RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -251,8 +257,7 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) - # print(kv_quant_params_list) - # print(quant_kv_cache) + self.layers = nn.ModuleList([ LlamaDecoderLayer(config, quant_config, quant_kv_cache, kv_quant_params_list[i] if quant_kv_cache else None) for i in range(config.num_hidden_layers) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 96123c7ead2e..321f352ab0ad 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.model_executor import get_model, InputMetadata, set_random_seed, get_quant_model_kv +from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sampling_params import SamplingParams From 5ffc537da9a5db3a959ba395f4317be1c09144bb Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Tue, 26 Sep 2023 19:00:24 +0800 Subject: [PATCH 17/52] rms norm fusion --- csrc/layernorm.cpp | 21 ++-- csrc/layernorm_kernels.cu | 128 ++++++++++++++++++------ csrc/quant_utils.cuh | 32 ++++++ vllm/model_executor/layers/layernorm.py | 23 +++-- 4 files changed, 156 insertions(+), 48 deletions(-) diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp index 749ca5f92154..c1092b9583e7 100644 --- a/csrc/layernorm.cpp +++ b/csrc/layernorm.cpp @@ -1,14 +1,17 @@ #include -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); +void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, + float epsilon); + +void invoke_rms_norm_quant(torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + m.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + m.def("invoke_rms_norm_quant", &invoke_rms_norm_quant, + "Apply Root Mean Square (RMS) Normalization to the input tensor and " + "quant output."); } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f932b9e2d615..a20b9d4ddbe0 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,25 +1,25 @@ -#include #include +#include #include "dispatch_utils.h" +#include "quant_utils.cuh" #include "reduction_utils.cuh" namespace vllm { // TODO(woosuk): Further optimize this kernel. -template -__global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [num_tokens, hidden_size] - const scalar_t* __restrict__ input, // [num_tokens, hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ void +rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size] + const scalar_t *__restrict__ input, // [num_tokens, hidden_size] + const scalar_t *__restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float) input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -29,34 +29,104 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +template +__global__ void RMSLayerNorm(const T *__restrict input, + const T *__restrict gamma, int8_t *output, + const float layernorm_eps, int m, int n) { + // layernorm module in the T5 style No bias and no subtraction of mean. + const int tid = threadIdx.x; + + __shared__ float s_variance; + float variance = 0.0f; + + float local_var_sum = 0.0f; + for (int i = tid; i < n; i += blockDim.x) { + // float diff = (float)(ldg(&input[blockIdx.x * n + i])); + float diff = (float)(input[blockIdx.x * n + i]); + local_var_sum += diff * diff; + } + variance = blockReduceSum(local_var_sum); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (float)n + layernorm_eps); + } + __syncthreads(); + + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = + // float_to_int8_rn((((float)input[blockIdx.x * n + i]) * s_variance) * + // (float)(ldg(&gamma[i]))); + float_to_int8_rn((((float)input[blockIdx.x * n + i]) * s_variance) * + (float)(gamma[i])); } } +template +void invokeRMSLayerNorm(int8_t *out, const T *input, const T *gamma, + // const T* beta, + const float layernorm_eps, const int m, const int n, + cudaStream_t stream) { + // if (beta != nullptr) { + // invokeGeneralLayerNorm(out, input, gamma, beta, layernorm_eps, m, n, + // (float*)nullptr, 0, stream); return; + // } + + dim3 grid(m); + dim3 block(min(n, 1024)); + + /* For general cases, n is equal to hidden_units, e.g., 512/1024. + Since we have warp shuffle inside the code, block.x % 32 should be 0. + */ + if (n % 32 != 0) { + block.x = 1024; + } + + block.x = + block.x / (4 / sizeof(T)); // if using half, only need half of block.x + + /* should pay attention to the rsqrt precision*/ + RMSLayerNorm<<>>(input, gamma, out, layernorm_eps, + m, n); // For gpt-3 +} + } // namespace vllm -void rms_norm( - torch::Tensor& out, // [num_tokens, hidden_size] - torch::Tensor& input, // [num_tokens, hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +void rms_norm(torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &weight, // [hidden_size] + float epsilon) { int num_tokens = input.size(0); int hidden_size = input.size(1); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "rms_norm_kernel", - [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); +} + +void invoke_rms_norm_quant(torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon) { + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "invokeRMSLayerNorm", [&] { + vllm::RMSLayerNorm<<>>( + input.data_ptr(), gamma.data_ptr(), out.data_ptr(), + epsilon, m, n); + }); } diff --git a/csrc/quant_utils.cuh b/csrc/quant_utils.cuh index f2639ba4cf9c..6b84931ac228 100644 --- a/csrc/quant_utils.cuh +++ b/csrc/quant_utils.cuh @@ -233,3 +233,35 @@ __inline__ __device__ bf16_8_t vec_conversion(const Float8_ & b.w = vec_conversion<__nv_bfloat162, float2>(a.w); return b; } + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + +template +inline __device__ T ldg(const T* val) { + return __ldg(val); +} + +#if ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} + +template<> +inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} +#endif // ENABLE_BF16 diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index fced7b42cb07..99f8927a2fe3 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -48,13 +48,16 @@ def __init__( self.variance_epsilon = eps def forward(self, x: torch.Tensor) -> torch.Tensor: - out = torch.empty_like(x) - layernorm_ops.rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - # TODO: kernel fusion - q_out = out.round().clamp(-128, 127).to(torch.int8) - return q_out + # out = torch.empty_like(x) + # layernorm_ops.rms_norm( + # out, + # x, + # self.weight.data, + # self.variance_epsilon, + # ) + # # TODO: kernel fusion + # q_out = out.round().clamp(-128, 127).to(torch.int8) + # return q_out + out = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_rms_norm_quant(out, x, self.weight.data, self.variance_epsilon) + return out From 347397cddd2ab89f1850b3babd391a9f8ce7ccbc Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Tue, 26 Sep 2023 19:01:14 +0800 Subject: [PATCH 18/52] fix w8a8 linear --- vllm/model_executor/layers/int8_linear/w8a8linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index 2825d62548c8..9a452d591e31 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -38,7 +38,6 @@ def to(self, *args, **kwargs): def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) - x = (x / self.inscale).clamp(-128, 127).to(torch.int8) y = i8gemm.linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) y = y.view(*x_shape[:-1], -1) @@ -94,6 +93,7 @@ def to(self, *args, **kwargs): def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) + x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) y = i8gemm.linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) y = y.view(*x_shape[:-1], -1) @@ -207,7 +207,7 @@ def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) # quant activation - x = (x / self.inscale).clamp(-128, 127).to(torch.int8) + x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) self.bias = self.bias.to(torch.float32) y = i8gemm.linear_a8_w8_bfp32_ofp32( x, self.weight, self.bias, self.a.item(), 1) From 030a100fb8d187a3a6e9def99df3771998c0c435 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Tue, 26 Sep 2023 19:01:48 +0800 Subject: [PATCH 19/52] use same scale across tensors --- vllm/model_executor/models/llama.py | 37 +++++++++---------- .../quantization_utils/smoothquant.py | 2 +- vllm/model_executor/weight_utils.py | 12 ------ 3 files changed, 18 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 55821592aa65..5e6a3b3daa66 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -61,7 +61,9 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - if quant_config is not None and quant_config.get_name() == "smoothquant": + self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" + + if self.use_int8: self.gate_up_proj = W8A8BFP32OFP32Linear(hidden_size, 2 * intermediate_size) self.down_proj = W8A8BFP32OFP32LinearWithSFactor(intermediate_size, @@ -79,14 +81,15 @@ def __init__( input_is_parallel=True, perform_initialization=False, quant_config=quant_config) + if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) - gate_up = gate_up.half() + # FIXME: currently gate up share same scale, plan to use seperate scales x = self.act_fn(gate_up) x, _ = self.down_proj(x) x = x.half() @@ -119,8 +122,9 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" - if quant_config is not None and quant_config.get_name() == "smoothquant": + if self.use_int8: self.qkv_proj = W8A8BFP32OFP32Linear( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) @@ -164,6 +168,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) qkv = qkv.half() + # FIXME: currently qkv share same scale, plan to use seperate scales q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, @@ -486,7 +491,6 @@ def _load_int8_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - print(f"{name} origin weight shape: {loaded_weight.shape}") # bias is useless for llama if "bias" in name: continue @@ -505,14 +509,15 @@ def _load_int8_weights(self, if weight_name not in name: continue param = state_dict[name.replace(weight_name, "qkv_proj")] - # if is_transposed: - # param = param.T + if is_transposed: + param = param.T if is_packed: shard_size //= self.quant_config.pack_factor offset //= self.quant_config.pack_factor - if "proj.a" in name: + # share use same scale in quantizatin + if "proj.a" in name or "proj.inscale" in name: param.copy_(loaded_weight) is_attention_weight = True continue @@ -521,7 +526,6 @@ def _load_int8_weights(self, shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] param_slice = param.data[offset:offset + shard_size] - print(f"{name} param shape: {param.shape} param_slice shape:{param_slice.shape} weight shape:{loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) @@ -535,10 +539,11 @@ def _load_int8_weights(self, if weight_name not in name: continue param = state_dict[name.replace(weight_name, "gate_up_proj")] - # if is_transposed: - # loaded_weight = loaded_weight.T + if is_transposed: + loaded_weight = loaded_weight.T - if "proj.a" in name: + # share use same scale in quantizatin + if "proj.a" in name or "proj.inscale" in name: param.copy_(loaded_weight) is_gate_up_weight = True continue @@ -547,11 +552,8 @@ def _load_int8_weights(self, loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] - if is_transposed: - loaded_weight = loaded_weight.T param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] - print(f"{name} param shape: {param.shape} param_slice shape:{param_slice.shape} weight shape:{loaded_weight.shape}") assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True @@ -563,11 +565,6 @@ def _load_int8_weights(self, if is_transposed: param = param.T - #copy down and out pro - if "proj.a" in name or "bias" in name or "inscale" in name: - param.copy_(loaded_weight) - continue - if "embed_tokens" in name or "lm_head" in name: load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank) diff --git a/vllm/model_executor/quantization_utils/smoothquant.py b/vllm/model_executor/quantization_utils/smoothquant.py index 2b68974553e2..1b4f64a94573 100644 --- a/vllm/model_executor/quantization_utils/smoothquant.py +++ b/vllm/model_executor/quantization_utils/smoothquant.py @@ -65,7 +65,7 @@ def get_packed_tensor_names(cls) -> List[str]: @classmethod def get_transposed_tensor_names(cls) -> List[str]: - return ["weight", "bias"] + return [] @classmethod def get_tp_tensor_names(cls) -> List[str]: diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 4c76fbd5268e..74de96842296 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -294,18 +294,6 @@ def load_tensor_parallel_weights( f"{param.shape} != {loaded_weight.shape}") param.data.copy_(loaded_weight) -def load_tensor_parallel_weights2( - param: torch.Tensor, - loaded_weight: torch.Tensor, - param_name: str, - tensor_model_parallel_rank: int, -) -> None: - assert param.shape == loaded_weight.shape, ( - f"{param_name} shape mismatch between model and checkpoint: " - f"{param.shape} != {loaded_weight.shape}") - param.data.copy_(loaded_weight) - - def initialize_dummy_weights( model: torch.nn.Module, From 2805edc3be743b7e8af389a25311f77a283b1ae9 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Wed, 27 Sep 2023 11:55:04 +0800 Subject: [PATCH 20/52] add ftgemm --- csrc/ftgemm/CMakeLists.txt | 65 ++ csrc/ftgemm/allocator.h | 426 +++++++++ csrc/ftgemm/bindings.cpp | 212 +++++ csrc/ftgemm/cublasAlgoMap.cc | 188 ++++ csrc/ftgemm/cublasAlgoMap.h | 108 +++ csrc/ftgemm/cublasINT8MMWrapper.cc | 840 +++++++++++++++++ csrc/ftgemm/cublasINT8MMWrapper.h | 68 ++ csrc/ftgemm/cublasMMWrapper.cc | 851 ++++++++++++++++++ csrc/ftgemm/cublasMMWrapper.h | 177 ++++ csrc/ftgemm/cuda_utils.cc | 381 ++++++++ csrc/ftgemm/cuda_utils.h | 461 ++++++++++ csrc/ftgemm/int8_utils.cuh | 51 ++ csrc/ftgemm/transform_layout.cu | 127 +++ csrc/ftgemm/transform_layout.h | 31 + setup.py | 18 + .../layers/int8_linear/w8a8linear.py | 122 +++ 16 files changed, 4126 insertions(+) create mode 100644 csrc/ftgemm/CMakeLists.txt create mode 100644 csrc/ftgemm/allocator.h create mode 100644 csrc/ftgemm/bindings.cpp create mode 100644 csrc/ftgemm/cublasAlgoMap.cc create mode 100644 csrc/ftgemm/cublasAlgoMap.h create mode 100644 csrc/ftgemm/cublasINT8MMWrapper.cc create mode 100644 csrc/ftgemm/cublasINT8MMWrapper.h create mode 100644 csrc/ftgemm/cublasMMWrapper.cc create mode 100644 csrc/ftgemm/cublasMMWrapper.h create mode 100644 csrc/ftgemm/cuda_utils.cc create mode 100644 csrc/ftgemm/cuda_utils.h create mode 100644 csrc/ftgemm/int8_utils.cuh create mode 100644 csrc/ftgemm/transform_layout.cu create mode 100644 csrc/ftgemm/transform_layout.h diff --git a/csrc/ftgemm/CMakeLists.txt b/csrc/ftgemm/CMakeLists.txt new file mode 100644 index 000000000000..0425e4acb0e3 --- /dev/null +++ b/csrc/ftgemm/CMakeLists.txt @@ -0,0 +1,65 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CUDA_STANDARD 14) + +find_package(CUDA REQUIRED) +find_package(Python REQUIRED) +set(Torch_DIR "/usr/local/lib/python3.9/site-packages/torch/share/cmake/Torch/") +find_package(Torch REQUIRED) +set(pybind11_DIR "/usr/local/lib/python3.9/site-packages/pybind11/share/cmake/pybind11") +find_package(pybind11 REQUIRED) + + +include_directories(${CUDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS}) +set(CUDA_LIBRARIES "/usr/local/cuda/lib64") +set(TORCH_LIBRARIES "/usr/local/lib/python3.9/site-packages/torch/lib") +link_directories(${CUDA_LIBRARIES} ${TORCH_LIBRARIES}) + +add_library(cuda_utils STATIC cuda_utils.cc) +set_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(cuda_utils PUBLIC -lcudart) + +add_library(cublasAlgoMap STATIC cublasAlgoMap.cc) +set_property(TARGET cublasAlgoMap PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET cublasAlgoMap PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(cublasAlgoMap PUBLIC -lcublas -lcudart -lcurand cuda_utils) + +add_library(cublasMMWrapper STATIC cublasMMWrapper.cc) +set_property(TARGET cublasMMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET cublasMMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(cublasMMWrapper PUBLIC -lcublas -lcudart -lcurand cublasAlgoMap cuda_utils) + +add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) +set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(cublasINT8MMWrapper PUBLIC -lcublasLt -lcudart -lcurand -lcublas cublasAlgoMap cublasMMWrapper cuda_utils) + +add_library(transformLayout STATIC transform_layout.cu) +set_property(TARGET transformLayout PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET transformLayout PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(transformLayout PUBLIC -lcudart -lcurand -lcublas) + + +find_package(Python COMPONENTS Interpreter Development REQUIRED) +pybind11_add_module(int8_gemm MODULE bindings.cpp) +target_link_libraries(int8_gemm PUBLIC -lpython3.9 -ltorch -ltorch_python -lcudart cublasINT8MMWrapper cublasAlgoMap transformLayout) + + + + + diff --git a/csrc/ftgemm/allocator.h b/csrc/ftgemm/allocator.h new file mode 100644 index 000000000000..82e7da567b9b --- /dev/null +++ b/csrc/ftgemm/allocator.h @@ -0,0 +1,426 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Memory Allocator + **/ + +#pragma once + +#include "cuda_utils.h" +#include +#include +#include + +#ifdef GOOGLE_CUDA +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#endif + +#ifdef TORCH_CUDA +#include "torch/extension.h" +#include +#endif + +// #include "src/fastertransformer/utils/logger.h" + +#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 +#define CUDA_MEMORY_POOL_DISABLED +#endif + +enum class AllocatorType { CUDA, TF, TH }; + +enum class ReallocType { + INCREASE, + REUSE, + DECREASE, +}; + +class IAllocator { +public: + virtual ~IAllocator(){}; + + virtual void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) = 0; + virtual void free(void **ptr, bool is_host = false) const = 0; + virtual void setStream(cudaStream_t stream) = 0; + virtual cudaStream_t returnStream() = 0; + virtual void memSet(void *ptr, const int val, const size_t size) = 0; + + template + void *reMalloc(T *ptr, size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + size = ((size + 31) / 32) * 32; // make the buffer align with 32 bytes + void *void_ptr = (void *)ptr; + void *ptr_address = getAddress(void_ptr); + if (isExist(ptr_address)) { + ReallocType realloc_type = isReMalloc(ptr_address, size); + if (realloc_type == ReallocType::INCREASE) { + // FT_LOG_DEBUG("ReMalloc the buffer %p since it is too small.", + // void_ptr); + free((void **)(&void_ptr), is_host); + return malloc(size, is_set_zero, is_host); + } +#if !defined(CUDA_MEMORY_POOL_DISABLED) + else if (realloc_type == ReallocType::DECREASE) { + // FT_LOG_DEBUG("ReMalloc the buffer %p to release unused memory to + // memory pools.", void_ptr); + free((void **)(&void_ptr), is_host); + return malloc(size, is_set_zero, is_host); + } +#endif + else { + // FT_LOG_DEBUG("Reuse original buffer %p with size %d and do nothing + // for reMalloc.", void_ptr, size); + if (is_set_zero) { + memSet(void_ptr, 0, size); + } + return void_ptr; + } + } else { + // FT_LOG_DEBUG("Cannot find buffer %p, mallocing new one.", void_ptr); + return malloc(size, is_set_zero, is_host); + } + } + +protected: + virtual bool isExist(void *address) const = 0; + virtual ReallocType isReMalloc(void *address, size_t size) const = 0; + + void *getAddress(void *ptr) const { return ptr; } +}; + +template class Allocator; + +template <> class Allocator : public IAllocator { +private: + const int device_id_; + cudaStream_t stream_ = 0; // initialize as default stream + std::unordered_map *pointer_mapping_; + + bool isExist(void *address) const { + return pointer_mapping_->count(address) > 0; + } + ReallocType isReMalloc(void *address, size_t size) const { + FT_CHECK(isExist(address)); + if (pointer_mapping_->at(address) < size) { + return ReallocType::INCREASE; + } else if (pointer_mapping_->at(address) == size) { + return ReallocType::REUSE; + } else { + return ReallocType::DECREASE; + } + } + +public: + Allocator(int device_id) : device_id_(device_id) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + pointer_mapping_ = new std::unordered_map(); +#if defined(CUDA_MEMORY_POOL_DISABLED) + // FT_LOG_WARNING( + // "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync + // cudaMalloc/Free." "Note this may lead to hang with NCCL kernels + // launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); +#else + int device_count = 1; + check_cuda_error(cudaGetDeviceCount(&device_count)); + cudaMemPool_t mempool; + check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); + cudaMemAccessDesc desc = {}; + int peer_access_available = 0; + for (int i = 0; i < device_count; i++) { + if (i == device_id) { + continue; + } + check_cuda_error( + cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); + if (!peer_access_available) { + // FT_LOG_WARNING("Device " + std::to_string(device_id) + " peer access + // Device " + std::to_string(i) + // + " is not available."); + continue; + } + desc.location.type = cudaMemLocationTypeDevice; + desc.location.id = i; + desc.flags = cudaMemAccessFlagsProtReadWrite; + check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); + } + // set memory pool threshold to avoid shrinking the pool + uint64_t setVal = UINT64_MAX; + check_cuda_error(cudaMemPoolSetAttribute( + mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); +#endif + } + + virtual ~Allocator() { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + while (!pointer_mapping_->empty()) { + free((void **)(&pointer_mapping_->begin()->first)); + } + delete pointer_mapping_; + } + + void setStream(cudaStream_t stream) { stream_ = stream; } + + cudaStream_t returnStream() { return stream_; }; + + void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (size == 0) { + return nullptr; + } + void *ptr = nullptr; + int o_device = 0; + + check_cuda_error(getSetDevice(device_id_, &o_device)); + if (is_host) { + check_cuda_error(cudaMallocHost(&ptr, (size_t)(ceil(size / 32.)) * 32)); + } else { +#if defined(CUDA_MEMORY_POOL_DISABLED) + check_cuda_error(cudaMalloc(&ptr, (size_t)(ceil(size / 32.)) * 32)); +#else + check_cuda_error( + cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_)); +#endif + } + if (is_set_zero) { + check_cuda_error( + cudaMemsetAsync(ptr, 0, (size_t)(ceil(size / 32.)) * 32, stream_)); + } + check_cuda_error(getSetDevice(o_device)); + // FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); + + pointer_mapping_->insert({getAddress(ptr), size}); + + return ptr; + } + + void free(void **ptr, bool is_host = false) const { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + void *address = getAddress(*ptr); + if (*ptr != nullptr) { + int o_device = 0; + if (pointer_mapping_->count(address)) { + // FT_LOG_DEBUG("Free buffer %p", address); + check_cuda_error(getSetDevice(device_id_, &o_device)); + if (is_host) { + check_cuda_error(cudaFreeHost(*ptr)); + } else { +#if defined(CUDA_MEMORY_POOL_DISABLED) + check_cuda_error(cudaFree(*ptr)); +#else + check_cuda_error(cudaFreeAsync(*ptr, stream_)); + cudaStreamSynchronize(stream_); +#endif + } + check_cuda_error(getSetDevice(o_device)); + pointer_mapping_->erase(address); + } else { + // FT_LOG_WARNING("pointer_mapping_ does not have information of ptr at + // %p.", address); + } + } + *ptr = nullptr; + return; + } + + void memSet(void *ptr, const int val, const size_t size) { + check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_)); + } +}; + +#ifdef GOOGLE_CUDA +using namespace tensorflow; +template <> class Allocator : public IAllocator { + OpKernelContext *context_; + std::unordered_map *pointer_mapping_; + cudaStream_t stream_; + + bool isExist(void *address) const { + return pointer_mapping_->count(address) > 0; + } + ReallocType isReMalloc(void *address, size_t size) const { + FT_CHECK(isExist(address)); + size_t current_buffer_size = 1; + for (int i = 0; i < pointer_mapping_->at(address).dims(); i++) { + current_buffer_size *= pointer_mapping_->at(address).dim_size(i); + } + // FT_LOG_DEBUG("current_buffer_size: %d, new buffer: %d", + // current_buffer_size, size); + if (current_buffer_size < size) { + return ReallocType::INCREASE; + } else if (current_buffer_size == size) { + return ReallocType::REUSE; + } else { + return ReallocType::DECREASE; + } + } + +public: + Allocator(OpKernelContext *context, cudaStream_t stream) + : context_(context), stream_(stream) { + pointer_mapping_ = new std::unordered_map(); + } + + void setStream(cudaStream_t stream) { stream_ = stream; } + + cudaStream_t returnStream() { return stream_; }; + + void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + tensorflow::Tensor buf; + long long int buf_size = ((long long int)ceil(size / 32.) * 32); + tensorflow::Status status; + if (is_host) { + tensorflow::AllocatorAttributes pinned_allocator; + pinned_allocator.set_on_host(true); + pinned_allocator.set_gpu_compatible(true); + status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf, + pinned_allocator); + } else { + status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf); + } + + if (status != tensorflow::Status::OK()) { + throw std::runtime_error("TF error: context->allocate_temp failed"); + } + + auto flat = buf.flat(); + void *ptr = (void *)flat.data(); + if (is_set_zero) { + cudaMemsetAsync(ptr, 0, buf_size, stream_); + } + pointer_mapping_->insert({getAddress(ptr), buf}); + + return ptr; + } + + void free(void **ptr, bool is_host = false) const { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + void *address = getAddress(*ptr); + pointer_mapping_->erase(address); + *ptr = nullptr; + return; + } + + virtual ~Allocator() { + while (!pointer_mapping_->empty()) { + void *ptr = pointer_mapping_->begin()->second.flat().data(); + free((void **)(&ptr)); + } + pointer_mapping_->clear(); + delete pointer_mapping_; + } + + void memSet(void *ptr, const int val, const size_t size) { + check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_)); + } +}; +#endif + +#ifdef TORCH_CUDA +template <> class Allocator : public IAllocator { + std::unordered_map *pointer_mapping_; + + bool isExist(void *address) const { + return pointer_mapping_->count(address) > 0; + } + ReallocType isReMalloc(void *address, size_t size) const { + FT_CHECK(isExist(address)); + size_t current_buffer_size = 1; + for (int i = 0; i < pointer_mapping_->at(address).dim(); i++) { + current_buffer_size *= pointer_mapping_->at(address).size(i); + } + // FT_LOG_DEBUG( + // "current_buffer_size: %d, original buffer: %p, new buffer: %d", + // current_buffer_size, address, size); + if (current_buffer_size < size) { + return ReallocType::INCREASE; + } else if (current_buffer_size == size) { + return ReallocType::REUSE; + } else { + return ReallocType::DECREASE; + } + } + +public: + Allocator() { + pointer_mapping_ = new std::unordered_map(); + } + + void setStream(cudaStream_t stream) { + // nothing to do here; + } + + cudaStream_t returnStream() { + // nothing to do here; + return 0; + }; + + void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int64_t buf_size = static_cast(ceil(size / 32.)) * 32; + torch::Tensor buf; + if (is_host) { + buf = torch::empty( + {buf_size}, + torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true)); + } else { + buf = torch::empty({buf_size}, + torch::dtype(torch::kUInt8).device(torch::kCUDA)); + } + void *ptr = buf.data_ptr(); + if (is_set_zero) { + cudaMemset(ptr, 0, buf_size); + } + // FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, buf_size); + pointer_mapping_->insert({getAddress(ptr), buf}); + return ptr; + } + + void free(void **ptr, bool is_host = false) const { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + void *address = getAddress(*ptr); + pointer_mapping_->erase(address); + *ptr = nullptr; + return; + } + + virtual ~Allocator() { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + while (!pointer_mapping_->empty()) { + void *ptr = pointer_mapping_->begin()->second.data_ptr(); + free((void **)(&ptr)); + } + pointer_mapping_->clear(); + delete pointer_mapping_; + } + + void memSet(void *ptr, const int val, const size_t size) { + check_cuda_error(cudaMemset(ptr, val, size)); + } +}; +#endif diff --git a/csrc/ftgemm/bindings.cpp b/csrc/ftgemm/bindings.cpp new file mode 100644 index 000000000000..9623a37c4dd5 --- /dev/null +++ b/csrc/ftgemm/bindings.cpp @@ -0,0 +1,212 @@ +#include +#include +#include "cublasAlgoMap.h" +#include "cublasINT8MMWrapper.h" +#include "transform_layout.h" + +class FTGEMM { +private: + cublasINT8MMWrapper *int8_gemm_wrapper = nullptr; + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +public: + FTGEMM(); + ~FTGEMM(); + + void linear_a8_w8_o32(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output); + void linear_a8_w8_o32_(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output); + void linear_a8_w8_o8(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output, float alpha); + void linear_a8_w8_o8_(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output, float alpha); + void linear_a8_w8_ofp32(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output, float alpha); + void transform_row_to_col32(torch::Tensor &input, torch::Tensor &out); + void transform_col32_to_row(torch::Tensor &input, torch::Tensor &out); + void transform_row_to_ampere(torch::Tensor &input, torch::Tensor &out); + void transform_row_to_turing(torch::Tensor &input, torch::Tensor &out); + +}; +FTGEMM::FTGEMM() { + // cublasAlgoMap *cublas_algo_map = new cublasAlgoMap("igemm_config.in"); + cublasAlgoMap *cublas_algo_map = new cublasAlgoMap(); + std::mutex *cublas_wrapper_mutex = new std::mutex(); + bool use_ORDER_COL32_2R_4R4 = true; + + // const cudaStream_t stream; + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cublasLtHandle_t cublaslt_handle; +// cudaStreamCreate(&stream); + cublasLtCreate(&cublaslt_handle); + + int8_gemm_wrapper = + new cublasINT8MMWrapper(cublaslt_handle, this->stream, cublas_algo_map, + cublas_wrapper_mutex, use_ORDER_COL32_2R_4R4); +} + +FTGEMM::~FTGEMM() {} + +void FTGEMM::linear_a8_w8_o32(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out // INT32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int32_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, + weight_ptr); +} + +void FTGEMM::linear_a8_w8_o32_(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out // INT32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int32_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, + weight_ptr); +} + +void FTGEMM::linear_a8_w8_o8(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int8_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} + +void FTGEMM::linear_a8_w8_o8_(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int8_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} + +void FTGEMM::linear_a8_w8_ofp32(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + float *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_f(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} + +void FTGEMM::transform_row_to_col32(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeRowMajorToCOL32(out_ptr, input_ptr, m, n, this->stream); + // invokeRowMajorToCOL32(out_ptr, input_ptr, m, n, stream); +} + +void FTGEMM::transform_col32_to_row(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, this->stream); + // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); +} + +void FTGEMM::transform_row_to_ampere(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeRowMajorToAmpere(out_ptr, input_ptr, m, n, this->stream); + // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); +} + +void FTGEMM::transform_row_to_turing(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeRowMajorToTuring(out_ptr, input_ptr, m, n, this->stream); + // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + pybind11::class_(m, "FTGEMM") + .def(pybind11::init<>()) + .def("linear_a8_w8_o32", &FTGEMM::linear_a8_w8_o32) + .def("linear_a8_w8_o8", &FTGEMM::linear_a8_w8_o8) + .def("linear_a8_w8_o8_", &FTGEMM::linear_a8_w8_o8_) + .def("linear_a8_w8_o32_", &FTGEMM::linear_a8_w8_o32_) + .def("linear_a8_w8_ofp32", &FTGEMM::linear_a8_w8_ofp32) + .def("transform_row_to_col32", &FTGEMM::transform_row_to_col32) + .def("transform_col32_to_row", &FTGEMM::transform_col32_to_row) + .def("transform_row_to_ampere", &FTGEMM::transform_row_to_ampere) + .def("transform_row_to_turing", &FTGEMM::transform_row_to_turing); +} diff --git a/csrc/ftgemm/cublasAlgoMap.cc b/csrc/ftgemm/cublasAlgoMap.cc new file mode 100644 index 000000000000..61e41438c6a8 --- /dev/null +++ b/csrc/ftgemm/cublasAlgoMap.cc @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasAlgoMap.h" + +cublasAlgoMap::cublasAlgoMap(const std::string filename, + const std::string sp_config_filename) + : config_filename_(filename), sp_config_filename_(sp_config_filename) { + loadGemmConfig(); + loadSpGemmConfig(); +} + +cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap &algo_map) + : config_filename_(algo_map.config_filename_), + sp_config_filename_(algo_map.sp_config_filename_), + algo_map_(algo_map.algo_map_), sp_algo_map_(algo_map.sp_algo_map_) {} + +cublasAlgoMap::~cublasAlgoMap() { algo_map_.clear(); } + +void cublasAlgoMap::loadGemmConfig() { + FILE *fd; + fd = fopen(config_filename_.c_str(), "r"); + if (fd == NULL) { + std::cout << "[WARNING] " << config_filename_ + << " is not found; using default GEMM algo" << std::endl; + return; + } + + int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val; + int batch_size, seq_len, head_num, size_per_head, dataType; + int swizzle, reductionScheme, workspaceSize, stages; + int inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode; + float exec_time; + char tmp[1024]; + if (!fgets(tmp, 1024, fd)) { + printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); + exit(-1); + } + while (fscanf(fd, + "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d " +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + "%d %d " +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + "%d %d %d " +#endif + "%f\n", + &batch_size, &seq_len, &head_num, &size_per_head, &dataType, + &batchCount2, &n2, &m2, &k2, &algoId, &customOption, &tile, + &splitK_val, &swizzle, &reductionScheme, &workspaceSize, + &stages, +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + &inner_shapeId, &cluster_shapeId, +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + &mma_shapeId, &cga_shapeId, &sche_mode, +#endif + &exec_time) != EOF) { + if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && + dataType != BFLOAT16_DATATYPE && dataType != INT8_DATATYPE && + dataType != FP8_DATATYPE) { + printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType); + continue; + } + cublasAlgoConfig_t markStr{batchCount2, m2, n2, k2, + static_cast(dataType)}; + // workspaceSize should be zero + if (algo_map_.find(markStr) == algo_map_.end()) { + algo_map_[markStr].algoId = algoId; + algo_map_[markStr].customOption = customOption; + algo_map_[markStr].tile = tile; + algo_map_[markStr].splitK_val = splitK_val; + algo_map_[markStr].swizzle = swizzle; + algo_map_[markStr].reductionScheme = reductionScheme; + algo_map_[markStr].workspaceSize = workspaceSize; + algo_map_[markStr].stages = stages; +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + algo_map_[markStr].inner_shapeId = (uint16_t)inner_shapeId; + algo_map_[markStr].cluster_shapeId = (uint16_t)cluster_shapeId; +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + algo_map_[markStr].mma_shapeId = (uint16_t)mma_shapeId; + algo_map_[markStr].cga_shapeId = (uint16_t)cga_shapeId; + algo_map_[markStr].sche_mode = (uint16_t)sche_mode; +#endif + algo_map_[markStr].exec_time = exec_time; + } + } + fclose(fd); +} + +bool cublasAlgoMap::isExist(const int batch_count, const int m, const int n, + const int k, const CublasDataType data_type) { + cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; + return algo_map_.find(mark) != algo_map_.end(); +} + +cublasLtMatmulAlgo_info cublasAlgoMap::getAlgo(const int batch_count, + const int m, const int n, + const int k, + const CublasDataType data_type) { + cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; + if (algo_map_.find(mark) != algo_map_.end()) { + return algo_map_[mark]; + } else { + cublasLtMatmulAlgo_info tmp_algo; + tmp_algo.algoId = static_cast(data_type == FLOAT_DATATYPE + ? CUBLAS_GEMM_DEFAULT + : CUBLAS_GEMM_DEFAULT_TENSOR_OP); + tmp_algo.customOption = -1; + tmp_algo.tile = -1; + tmp_algo.splitK_val = -1; + tmp_algo.swizzle = -1; + tmp_algo.reductionScheme = -1; + tmp_algo.workspaceSize = -1; + tmp_algo.stages = -1; + tmp_algo.exec_time = -1.0f; + return tmp_algo; + } +} + +void cublasAlgoMap::loadSpGemmConfig() { + if (sp_config_filename_.empty()) { + return; + } + FILE *fd = fopen(sp_config_filename_.c_str(), "r"); + if (fd == NULL) { + printf("[WARNING] %s is not found; using SPGEMM algo id 0\n", + sp_config_filename_.c_str()); + return; + } + sp_algo_map_.clear(); + int batch_size, seq_len, head_num, size_per_head, data_type; + int batchCount, m, n, k, algoId; + float exec_time; + char tmp[1024]; + if (!fgets(tmp, 1024, fd)) { + printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); + exit(-1); + } + while (fscanf(fd, "%d %d %d %d %d ### %d %d %d %d %d %f\n", &batch_size, + &seq_len, &head_num, &size_per_head, &data_type, &batchCount, + &m, &n, &k, &algoId, &exec_time) != EOF) { + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k); + std::string markStr(mark); + sp_algo_map_[markStr] = algoId; + } + fclose(fd); +} + +int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, + const int k) { + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); + if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { + return sp_algo_map_[mark]; + } else { + // for remove padding, select algo 1 for simplicity + return 0; + } +} + +bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, + const int k) { + // not available to use cusparselt. + if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0) { + return false; + } + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); + if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { + return sp_algo_map_[mark] != -1; + } else { + // no gemm test case, choose sparse according to sparse flag + return true; + } +} diff --git a/csrc/ftgemm/cublasAlgoMap.h b/csrc/ftgemm/cublasAlgoMap.h new file mode 100644 index 000000000000..beb9d3a23d90 --- /dev/null +++ b/csrc/ftgemm/cublasAlgoMap.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_utils.h" +#include +#include +#include +#include +#include +#include +#include + +#pragma once + +#define GEMM_NUM 6 +#define GEMM_CONFIG "gemm_config.in" +#define IGEMM_CONFIG "igemm_config.in" +#define SPGEMM_CONFIG "spgemm_config.in" +#define SPIGEMM_CONFIG "spigemm_config.in" + +typedef struct { + int algoId, customOption, tile, splitK_val; + int swizzle, reductionScheme, workspaceSize; + // only used in cublasLt >= 11.0 + int stages; +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + uint16_t inner_shapeId, cluster_shapeId; +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + uint16_t mma_shapeId, cga_shapeId, sche_mode; +#endif + float exec_time; +} cublasLtMatmulAlgo_info; + +/* Structure to store information about different run trials */ +typedef struct { + cublasLtMatmulAlgo_t algo; + cublasStatus_t status; + float time; + size_t workspaceSize; // actual memory workspace needed + cublasMath_t mathMode; + cublasLtReductionScheme_t reductionScheme; + int customOption; + float wavesCount; +} customMatmulPerf_t; + +struct cublasAlgoConfig_t { + int batch_count; + int m; + int n; + int k; + CublasDataType data_type; + bool operator==(cublasAlgoConfig_t const &config) const { + return (batch_count == config.batch_count) && (m == config.m) && + (n == config.n) && (k == config.k) && + (data_type == config.data_type); + } +}; + +class cublasAlgoConfig_hasher { +public: + std::size_t operator()(cublasAlgoConfig_t const &config) const { + return config.batch_count * 98317ull ^ config.m * 49157ull ^ + config.n * 24593ull ^ config.k * 196613ull ^ + static_cast(config.data_type) * 6151ull; + } +}; + +class cublasAlgoMap { +private: + std::unordered_map + algo_map_; + std::string config_filename_; + std::string sp_config_filename_; + std::map sp_algo_map_; + +public: + cublasAlgoMap(){}; + explicit cublasAlgoMap(const std::string filename, + const std::string sp_config_filename = ""); + cublasAlgoMap(const cublasAlgoMap &map); + ~cublasAlgoMap(); + void loadGemmConfig(); + void loadSpGemmConfig(); + int getSpAlgo(const int batch_count, const int m, const int n, const int k); + bool isUseSparse(const int batch_count, const int m, const int n, + const int k); + + bool isExist(const int batch_count, const int m, const int n, const int k, + const CublasDataType data_type); + + cublasLtMatmulAlgo_info getAlgo(const int batch_count, const int m, + const int n, const int k, + const CublasDataType data_type); +}; diff --git a/csrc/ftgemm/cublasINT8MMWrapper.cc b/csrc/ftgemm/cublasINT8MMWrapper.cc new file mode 100644 index 000000000000..23f68c971414 --- /dev/null +++ b/csrc/ftgemm/cublasINT8MMWrapper.cc @@ -0,0 +1,840 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasINT8MMWrapper.h" + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublasMMWrapper(nullptr, cublaslt_handle, stream, cublas_algo_map, mu, + nullptr), + use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} + +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, + mu, nullptr), + use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} + +#ifdef SPARSITY_ENABLED +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, + cusparseLtHandle_t cusparselt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublasMMWrapper(nullptr, cublaslt_handle, cusparselt_handle, stream, + cublas_algo_map, mu, nullptr), + use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} +#endif + +cublasINT8MMWrapper::~cublasINT8MMWrapper() { mu_ = nullptr; } + +cublasINT8MMWrapper::cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper) + : +#ifdef SPARSITY_ENABLED + cublasMMWrapper(nullptr, wrapper.cublaslt_handle_, + wrapper.cusparselt_handle_, wrapper.stream_, + wrapper.cublas_algo_map_, wrapper.mu_, + wrapper.allocator_), +#else + cublasMMWrapper(nullptr, wrapper.cublaslt_handle_, wrapper.stream_, + wrapper.cublas_algo_map_, wrapper.mu_, + wrapper.allocator_), +#endif + use_ORDER_COL32_2R_4R4_(wrapper.use_ORDER_COL32_2R_4R4_) { +} + +// for int8 cublasLtMM with algo +// ATransform should be m*n, CUBLASLT_ORDER_COL32 +// kernel should be n*k, CUBLASLT_ORDER_COL4_4R2_8C or +// CUBLASLT_ORDER_COL32_2R_4R4 res is m*n, CUBLASLT_ORDER_COL32 +void cublasINT8MMWrapper::Gemm(int *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + cublasLtOrder_t order_matrixB; +#if (CUDART_VERSION >= 11000) + if (use_ORDER_COL32_2R_4R4_) { + order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; + } else { + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + } +#else + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; +#endif + + int ldaTransform = 32 * m; + int ldbTransform; + if (use_ORDER_COL32_2R_4R4_) { + ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; + } else { + ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; + } + int ldcTransform = 32 * m; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); + cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_matrixB, sizeof(order_matrixB)); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, m, n, ldcTransform); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + int alphaI = 1; + int betaI = 0; + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + // printf("find algo %s\n", markStr.c_str()); + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + if (use_ORDER_COL32_2R_4R4_) { + algoId = 7; + } else { + algoId = 6; + } + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + if (use_ORDER_COL32_2R_4R4_) { + stages = 15; + } else { + stages = 13; + } + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, ATransform, + AtransformDesc, kernel, BtransformDesc, &betaI, res, + CtransformDesc, res, CtransformDesc, + (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_32I +// alpha: CUDA_R_32I should be 1 +// beta: CUDA_R_32I should be 0 +// computeType: CUBLAS_COMPUTE_32I +void cublasINT8MMWrapper::Gemm_(int *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + int alphaI = 1; + int betaI = 0; + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + // printf("find algo %s\n", markStr.c_str()); + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, kernel, AtransformDesc, + ATransform, BtransformDesc, &betaI, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// for int8 IO cublasLtMM with algo +// ATransform should be m*k CUBLASLT_ORDER_COL32 +// kernel should be n*k CUBLASLT_ORDER_COL4_4R2_8C +// res is m*n CUBLASLT_ORDER_COL32 +void cublasINT8MMWrapper::Gemm(int8_t *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + cublasLtOrder_t order_matrixB; +#if (CUDART_VERSION >= 11000) + if (use_ORDER_COL32_2R_4R4_) { + order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; + } else { + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + } +#else + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; +#endif + + int ldaTransform = 32 * m; + + int ldbTransform; + if (use_ORDER_COL32_2R_4R4_) { + ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; + } else { + ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; + } + + int ldcTransform = 32 * m; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); + cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_matrixB, sizeof(order_matrixB)); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, m, n, ldcTransform); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + if (use_ORDER_COL32_2R_4R4_) { + algoId = 7; + } else { + algoId = 6; + } + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + if (use_ORDER_COL32_2R_4R4_) { + stages = 15; + } else { + stages = 13; + } + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_8I +// alpha: CUDA_R_32F +// beta: CUDA_R_32F +// computeType: CUBLAS_COMPUTE_32I +void cublasINT8MMWrapper::Gemm_(int8_t *res, int batchCount, int m, int n, + int k, int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, n, m, k, INT8_DATATYPE)) { + findAlgo = 1; + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, n, m, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_32F +// alpha: CUDA_R_32F +// beta: CUDA_R_32F +// computeType: CUBLAS_COMPUTE_32F +void cublasINT8MMWrapper::Gemm_f(float *res, int batchCount, int m, int n, + int k, int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; +#else + cudaDataType_t computeType = CUDA_R_32F; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + // cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32F, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, n, m, k, INT8_DATATYPE)) { + findAlgo = 1; + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, n, m, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32F, CUDA_R_32F, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32F, CUDA_R_32F, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + findAlgo = 0; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +bool cublasINT8MMWrapper::getUseOrderCol322R4R4() { + return use_ORDER_COL32_2R_4R4_; +} diff --git a/csrc/ftgemm/cublasINT8MMWrapper.h b/csrc/ftgemm/cublasINT8MMWrapper.h new file mode 100644 index 000000000000..cbd2879a36b0 --- /dev/null +++ b/csrc/ftgemm/cublasINT8MMWrapper.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasAlgoMap.h" +#include "cublasMMWrapper.h" +#include "cuda_utils.h" +#include +#include +#include +#include +#include +#include + +#pragma once + +class cublasINT8MMWrapper : public cublasMMWrapper { +private: + bool use_ORDER_COL32_2R_4R4_; + +public: + cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle_, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, + bool use_ORDER_COL32_2R_4R4); + + cublasINT8MMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, + bool use_ORDER_COL32_2R_4R4); + + ~cublasINT8MMWrapper(); + + cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper); + + void Gemm(int *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const int8_t *ATransform, + const int8_t *kernel); + + void Gemm_(int *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const int8_t *ATransform, + const int8_t *kernel); + + void Gemm(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + void Gemm_(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + void Gemm_f(float *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + bool getUseOrderCol322R4R4(); +}; \ No newline at end of file diff --git a/csrc/ftgemm/cublasMMWrapper.cc b/csrc/ftgemm/cublasMMWrapper.cc new file mode 100644 index 000000000000..184304af74b7 --- /dev/null +++ b/csrc/ftgemm/cublasMMWrapper.cc @@ -0,0 +1,851 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasMMWrapper.h" +#include "cuda_utils.h" +#include + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, std::mutex *mu, + IAllocator *allocator) + : cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), + stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), + allocator_(allocator) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (allocator_ != nullptr) { + cublas_workspace_ = + allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); + } +} + +#ifdef SPARSITY_ENABLED +cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + cusparseLtHandle_t cusparselt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, std::mutex *mu, + IAllocator *allocator) + : cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), + cusparselt_handle_(cusparselt_handle), stream_(stream), + cublas_algo_map_(cublas_algo_map), mu_(mu), allocator_(allocator) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (allocator_ != nullptr) { + cublas_workspace_ = + allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); + } +} +#endif + +cublasMMWrapper::~cublasMMWrapper() { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + mu_ = nullptr; + if (allocator_ != nullptr) { + allocator_->free((void **)(&cublas_workspace_)); + allocator_ = nullptr; + } +} + +cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper &wrapper) + : cublas_handle_(wrapper.cublas_handle_), + cublaslt_handle_(wrapper.cublaslt_handle_), +#ifdef SPARSITY_ENABLED + cusparselt_handle_(wrapper.cusparselt_handle_), +#endif + stream_(wrapper.stream_), cublas_algo_map_(wrapper.cublas_algo_map_), + mu_(wrapper.mu_), allocator_(wrapper.allocator_) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (allocator_ != nullptr) { + cublas_workspace_ = + allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); + } +} + +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, + const void *alpha, const void *A, + cudaDataType_t Atype, int lda, const void *B, + cudaDataType_t Btype, int ldb, const void *beta, + void *C, cudaDataType_t Ctype, int ldc, + cudaDataType_t computeType, cublasGemmAlgo_t algo) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + mu_->lock(); + check_cuda_error(cublasGemmEx(cublas_handle_, transa, transb, m, n, k, alpha, + A, Atype, lda, B, Btype, ldb, beta, C, Ctype, + ldc, computeType, algo)); + sync_check_cuda_error(); + mu_->unlock(); +} + +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const void *B, const int ldb, void *C, + const int ldc) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); +} + +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const void *B, const int ldb, void *C, + const int ldc, float f_alpha, float f_beta) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + half h_alpha = (half)(f_alpha); + half h_beta = (half)(f_beta); + + mu_->lock(); + // TODO: default cublas libs + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; + int batch_count = 1; + // fp32 use cublas as default + // fp16 use cublasLt as default + const void *alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + + int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, + getCublasDataType(Atype_)); + + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + if (findAlgo) { + if (info.stages != -1) { + using_cublasLt = true; + } else { + using_cublasLt = false; + } + } + + if (using_cublasLt) { + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cudaDataType_t scaleType; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType; +#else + cudaDataType_t computeType; +#endif + + if (is_fp16_computeType) { +#if (CUDART_VERSION >= 11000) + computeType = CUBLAS_COMPUTE_16F; +#else + computeType = CUDA_R_16F; +#endif + scaleType = CUDA_R_16F; + } else { +#if (CUDART_VERSION >= 11000) + computeType = CUBLAS_COMPUTE_32F; +#else + computeType = CUDA_R_32F; +#endif + scaleType = CUDA_R_32F; + } + + // -------------------------------------- + // Create descriptors for the original matrices + cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, + transa == CUBLAS_OP_N ? k : m, lda); + cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, + transb == CUBLAS_OP_N ? n : k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&operationDesc, computeType); +#endif + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &transa, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &transb, sizeof(cublasOperation_t)); + + cublasLtMatmulAlgo_t algo; + void *workSpace = cublas_workspace_; + int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + if (findAlgo) { + if (info.workspaceSize > workspaceSize) { + findAlgo = 0; + } else { + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, scaleType, Atype_, + Btype_, Ctype_, Ctype_, info.algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), + sizeof(info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_TILE_ID, + &(info.tile), sizeof(info.tile)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), + sizeof(info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), + sizeof(info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(info.reductionScheme), sizeof(info.reductionScheme)); + +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), + sizeof(info.stages)); +#endif + +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), + sizeof(info.inner_shapeId)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, + &(info.cluster_shapeId), sizeof(info.cluster_shapeId)); +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), + sizeof(info.mma_shapeId)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), + sizeof(info.cga_shapeId)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), + sizeof(info.sche_mode)); +#endif + } + } + + cublasLtMatmul(cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, + beta, C, Cdesc, C, Cdesc, (findAlgo == 1 ? (&algo) : NULL), + workSpace, workspaceSize, stream_); + + cublasLtMatmulDescDestroy(operationDesc); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Cdesc); + sync_check_cuda_error(); + } else { + int cublasAlgo = info.algoId; + check_cuda_error(cublasGemmEx(cublas_handle_, transa, transb, m, n, k, + alpha, A, Atype_, lda, B, Btype_, ldb, beta, + C, Ctype_, ldc, computeType_, + static_cast(cublasAlgo))); + sync_check_cuda_error(); + } + mu_->unlock(); +} + +void cublasMMWrapper::setFP32GemmConfig() { + Atype_ = CUDA_R_32F; + Btype_ = CUDA_R_32F; + Ctype_ = CUDA_R_32F; + computeType_ = CUDA_R_32F; +} + +void cublasMMWrapper::setFP16GemmConfig() { + Atype_ = CUDA_R_16F; + Btype_ = CUDA_R_16F; + Ctype_ = CUDA_R_16F; + computeType_ = CUDA_R_32F; +} + +#ifdef ENABLE_BF16 +void cublasMMWrapper::setBF16GemmConfig() { + Atype_ = CUDA_R_16BF; + Btype_ = CUDA_R_16BF; + Ctype_ = CUDA_R_16BF; + computeType_ = CUDA_R_32F; +} +#endif + +void cublasMMWrapper::setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, + cudaDataType_t cType, + cudaDataType_t computeType) { + Atype_ = aType; + Btype_ = bType; + Ctype_ = cType; + computeType_ = computeType; +} + +CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) { + if (data_type == CUDA_R_16F) { + return HALF_DATATYPE; + } else if (data_type == CUDA_R_32F) { + return FLOAT_DATATYPE; + } +#ifdef ENABLE_BF16 + else if (data_type == CUDA_R_16BF) { + return BFLOAT16_DATATYPE; + } +#endif + return FLOAT_DATATYPE; +} + +#if (CUDART_VERSION >= 11000) +// input, weight, output are row-major +// only works for cublas 11.x +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const void *B, const int ldb, + const void *bias, void *C, const int ldc) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + cudaDataType_t Atype, Btype, Ctype; + cublasComputeType_t computeType; + cudaDataType_t scaleType; + float alpha_float = 1.0f; + float beta_float = 0.0f; + half alpha_half = half(1.0f); + half beta_half = half(0.0f); + void *alpha, *beta; + + // int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + if (Atype_ == CUDA_R_32F) { + computeType = CUBLAS_COMPUTE_32F_FAST_TF32; + Atype = CUDA_R_32F; + Btype = CUDA_R_32F; + Ctype = CUDA_R_32F; + scaleType = CUDA_R_32F; + alpha = &alpha_float; + beta = &beta_float; + } else if (Atype_ == CUDA_R_16BF) { + computeType = CUBLAS_COMPUTE_32F_FAST_TF32; + Atype = CUDA_R_16BF; + Btype = CUDA_R_16BF; + Ctype = CUDA_R_16BF; + scaleType = CUDA_R_32F; + alpha = &alpha_float; + beta = &beta_float; + } else { + computeType = CUBLAS_COMPUTE_16F; + Atype = CUDA_R_16F; + Btype = CUDA_R_16F; + Ctype = CUDA_R_16F; + scaleType = CUDA_R_16F; + alpha = &alpha_half; + beta = &beta_half; + } + + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; + cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, + (transa == CUBLAS_OP_N) ? k : m, lda); + cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, + (transb == CUBLAS_OP_N) ? n : k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc); + + cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &transa, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &transb, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epi, sizeof(cublasLtEpilogue_t)); + cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, + sizeof(const void *)); + check_cuda_error(cublasLtMatmul(cublaslt_handle_, operationDesc, alpha, A, + Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, + NULL, NULL, 0, stream_)); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Cdesc); + cublasLtMatmulDescDestroy(operationDesc); +} +#endif +void cublasMMWrapper::setStream(cudaStream_t stream) { stream_ = stream; } + +void cublasMMWrapper::stridedBatchedGemm( + cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const int64_t strideA, const void *B, const int ldb, const int64_t strideB, + void *C, const int ldc, const int64_t strideC, const int batch_count, + const float f_alpha, const float f_beta) { + half h_alpha = (half)f_alpha; + half h_beta = (half)f_beta; + + mu_->lock(); + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + const void *alpha = is_fp16_computeType + ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType + ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + + check_cuda_error(cublasGemmStridedBatchedEx( + cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, strideA, + B, Btype_, ldb, strideB, beta, C, Ctype_, ldc, strideC, batch_count, + computeType_, static_cast(info.algoId))); + + mu_->unlock(); +} + +void cublasMMWrapper::stridedBatchedGemm( + cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const float f_alpha, const void *A, + cudaDataType_t AType, const int lda, const int64_t strideA, const void *B, + cudaDataType_t BType, const int ldb, const int64_t strideB, + const float f_beta, void *C, cudaDataType_t CType, const int ldc, + const int64_t strideC, const int batch_count, cudaDataType_t computeType) { + half h_alpha = (half)f_alpha; + half h_beta = (half)f_beta; + + mu_->lock(); + int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; + const void *alpha = is_fp16_computeType + ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType + ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + + check_cuda_error(cublasGemmStridedBatchedEx( + cublas_handle_, transa, transb, m, n, k, alpha, A, AType, lda, strideA, B, + BType, ldb, strideB, beta, C, CType, ldc, strideC, batch_count, + computeType, static_cast(info.algoId))); + + mu_->unlock(); +} + +void cublasMMWrapper::batchedGemm(cublasOperation_t transa, + cublasOperation_t transb, const int m, + const int n, const int k, + const void *const *A, const int lda, + const void *const *B, const int ldb, + void *const *C, const int ldc, + const int batch_count) { + float f_alpha = static_cast(1.0f); + float f_beta = static_cast(0.0f); + + half h_alpha = (half)1.0f; + half h_beta = (half)0.0f; + + mu_->lock(); + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + const void *alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + + check_cuda_error(cublasGemmBatchedEx( + cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_, + ldb, beta, C, Ctype_, ldc, batch_count, computeType_, + static_cast(info.algoId))); + mu_->unlock(); +} + +bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, + const int k, const int n) { + CublasDataType data_type = getCublasDataType(Atype_); + + if (cublas_algo_map_->isExist(batch_count, m, k, n, data_type) == false || + cublas_algo_map_->isExist(1, m, k, n, data_type) == false) { + return false; + } else { + return cublas_algo_map_->getAlgo(batch_count, m, k, n, data_type) + .exec_time < + 3 * cublas_algo_map_->getAlgo(1, m, k, n, data_type).exec_time; + } +} + +#ifdef SPARSITY_ENABLED +void cublasMMWrapper::SpGemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, + const void *A, const void *B, void *C) { + if (Atype_ != CUDA_R_16F || Btype_ != CUDA_R_16F || Ctype_ != CUDA_R_16F) { + throw std::runtime_error( + "\n[FT][ERROR] sparse GEMM only supports FP16 data type now."); + } + static bool not_printed_fp32_accumulation_warning = true; + if (computeType_ != CUDA_R_16F && not_printed_fp32_accumulation_warning) { + printf("[FT][WARNING] cublasMMWrapper sets to FP32 compute type, " + "but sparse gemm will use FP16 compute type since cusparselt " + "supports FP16 accumulation only.\n"); + not_printed_fp32_accumulation_warning = false; + } + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = (transa == CUBLAS_OP_N) + ? CUSPARSE_OPERATION_NON_TRANSPOSE + : CUSPARSE_OPERATION_TRANSPOSE; + cusparseOperation_t opB = (transb == CUBLAS_OP_N) + ? CUSPARSE_OPERATION_NON_TRANSPOSE + : CUSPARSE_OPERATION_TRANSPOSE; + cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F; + cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulAlgSelection_t alg_sel; + cusparseLtMatmulPlan_t plan; + + bool is_rowmajor = (order == CUSPARSE_ORDER_ROW); + bool isA_transposed = (opA != CUSPARSE_OPERATION_NON_TRANSPOSE); + bool isB_transposed = (opB != CUSPARSE_OPERATION_NON_TRANSPOSE); + auto num_A_rows = (isA_transposed) ? k : m; + auto num_A_cols = (isA_transposed) ? m : k; + auto num_B_rows = (isB_transposed) ? n : k; + auto num_B_cols = (isB_transposed) ? k : n; + auto num_C_rows = m; + auto num_C_cols = n; + unsigned alignment = 16; + auto lda = (is_rowmajor) ? num_A_cols : num_A_rows; + auto ldb = (is_rowmajor) ? num_B_cols : num_B_rows; + auto ldc = (is_rowmajor) ? num_C_cols : num_C_rows; + float _alpha(1.0f); + float _beta(0.0f); + + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", 1, m, n, k); + if (sp_mat_A_desc_map_.find(mark) != sp_mat_A_desc_map_.end()) { + CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( + &cusparselt_handle_, &matmul, opA, opB, &sp_mat_A_desc_map_[mark], + &sp_mat_B_desc_map_[mark], &sp_mat_C_desc_map_[mark], + &sp_mat_C_desc_map_[mark], compute_type)) + } else { + // initializing MatDesc takes a lot of time + cusparseLtMatDescriptor_t matA, matB, matC; + sp_mat_A_desc_map_[mark] = matA; + sp_mat_B_desc_map_[mark] = matB; + sp_mat_C_desc_map_[mark] = matC; + CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( + &cusparselt_handle_, &sp_mat_A_desc_map_[mark], num_A_rows, num_A_cols, + lda, alignment, Atype_, order, CUSPARSELT_SPARSITY_50_PERCENT)) + CHECK_CUSPARSE(cusparseLtDenseDescriptorInit( + &cusparselt_handle_, &sp_mat_B_desc_map_[mark], num_B_rows, num_B_cols, + ldb, alignment, Btype_, order)) + CHECK_CUSPARSE(cusparseLtDenseDescriptorInit( + &cusparselt_handle_, &sp_mat_C_desc_map_[mark], num_C_rows, num_C_cols, + ldc, alignment, Ctype_, order)) + CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( + &cusparselt_handle_, &matmul, opA, opB, &sp_mat_A_desc_map_[mark], + &sp_mat_B_desc_map_[mark], &sp_mat_C_desc_map_[mark], + &sp_mat_C_desc_map_[mark], compute_type)) + } + mu_->lock(); + CHECK_CUSPARSE(cusparseLtMatmulAlgSelectionInit( + &cusparselt_handle_, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) + int alg = cublas_algo_map_->getSpAlgo(1, num_A_rows, num_B_cols, num_A_cols); + CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( + &cusparselt_handle_, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, + sizeof(alg))) + size_t workspace_size; + CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&cusparselt_handle_, &alg_sel, + &workspace_size)) + CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&cusparselt_handle_, &plan, &matmul, + &alg_sel, workspace_size)) + + void *d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream_}; + CHECK_CUSPARSE(cusparseLtMatmul(&cusparselt_handle_, &plan, &_alpha, A, B, + &_beta, C, C, d_workspace, streams, + num_streams)) + CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan)) + sync_check_cuda_error(); + mu_->unlock(); +} + +size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) { + // Get a compressed matrix size of shape (m, k) used in cusparselt. + auto Atype_ = CUDA_R_16F; + cusparseOrder_t order = CUSPARSE_ORDER_COL; + unsigned alignment = 16; + int num_A_rows = m; + int num_A_cols = k; + int lda = num_A_rows; + + cusparseLtMatDescriptor_t matA; + CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( + &cusparselt_handle_, &matA, num_A_rows, num_A_cols, lda, alignment, + Atype_, order, CUSPARSELT_SPARSITY_50_PERCENT)); + size_t compressed_size = 0; + CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &matA, + &compressed_size)); + return compressed_size; +} + +void cublasMMWrapper::compressMatrix(const void *input, void *output, + const int m, const int k) { + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseLtMatDescriptor_t matA; + unsigned alignment = 16; + CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( + &cusparselt_handle_, &matA, m, k, m, alignment, CUDA_R_16F, order, + CUSPARSELT_SPARSITY_50_PERCENT)) + CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &matA, true, opA, + input, output, stream_)) + sync_check_cuda_error(); +} + +bool cublasMMWrapper::isUseSparse(const int batch_count, const int m, + const int n, const int k) { + return cublas_algo_map_->isUseSparse(batch_count, m, n, k); +} +#endif + +std::pair cublasMMWrapper::findBestAlgo( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, cudaStream_t stream) { +#if (CUBLAS_VERSION) < 11601 + FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); + return {false, cublasLtMatmulAlgo_t{}}; +#else + size_t returnSize; + int32_t pointer_mode; + cublasLtMatmulDescGetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode), + &returnSize); + + std::vector heuristics(200); + cublasLtMatmulPreference_t preference; + check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); + check_cuda_error(cublasLtMatmulPreferenceInit(preference)); + uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); +#if (CUBLAS_VERSION) <= 12000 + uint32_t pointer_mode_mask = 0; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, + sizeof(pointer_mode_mask))); +#endif + + int return_count = 0; + auto ret = cublasLtMatmulAlgoGetHeuristic( + lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, + heuristics.size(), heuristics.data(), &return_count); + heuristics.resize(return_count); + + std::map> algo_results; + for (const auto &heuristic : heuristics) { + cublasLtMatmulAlgo_t algo = heuristic.algo; + int32_t algo_id; + cublasLtMatmulAlgoConfigGetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); + + cudaEvent_t start_event, stop_event; + cudaEventCreate(&start_event); + cudaEventCreate(&stop_event); + + float my_alpha = 1.0f; + float my_beta = 0.0f; + + for (int i = 0; i < 11; i++) { + float duration_ms; + cudaEventRecord(start_event, stream); + check_cuda_error(cublasLtMatmul( + lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, + D, Ddesc, &algo, cublas_workspace_, CUBLAS_WORKSPACE_SIZE, stream)); + cudaEventRecord(stop_event, stream); + cudaEventSynchronize(stop_event); + cudaEventElapsedTime(&duration_ms, start_event, stop_event); + + algo_results[algo_id].push_back(duration_ms); + } + std::sort(algo_results[algo_id].begin(), algo_results[algo_id].end()); + } + + cublasLtMatmulHeuristicResult_t result; + float best_time = INFINITY; + for (const auto &heuristic : heuristics) { + cublasLtMatmulAlgo_t algo = heuristic.algo; + int32_t algo_id; + cublasLtMatmulAlgoConfigGetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); + const auto &results = algo_results[algo_id]; + + if (results.size() > 0 && results[5] < best_time) { + best_time = results[5]; + result = heuristic; + } + } + + return {best_time != INFINITY, result.algo}; +#endif +} + +cublasMMWrapper::MatrixLayout +cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc) { + size_t returnSize; + MatrixLayout m_layout; + + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, + &std::get<0>(m_layout), + sizeof(std::get<0>(m_layout)), &returnSize); + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &std::get<1>(m_layout), + sizeof(std::get<1>(m_layout)), &returnSize); + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, + &std::get<2>(m_layout), + sizeof(std::get<2>(m_layout)), &returnSize); + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, + &std::get<3>(m_layout), + sizeof(std::get<3>(m_layout)), &returnSize); + + return m_layout; +} + +cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t *algo, + void *workspace, size_t workspaceSizeInBytes, cudaStream_t stream) { + cache_idx_t cache_idx{computeDesc, + {createMatrixLayout(Adesc), createMatrixLayout(Bdesc), + createMatrixLayout(Cdesc), createMatrixLayout(Ddesc)}}; + + cublasLtMatmulAlgo_t algo_value; + bool found_algo = false; + if (algo == nullptr) { + if (algo_cache.find(cache_idx) == algo_cache.end()) { + auto result = findBestAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, + Bdesc, beta, C, Cdesc, D, Ddesc, stream); + if (result.first) { + algo_cache[cache_idx] = result.second; + algo_value = result.second; + found_algo = true; + } + } else { + algo_value = algo_cache[cache_idx]; + found_algo = true; + } + } + + return cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, + beta, C, Cdesc, D, Ddesc, + found_algo ? &algo_value : algo, workspace, + workspaceSizeInBytes, stream); +} + +void cublasMMWrapper::_Int8Gemm(const int m, const int n, const int k, + const int8_t *A, const int lda, const int8_t *B, + const int ldb, void *C, const int ldc, + const void *alpha, const int mode, + const bool per_column_scaling) { +/* mode: + * - 0: int8 * int8 -> int32 -> int8 + * - 1: int8 * int8 -> int32 -> int32 + */ +#if (CUBLAS_VERSION) < 11601 + FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); +#else + + mu_->lock(); + const auto op_a = CUBLAS_OP_T; + const auto op_b = CUBLAS_OP_N; + const auto dataType = CUDA_R_8I; + const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I; + const auto computeType = CUBLAS_COMPUTE_32I; + const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I; + const int batch_count = 1; + const void *beta; + + int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, + getCublasDataType(dataType)); + + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(dataType)); + + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + + // -------------------------------------- + // Create descriptors for the original matrices + check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, lda)); + check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, ldb)); + check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, resultType, m, n, ldc)); + + check_cuda_error( + cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType)); + + auto pointer_mode = CUBLASLT_POINTER_MODE_HOST; + if (mode == 0) { + pointer_mode = per_column_scaling + ? CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST + : CUBLASLT_POINTER_MODE_DEVICE; + } + check_cuda_error( + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &op_a, sizeof(cublasOperation_t))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &op_b, sizeof(cublasOperation_t))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, + &op_b, sizeof(cublasOperation_t))); + check_cuda_error(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, + sizeof(pointer_mode))); + + const int32_t int_one = 1; + const int32_t int_zero = 0; + const float float_zero = 0; + if (mode == 0) { + beta = per_column_scaling ? &float_zero : NULL; + } else { + alpha = &int_one; + beta = &int_zero; + } + + cublasLtMatmulAlgo_t algo; + void *workSpace = cublas_workspace_; + int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + sync_check_cuda_error(); + auto ret = cublasLtMatmulWrapper(cublaslt_handle_, operationDesc, alpha, A, + Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, + NULL, workSpace, workspaceSize, stream_); + check_cuda_error(ret); + sync_check_cuda_error(); + + cublasLtMatmulDescDestroy(operationDesc); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Cdesc); + sync_check_cuda_error(); + mu_->unlock(); +#endif +} + +void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, + const int8_t *A, const int lda, const int8_t *B, + const int ldb, int8_t *C, const int ldc, + const float *alpha, + const bool per_column_scaling) { + return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, alpha, 0, + per_column_scaling); +} + +void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, + const int8_t *A, const int lda, const int8_t *B, + const int ldb, int32_t *C, const int ldc) { + return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, (float *)nullptr, 1, false); +} diff --git a/csrc/ftgemm/cublasMMWrapper.h b/csrc/ftgemm/cublasMMWrapper.h new file mode 100644 index 000000000000..69f229246ea9 --- /dev/null +++ b/csrc/ftgemm/cublasMMWrapper.h @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "allocator.h" +#include "cublasAlgoMap.h" +#include "cuda_utils.h" +#include +#include +#include +#include +#include +#include + +#pragma once + +class cublasMMWrapper { +protected: + cublasHandle_t cublas_handle_; + cublasLtHandle_t cublaslt_handle_; +#ifdef SPARSITY_ENABLED + cusparseLtHandle_t cusparselt_handle_; + std::map sp_mat_A_desc_map_; + std::map sp_mat_B_desc_map_; + std::map sp_mat_C_desc_map_; +#endif + + cudaDataType_t Atype_; + cudaDataType_t Btype_; + cudaDataType_t Ctype_; + cudaDataType_t computeType_; + + cudaStream_t stream_; + cublasAlgoMap *cublas_algo_map_; + std::mutex *mu_; + + IAllocator *allocator_ = nullptr; + void *cublas_workspace_ = nullptr; + + friend class cublasINT8MMWrapper; + + void _Int8Gemm(const int m, const int n, const int k, const int8_t *A, + const int lda, const int8_t *B, const int ldb, void *C, + const int ldc, const void *alpha, const int mode, + const bool per_column_scaling); + +public: + cublasMMWrapper(cublasHandle_t cublas_handle_, + cublasLtHandle_t cublaslt_handle_, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, IAllocator *allocator); + +#ifdef SPARSITY_ENABLED + cublasMMWrapper(cublasHandle_t cublas_handle_, + cublasLtHandle_t cublaslt_handle_, + cusparseLtHandle_t cusparselt_handle, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, IAllocator *allocator); +#endif + + ~cublasMMWrapper(); + + cublasMMWrapper(const cublasMMWrapper &wrapper); + + virtual void cublasVersionCheck() { return; }; + cublasStatus_t cublasLtMatmulWrapper( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t *algo, + void *workspace, size_t workspaceSizeInBytes, cudaStream_t stream); + + std::pair + findBestAlgo(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, cudaStream_t stream); + + using MatrixLayout = + std::tuple; + using cache_idx_t = + std::tuple>; + std::map algo_cache; + + MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *alpha, const void *A, + cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, + int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, + cudaDataType_t computeType, cublasGemmAlgo_t algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const void *B, const int ldb, void *C, const int ldc); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const void *B, const int ldb, void *C, const int ldc, float f_alpha, + float f_beta); + + void Int8Gemm(const int m, const int n, const int k, const int8_t *A, + const int lda, const int8_t *B, const int ldb, int8_t *C, + const int ldc, const float *alpha, + const bool per_column_scaling = false); + + void Int8Gemm(const int m, const int n, const int k, const int8_t *A, + const int lda, const int8_t *B, const int ldb, int32_t *C, + const int ldc); + + void setFP32GemmConfig(); + void setFP16GemmConfig(); +#ifdef ENABLE_BF16 + void setBF16GemmConfig(); +#endif + void setStream(cudaStream_t stream); + + void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, + cudaDataType_t cType, cudaDataType_t computeType); + + CublasDataType getCublasDataType(cudaDataType_t data_type); + +#if (CUDART_VERSION >= 11000) + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const void *B, const int ldb, const void *bias, void *C, + const int ldc); +#endif + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const int64_t strideA, const void *B, + const int ldb, const int64_t strideB, void *C, + const int ldc, const int64_t strideC, + const int batchCount, const float f_alpha = 1.0f, + const float f_beta = 0.0f); + + void stridedBatchedGemm( + cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const float f_alpha, const void *A, + cudaDataType_t AType, const int lda, const int64_t strideA, const void *B, + cudaDataType_t BType, const int ldb, const int64_t strideB, + const float f_beta, void *C, cudaDataType_t CType, const int ldc, + const int64_t strideC, const int batch_count, cudaDataType_t computeType); + + void batchedGemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *const *A, + const int lda, const void *const *B, const int ldb, + void *const *C, const int ldc, const int batch_count); + + bool isFuseBatchGemm(const int batch_count, const int m, const int k, + const int n); + +#ifdef SPARSITY_ENABLED + void SpGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const void *B, void *C); + + size_t getSparseMatrixSize(int m, int k); + void compressMatrix(const void *input, void *output, const int m, + const int k); + + bool isUseSparse(const int batch_count, const int m, const int n, + const int k); +#endif +}; diff --git a/csrc/ftgemm/cuda_utils.cc b/csrc/ftgemm/cuda_utils.cc new file mode 100644 index 000000000000..be8c6d7fb46b --- /dev/null +++ b/csrc/ftgemm/cuda_utils.cc @@ -0,0 +1,381 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_utils.h" +// #include "cuda_fp8_utils.h" + +/* **************************** debug tools ********************************* */ + +template +void print_to_file(const T *result, const int size, const char *file, + cudaStream_t stream, std::ios::openmode open_mode) { + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + printf("[INFO] file: %s with size %d.\n", file, size); + std::ofstream outFile(file, open_mode); + if (outFile) { + T *tmp = new T[size]; + check_cuda_error(cudaMemcpyAsync(tmp, result, sizeof(T) * size, + cudaMemcpyDeviceToHost, stream)); + for (int i = 0; i < size; ++i) { + float val = (float)(tmp[i]); + outFile << val << std::endl; + } + delete[] tmp; + } else { + throw std::runtime_error(std::string("[FT][ERROR] Cannot open file: ") + + file + "\n"); + } + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template void print_to_file(const float *result, const int size, + const char *file, cudaStream_t stream, + std::ios::openmode open_mode); +template void print_to_file(const half *result, const int size, + const char *file, cudaStream_t stream, + std::ios::openmode open_mode); +// #ifdef ENABLE_BF16 +// template void print_to_file( +// const __nv_bfloat16* result, const int size, const char* file, +// cudaStream_t stream, std::ios::openmode open_mode); +// #endif + +template +void print_abs_mean(const T *buf, uint size, cudaStream_t stream, + std::string name) { + if (buf == nullptr) { + // FT_LOG_WARNING("It is an nullptr, skip!"); + return; + } + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + T *h_tmp = new T[size]; + cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + double sum = 0.0f; + uint64_t zero_count = 0; + float max_val = -1e10; + bool find_inf = false; + for (uint i = 0; i < size; i++) { + if (std::isinf((float)(h_tmp[i]))) { + find_inf = true; + continue; + } + sum += abs((double)h_tmp[i]); + if ((float)h_tmp[i] == 0.0f) { + zero_count++; + } + max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i])); + } + printf("[INFO][FT] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, " + "find inf: %s", + name.c_str(), size, sum / size, sum, max_val, + find_inf ? "true" : "false"); + std::cout << std::endl; + delete[] h_tmp; + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template void print_abs_mean(const float *buf, uint size, cudaStream_t stream, + std::string name); +template void print_abs_mean(const half *buf, uint size, cudaStream_t stream, + std::string name); +// #ifdef ENABLE_BF16 +// template void print_abs_mean(const __nv_bfloat16* buf, uint size, +// cudaStream_t stream, std::string name); #endif +template void print_abs_mean(const int *buf, uint size, cudaStream_t stream, + std::string name); +template void print_abs_mean(const uint *buf, uint size, cudaStream_t stream, + std::string name); +template void print_abs_mean(const int8_t *buf, uint size, cudaStream_t stream, + std::string name); +// #ifdef ENABLE_FP8 +// template void print_abs_mean(const __nv_fp8_e4m3* buf, uint size, +// cudaStream_t stream, std::string name); #endif + +template void print_to_screen(const T *result, const int size) { + if (result == nullptr) { + // FT_LOG_WARNING("It is an nullptr, skip! \n"); + return; + } + T *tmp = reinterpret_cast(malloc(sizeof(T) * size)); + check_cuda_error( + cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); + for (int i = 0; i < size; ++i) { + printf("%d, %f\n", i, static_cast(tmp[i])); + } + free(tmp); +} + +template void print_to_screen(const float *result, const int size); +template void print_to_screen(const half *result, const int size); +// #ifdef ENABLE_BF16 +// template void print_to_screen(const __nv_bfloat16* result, const int size); +// #endif +template void print_to_screen(const int *result, const int size); +template void print_to_screen(const uint *result, const int size); +template void print_to_screen(const bool *result, const int size); +// #ifdef ENABLE_FP8 +// template void print_to_screen(const __nv_fp8_e4m3* result, const int size); +// #endif + +template +void printMatrix(T *ptr, int m, int k, int stride, bool is_device_ptr) { + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%7.3f ", (float)tmp[ii * stride + jj]); + } else { + printf("%7d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +template void printMatrix(float *ptr, int m, int k, int stride, + bool is_device_ptr); +template void printMatrix(half *ptr, int m, int k, int stride, + bool is_device_ptr); +// #ifdef ENABLE_BF16 +// template void printMatrix(__nv_bfloat16* ptr, int m, int k, int stride, bool +// is_device_ptr); #endif + +void printMatrix(unsigned long long *ptr, int m, int k, int stride, + bool is_device_ptr) { + typedef unsigned long long T; + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4llu ", tmp[ii * stride + jj]); + } else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +void printMatrix(int *ptr, int m, int k, int stride, bool is_device_ptr) { + typedef int T; + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4d ", tmp[ii * stride + jj]); + } else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +void printMatrix(size_t *ptr, int m, int k, int stride, bool is_device_ptr) { + typedef size_t T; + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4ld ", tmp[ii * stride + jj]); + } else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +template void check_max_val(const T *result, const int size) { + T *tmp = new T[size]; + cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost); + float max_val = -100000; + for (int i = 0; i < size; i++) { + float val = static_cast(tmp[i]); + if (val > max_val) { + max_val = val; + } + } + delete tmp; + printf("[INFO][CUDA] addr %p max val: %f \n", result, max_val); +} + +template void check_max_val(const float *result, const int size); +template void check_max_val(const half *result, const int size); +// #ifdef ENABLE_BF16 +// template void check_max_val(const __nv_bfloat16* result, const int size); +// #endif + +template void check_abs_mean_val(const T *result, const int size) { + T *tmp = new T[size]; + cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost); + float sum = 0.0f; + for (int i = 0; i < size; i++) { + sum += abs(static_cast(tmp[i])); + } + delete tmp; + printf("[INFO][CUDA] addr %p abs mean val: %f \n", result, sum / size); +} + +template void check_abs_mean_val(const float *result, const int size); +template void check_abs_mean_val(const half *result, const int size); +// #ifdef ENABLE_BF16 +// template void check_abs_mean_val(const __nv_bfloat16* result, const int +// size); #endif + +/* ***************************** common utils ****************************** */ + +cudaError_t getSetDevice(int i_device, int *o_device) { + int current_dev_id = 0; + cudaError_t err = cudaSuccess; + + if (o_device != NULL) { + err = cudaGetDevice(¤t_dev_id); + if (err != cudaSuccess) { + return err; + } + if (current_dev_id == i_device) { + *o_device = i_device; + } else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + *o_device = current_dev_id; + } + } else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + } + + return cudaSuccess; +} + +// FtCudaDataType getModelFileType(std::string ini_file, std::string +// section_name) +// { +// FtCudaDataType model_file_type; +// INIReader reader = INIReader(ini_file); +// if (reader.ParseError() < 0) { +// FT_LOG_WARNING("Can't load %s. Use FP32 as default", +// ini_file.c_str()); model_file_type = FtCudaDataType::FP32; +// } +// else { +// std::string weight_data_type_str = +// std::string(reader.Get(section_name, "weight_data_type")); if +// (weight_data_type_str.find("fp32") != std::string::npos) { +// model_file_type = FtCudaDataType::FP32; +// } +// else if (weight_data_type_str.find("fp16") != std::string::npos) { +// model_file_type = FtCudaDataType::FP16; +// } +// else if (weight_data_type_str.find("bf16") != std::string::npos) { +// model_file_type = FtCudaDataType::BF16; +// } +// else { +// FT_LOG_WARNING("Invalid type %s. Use FP32 as default", +// weight_data_type_str.c_str()); model_file_type = +// FtCudaDataType::FP32; +// } +// } +// return model_file_type; +// } + +/* ************************** end of common utils ************************** */ diff --git a/csrc/ftgemm/cuda_utils.h b/csrc/ftgemm/cuda_utils.h new file mode 100644 index 000000000000..33713cf11757 --- /dev/null +++ b/csrc/ftgemm/cuda_utils.h @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// #include "3rdparty/INIReader.h" +// #include "cuda_bf16_wrapper.h" +// #include "src/fastertransformer/utils/logger.h" + +#include +#include +#include +#include +#include +#include +#include +#ifdef SPARSITY_ENABLED +#include +#endif + +#define MAX_CONFIG_NUM 20 +#define COL32_ 32 +// workspace for cublas gemm : 32MB +#define CUBLAS_WORKSPACE_SIZE 33554432 + +typedef struct __align__(4) { + half x, y, z, w; +} +half4; + +/* **************************** type definition ***************************** */ + +enum CublasDataType { + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 +}; + +enum FtCudaDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 }; + +enum class OperationType { FP32, FP16, BF16, INT8, FP8 }; + +/* **************************** debug tools ********************************* */ +static const char *_cudaGetErrorEnum(cudaError_t error) { + return cudaGetErrorString(error); +} + +static const char *_cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; +} + +template +void check(T result, char const *const func, const char *const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_cuda_error_2(val, file, line) check((val), #val, file, line) + +inline void syncAndCheck(const char *const file, int const line) { + // When FT_DEBUG_LEVEL=DEBUG, must check error + static char *level_name = std::getenv("FT_DEBUG_LEVEL"); + if (level_name != nullptr) { + static std::string level = std::string(level_name); + if (level == "DEBUG") { + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error( + std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } + // FT_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line)); + } + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +#endif +} + +#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__) + +#define checkCUDNN(expression) \ + { \ + cudnnStatus_t status = (expression); \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::cerr << "Error on file " << __FILE__ << " line " << __LINE__ \ + << ": " << cudnnGetErrorString(status) << std::endl; \ + std::exit(EXIT_FAILURE); \ + } \ + } + +template +void print_to_file(const T *result, const int size, const char *file, + cudaStream_t stream = 0, + std::ios::openmode open_mode = std::ios::out); + +template +void print_abs_mean(const T *buf, uint size, cudaStream_t stream, + std::string name = ""); + +template void print_to_screen(const T *result, const int size); + +template +void printMatrix(T *ptr, int m, int k, int stride, bool is_device_ptr); + +void printMatrix(unsigned long long *ptr, int m, int k, int stride, + bool is_device_ptr); +void printMatrix(int *ptr, int m, int k, int stride, bool is_device_ptr); +void printMatrix(size_t *ptr, int m, int k, int stride, bool is_device_ptr); + +template void check_max_val(const T *result, const int size); + +template void check_abs_mean_val(const T *result, const int size); + +#define PRINT_FUNC_NAME_() \ + do { \ + std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ + } while (0) + +[[noreturn]] inline void throwRuntimeError(const char *const file, + int const line, + std::string const &info = "") { + throw std::runtime_error(std::string("[FT][ERROR] ") + info + + " Assertion fail: " + file + ":" + + std::to_string(line) + " \n"); +} + +inline void myAssert(bool result, const char *const file, int const line, + std::string const &info = "") { + if (!result) { + throwRuntimeError(file, line, info); + } +} + +#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) +#define FT_CHECK_WITH_INFO(val, info) \ + do { \ + bool is_valid_val = (val); \ + if (!is_valid_val) { \ + fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ + } \ + } while (0) + +#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info) + +#ifdef SPARSITY_ENABLED +#define CHECK_CUSPARSE(func) \ + { \ + cusparseStatus_t status = (func); \ + if (status != CUSPARSE_STATUS_SUCCESS) { \ + throw std::runtime_error( \ + std::string("[FT][ERROR] CUSPARSE API failed at line ") + \ + std::to_string(__LINE__) + " in file " + __FILE__ + ": " + \ + cusparseGetErrorString(status) + " " + std::to_string(status)); \ + } \ + } +#endif + +/*************Time Handling**************/ +class CudaTimer { +private: + cudaEvent_t event_start_; + cudaEvent_t event_stop_; + cudaStream_t stream_; + +public: + explicit CudaTimer(cudaStream_t stream = 0) { stream_ = stream; } + void start() { + check_cuda_error(cudaEventCreate(&event_start_)); + check_cuda_error(cudaEventCreate(&event_stop_)); + check_cuda_error(cudaEventRecord(event_start_, stream_)); + } + float stop() { + float time; + check_cuda_error(cudaEventRecord(event_stop_, stream_)); + check_cuda_error(cudaEventSynchronize(event_stop_)); + check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_)); + check_cuda_error(cudaEventDestroy(event_start_)); + check_cuda_error(cudaEventDestroy(event_stop_)); + return time; + } + ~CudaTimer() {} +}; + +static double diffTime(timeval start, timeval end) { + return (end.tv_sec - start.tv_sec) * 1000 + + (end.tv_usec - start.tv_usec) * 0.001; +} + +/* ***************************** common utils ****************************** */ + +inline void print_mem_usage(std::string time = "after allocation") { + size_t free_bytes, total_bytes; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; + float used = total - free; + printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", + time.c_str(), free, total, used); +} + +inline int getSMVersion() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getMaxSharedMemoryPerBlock() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int max_shared_memory_size = 0; + check_cuda_error(cudaDeviceGetAttribute( + &max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device)); + return max_shared_memory_size; +} + +inline std::string getDeviceName() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + cudaDeviceProp props; + check_cuda_error(cudaGetDeviceProperties(&props, device)); + return std::string(props.name); +} + +inline int div_up(int a, int n) { return (a + n - 1) / n; } + +cudaError_t getSetDevice(int i_device, int *o_device = NULL); + +inline int getDevice() { + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; +} + +inline int getDeviceCount() { + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; +} + +template CublasDataType getCublasDataType() { + if (std::is_same::value) { + return HALF_DATATYPE; + } + // #ifdef ENABLE_BF16 + // else if (std::is_same::value) { + // return BFLOAT16_DATATYPE; + // } + // #endif + else if (std::is_same::value) { + return FLOAT_DATATYPE; + } else { + FT_CHECK(false); + return FLOAT_DATATYPE; + } +} + +template cudaDataType_t getCudaDataType() { + if (std::is_same::value) { + return CUDA_R_16F; + } + // #ifdef ENABLE_BF16 + // else if (std::is_same::value) { + // return CUDA_R_16BF; + // } + // #endif + else if (std::is_same::value) { + return CUDA_R_32F; + } else { + FT_CHECK(false); + return CUDA_R_32F; + } +} + +template struct getTypeFromCudaDataType { + using Type = float; +}; + +template <> struct getTypeFromCudaDataType { + using Type = half; +}; + +// #ifdef ENABLE_BF16 +// template<> +// struct getTypeFromCudaDataType { +// using Type = __nv_bfloat16; +// }; +// #endif + +// FtCudaDataType getModelFileType(std::string ini_file, std::string +// section_name); + +// clang-format off +template struct packed_type; +template <> struct packed_type { using type = float; }; // we don't need to pack float by default +template <> struct packed_type { using type = half2; }; + +// #ifdef ENABLE_BF16 +// template<> +// struct packed_type<__nv_bfloat16> { +// using type = __nv_bfloat162; +// }; +// #endif + +template struct num_elems; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; +template <> struct num_elems { static constexpr int value = 4; }; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; +// #ifdef ENABLE_BF16 +// template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; +// template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; +// #endif + +template struct packed_as; +template struct packed_as { using type = T; }; +template<> struct packed_as { using type = half2; }; +template<> struct packed_as { using type = float2; }; +template<> struct packed_as { using type = int16_t; }; +template<> struct packed_as { using type = int2; }; +template<> struct packed_as { using type = half; }; +// #ifdef ENABLE_BF16 +// template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; +// template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; +// #endif + +inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } +inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } +// clang-format on + +template +void compareTwoTensor(const T1 *pred, const T2 *ref, const int size, + const int print_size = 0, + const std::string filename = "") { + T1 *h_pred = new T1[size]; + T2 *h_ref = new T2[size]; + check_cuda_error( + cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost)); + check_cuda_error( + cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost)); + + FILE *fd = nullptr; + if (filename != "") { + fd = fopen(filename.c_str(), "w"); + fprintf(fd, "| %10s | %10s | %10s | %10s | \n", "pred", "ref", "abs_diff", + "rel_diff(%)"); + } + + if (print_size > 0) { + // FT_LOG_INFO(" id | pred | ref |abs diff | rel diff (%) |"); + } + float mean_abs_diff = 0.0f; + float mean_rel_diff = 0.0f; + int count = 0; + for (int i = 0; i < size; i++) { + if (i < print_size) { + // FT_LOG_INFO("%4d | % 6.4f | % 6.4f | % 6.4f | % 7.4f |", + // i, + // (float)h_pred[i], + // (float)h_ref[i], + // abs((float)h_pred[i] - (float)h_ref[i]), + // abs((float)h_pred[i] - (float)h_ref[i]) / + // (abs((float)h_ref[i]) + 1e-6f) * 100.f); + } + if ((float)h_pred[i] == 0) { + continue; + } + count += 1; + mean_abs_diff += abs((float)h_pred[i] - (float)h_ref[i]); + mean_rel_diff += abs((float)h_pred[i] - (float)h_ref[i]) / + (abs((float)h_ref[i]) + 1e-6f) * 100.f; + + if (fd != nullptr) { + fprintf(fd, "| %10.5f | %10.5f | %10.5f | %11.5f |\n", (float)h_pred[i], + (float)h_ref[i], abs((float)h_pred[i] - (float)h_ref[i]), + abs((float)h_pred[i] - (float)h_ref[i]) / + (abs((float)h_ref[i]) + 1e-6f) * 100.f); + } + } + mean_abs_diff = mean_abs_diff / (float)count; + mean_rel_diff = mean_rel_diff / (float)count; + // FT_LOG_INFO("mean_abs_diff: % 6.4f, mean_rel_diff: % 6.4f (%%)", + // mean_abs_diff, mean_rel_diff); + + if (fd != nullptr) { + fprintf(fd, "mean_abs_diff: % 6.4f, mean_rel_diff: % 6.4f (%%)", + mean_abs_diff, mean_rel_diff); + fclose(fd); + } + delete[] h_pred; + delete[] h_ref; +} + +/* ************************** end of common utils ************************** */ diff --git a/csrc/ftgemm/int8_utils.cuh b/csrc/ftgemm/int8_utils.cuh new file mode 100644 index 000000000000..55b9c4f24de2 --- /dev/null +++ b/csrc/ftgemm/int8_utils.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + +static inline __device__ uint32_t float4_to_char4(float x, + float y, + float z, + float w) { + uint32_t dst; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720 + uint32_t a; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); +#else + char4 tmp; + tmp.x = x; + tmp.y = y; + tmp.z = z; + tmp.w = w; + dst = reinterpret_cast(tmp); +#endif + return dst; +} \ No newline at end of file diff --git a/csrc/ftgemm/transform_layout.cu b/csrc/ftgemm/transform_layout.cu new file mode 100644 index 000000000000..7bc6cfa6b95e --- /dev/null +++ b/csrc/ftgemm/transform_layout.cu @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// #include "src/fastertransformer/kernels/layout_transformer_int8_kernels.h" +#include "transform_layout.h" +#include + +// transform row-major to COL32 +// input matrix is (m, n) row-major +// output matrix is (m, n) COL32 +// n should be a multiple of 32 +// grid((n+31)/32, (m+31)/32) +// block(8, 32) +__global__ void rowMajorToCOL32_kernel(char4 *dst, const char4 *src, const int m, const int n) { + + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + + // COL32_col = n_id >> 5 ; COL32_row = (m_id << 5) + (n_id & 31); + // COL32_idx = (COL32_col << 5) * m + COL32_row = (n_id & 0xffffffe0)*m + + // (m_id << 5) + (n_id & 31) + dst[((n_id & 0xffffffe0) * m + (m_id << 5) + (n_id & 31)) >> 2] = + __ldg(src + ((m_id * n + n_id) >> 2)); + } +} + +__global__ void col32ToRowMajor_kernel(char4 *dst, const char4 *src, + const int m, const int n) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + int idx = m_id * n + n_id; + dst[(((idx >> 5) % m) * n + (((idx >> 5) / m) << 5) + (idx & 31)) >> 2] = + __ldg(src + (idx >> 2)); + } +} + +void invokeRowMajorToCOL32(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + rowMajorToCOL32_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} + +void invokeCOL32ToRowMajor(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + col32ToRowMajor_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} + +__global__ void rowMajorToAmpere_kernel(char4 *dst, const char4 *src, + const int m, const int n) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + int new_col = n_id >> 5; + int row_in_tile = m_id & 31; + int col_in_tile = n_id & 31; + int new_row = // CUBLASLT_ORDER_COL32_2R_4R4 + (((m_id >> 5) << 10) + + //(((row%8)/2*4+row/8)*2+row%2)*32+col + (((((((row_in_tile & 7) >> 1) << 2) + (row_in_tile >> 3)) << 1) + + (row_in_tile & 1)) + << 5) + + col_in_tile); + int idx = m_id * n + n_id; + dst[(new_col * (m << 5) + new_row) >> 2] = __ldg(src + (idx >> 2)); + } +} + +void invokeRowMajorToAmpere(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + rowMajorToAmpere_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} + +__global__ void rowMajorToTuring_kernel(char4 *dst, const char4 *src, + const int m, const int n) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + int new_col = n_id >> 5; + int new_row = // CUBLASLT_ORDER_COL4_4R2_8C + ////m_id/8 is the number of tile of (8 rows 32 columns) -- + /// column-major /m_id%2 is even row, otherwise odd row + ////n_id%COL32_/8 is the number tile of (8 rows 8 columns) + (((((m_id >> 3) << 3) + ((m_id & 1) << 2) + ((n_id & 31) >> 3)) << 5) + + ////n_id%8 >= 4 is the right half of (8 rows 8 columns) tile + ////(m_id%8/2) is (the row id of alternating 4 rows) - 1 + (((((n_id & 7) >= 4) ? 4 : 0) + ((m_id & 7) >> 1)) << 2) + + ////n_id%4 is the id of 4 cols + (n_id & 3)); + int idx = m_id * n + n_id; + dst[(new_col * (m << 5) + new_row) >> 2] = __ldg(src + (idx >> 2)); + } +} + +void invokeRowMajorToTuring(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + rowMajorToTuring_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} diff --git a/csrc/ftgemm/transform_layout.h b/csrc/ftgemm/transform_layout.h new file mode 100644 index 000000000000..a695923b08f8 --- /dev/null +++ b/csrc/ftgemm/transform_layout.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "int8_utils.cuh" +#include +#include +#include + +void invokeRowMajorToCOL32(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); +void invokeCOL32ToRowMajor(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); +void invokeRowMajorToAmpere(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); +void invokeRowMajorToTuring(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); \ No newline at end of file diff --git a/setup.py b/setup.py index 8a461c8cae63..1707e98fcc89 100644 --- a/setup.py +++ b/setup.py @@ -109,6 +109,24 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ) ext_modules.append(i8gemm_extension) +# FTGEMM(cutlass required) +ftgemm_extension = CUDAExtension( + name='vllm.ftgemm', + sources=[ + 'csrc/ftgemm/bindings.cpp', + 'csrc/ftgemm/cublasAlgoMap.cc', + 'csrc/ftgemm/cublasINT8MMWrapper.cc', + 'csrc/ftgemm/cublasMMWrapper.cc', + 'csrc/ftgemm/cuda_utils.cc', + 'csrc/ftgemm/transform_layout.cu' + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, +) +ext_modules.append(ftgemm_extension) + # Cache operations. cache_extension = CUDAExtension( name="vllm.cache_ops", diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index 9a452d591e31..fe0ca8deeadc 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -7,6 +7,9 @@ fake_quantize_activation_per_tensor_absmax, fake_quantize_activation_per_token_absmax, ) +from vllm.ftgemm import FTGEMM +ftgemm = FTGEMM() + class W8A8B8O8Linear(torch.nn.Module): # For qkv_proj @@ -226,3 +229,122 @@ def from_float(module: torch.nn.Linear, input_scale): int8_module.a = alpha int8_module.inscale = torch.tensor(input_scale) return int8_module + +# use ftgemm a8w8o8 +class W8A8BFP32OFP32LinearWithSFactorCublas(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.float32, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + self.register_buffer('inscale', torch.tensor(inscale)) + + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.a = self.a.cpu() + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + self.inscale = self.inscale.to(*args, **kwargs) + self.inscale = self.inscale.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + # quant activation + x = (x / self.inscale).clamp(-128, 127).to(torch.int8) + # self.bias = self.bias.to(torch.float32) + y = torch.empty((x.shape[0], self.out_features), dtype=torch.int8, device=x.device) + ftgemm.linear_a8_w8_o8_(x, self.weight, y, self.a.item()) + # y = i8gemm.linear_a8_w8_bfp32_ofp32( + # x, self.weight, self.bias, self.a.item(), 1) + # int8 to float32 + y = y.to(torch.float32) + y = y.view(*x_shape[:-1], -1) + return y, None + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32LinearWithSFactorCublas( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) + int8_module.bias = mockbias.to(torch.float32) + int8_module.a = alpha + int8_module.inscale = torch.tensor(input_scale) + return int8_module + +# use ftgemm a8w8o8 +class W8A8BFP32OFP32LinearCublas(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.float32, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.a = self.a.cpu() + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + # self.bias = self.bias.to(torch.float32) + y = torch.empty((x.shape[0], self.out_features), dtype=torch.int8, device=x.device) + ftgemm.linear_a8_w8_o8_(x, self.weight, y, self.a.item()) + # y = i8gemm.linear_a8_w8_bfp32_ofp32( + # x, self.weight, self.bias, self.a.item(), 1) + # int8 to float32 + y = y.to(torch.float32) + y = y.view(*x_shape[:-1], -1) + return y, None + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32LinearCublas( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) + int8_module.bias = mockbias.to(torch.float32) + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + return int8_module \ No newline at end of file From 06cfa3f8e29580ebdbf39aee1b13aec7a1e2ce70 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Wed, 27 Sep 2023 16:27:14 +0800 Subject: [PATCH 21/52] fix cublas linear --- .../layers/int8_linear/w8a8linear.py | 54 ++++--------------- vllm/model_executor/models/llama.py | 14 ++--- 2 files changed, 18 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index fe0ca8deeadc..32434091ee7b 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -231,7 +231,7 @@ def from_float(module: torch.nn.Linear, input_scale): return int8_module # use ftgemm a8w8o8 -class W8A8BFP32OFP32LinearWithSFactorCublas(torch.nn.Module): +class W8A8OFP32LinearWithSFactorCublas(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): super().__init__() @@ -269,32 +269,16 @@ def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) # quant activation - x = (x / self.inscale).clamp(-128, 127).to(torch.int8) - # self.bias = self.bias.to(torch.float32) - y = torch.empty((x.shape[0], self.out_features), dtype=torch.int8, device=x.device) - ftgemm.linear_a8_w8_o8_(x, self.weight, y, self.a.item()) - # y = i8gemm.linear_a8_w8_bfp32_ofp32( - # x, self.weight, self.bias, self.a.item(), 1) - # int8 to float32 - y = y.to(torch.float32) + x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) + y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) + ftgemm.linear_a8_w8_o32_(x, self.weight, y) + y = y * self.a.item() y = y.view(*x_shape[:-1], -1) return y, None - @staticmethod - def from_float(module: torch.nn.Linear, input_scale): - int8_module = W8A8BFP32OFP32LinearWithSFactorCublas( - module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - alpha = input_scale * weight_scale - int8_module.weight = int8_weight - mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) - int8_module.bias = mockbias.to(torch.float32) - int8_module.a = alpha - int8_module.inscale = torch.tensor(input_scale) - return int8_module # use ftgemm a8w8o8 -class W8A8BFP32OFP32LinearCublas(torch.nn.Module): +class W8A8O32LinearCublas(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): super().__init__() @@ -325,26 +309,8 @@ def to(self, *args, **kwargs): def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) - # self.bias = self.bias.to(torch.float32) - y = torch.empty((x.shape[0], self.out_features), dtype=torch.int8, device=x.device) - ftgemm.linear_a8_w8_o8_(x, self.weight, y, self.a.item()) - # y = i8gemm.linear_a8_w8_bfp32_ofp32( - # x, self.weight, self.bias, self.a.item(), 1) - # int8 to float32 - y = y.to(torch.float32) + y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) + ftgemm.linear_a8_w8_o32_(x, self.weight, y) + y = y * self.a.item() y = y.view(*x_shape[:-1], -1) - return y, None - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale): - int8_module = W8A8BFP32OFP32LinearCublas( - module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - alpha = input_scale * weight_scale - int8_module.weight = int8_weight - mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) - int8_module.bias = mockbias.to(torch.float32) - int8_module.a = alpha - int8_module.input_scale = input_scale - int8_module.weight_scale = weight_scale - return int8_module \ No newline at end of file + return y, None \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5e6a3b3daa66..1d529b803bff 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,7 +37,9 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear -from vllm.model_executor.layers.int8_linear.w8a8linear import W8A8BFP32OFP32LinearWithSFactor, W8A8BFP32OFP32Linear +from vllm.model_executor.layers.int8_linear.w8a8linear import ( + W8A8OFP32LinearWithSFactorCublas, + W8A8O32LinearCublas) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -64,9 +66,9 @@ def __init__( self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" if self.use_int8: - self.gate_up_proj = W8A8BFP32OFP32Linear(hidden_size, - 2 * intermediate_size) - self.down_proj = W8A8BFP32OFP32LinearWithSFactor(intermediate_size, + self.gate_up_proj = W8A8O32LinearCublas(hidden_size, + 2 * intermediate_size) + self.down_proj = W8A8OFP32LinearWithSFactorCublas(intermediate_size, hidden_size) else: self.gate_up_proj = ParallelLinear.column(hidden_size, @@ -125,10 +127,10 @@ def __init__( self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" if self.use_int8: - self.qkv_proj = W8A8BFP32OFP32Linear( + self.qkv_proj = W8A8O32LinearCublas( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) - self.o_proj = W8A8BFP32OFP32LinearWithSFactor( + self.o_proj = W8A8OFP32LinearWithSFactorCublas( self.total_num_heads * self.head_dim, hidden_size) else: From 97b5c6945e6a6a21fa7979b415a3dcd8e6ee023e Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Wed, 27 Sep 2023 16:36:57 +0800 Subject: [PATCH 22/52] clean cublass gemm code --- csrc/ftgemm/CMakeLists.txt | 65 ------------------------------- csrc/ftgemm/allocator.h | 2 - csrc/ftgemm/cublasINT8MMWrapper.h | 1 + csrc/ftgemm/cuda_utils.cc | 57 +-------------------------- csrc/ftgemm/cuda_utils.h | 29 +------------- 5 files changed, 4 insertions(+), 150 deletions(-) delete mode 100644 csrc/ftgemm/CMakeLists.txt diff --git a/csrc/ftgemm/CMakeLists.txt b/csrc/ftgemm/CMakeLists.txt deleted file mode 100644 index 0425e4acb0e3..000000000000 --- a/csrc/ftgemm/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -cmake_minimum_required(VERSION 3.8) -set(CMAKE_CXX_STANDARD 14) -set(CMAKE_CUDA_STANDARD 14) - -find_package(CUDA REQUIRED) -find_package(Python REQUIRED) -set(Torch_DIR "/usr/local/lib/python3.9/site-packages/torch/share/cmake/Torch/") -find_package(Torch REQUIRED) -set(pybind11_DIR "/usr/local/lib/python3.9/site-packages/pybind11/share/cmake/pybind11") -find_package(pybind11 REQUIRED) - - -include_directories(${CUDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS}) -set(CUDA_LIBRARIES "/usr/local/cuda/lib64") -set(TORCH_LIBRARIES "/usr/local/lib/python3.9/site-packages/torch/lib") -link_directories(${CUDA_LIBRARIES} ${TORCH_LIBRARIES}) - -add_library(cuda_utils STATIC cuda_utils.cc) -set_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(cuda_utils PUBLIC -lcudart) - -add_library(cublasAlgoMap STATIC cublasAlgoMap.cc) -set_property(TARGET cublasAlgoMap PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET cublasAlgoMap PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(cublasAlgoMap PUBLIC -lcublas -lcudart -lcurand cuda_utils) - -add_library(cublasMMWrapper STATIC cublasMMWrapper.cc) -set_property(TARGET cublasMMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET cublasMMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(cublasMMWrapper PUBLIC -lcublas -lcudart -lcurand cublasAlgoMap cuda_utils) - -add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) -set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(cublasINT8MMWrapper PUBLIC -lcublasLt -lcudart -lcurand -lcublas cublasAlgoMap cublasMMWrapper cuda_utils) - -add_library(transformLayout STATIC transform_layout.cu) -set_property(TARGET transformLayout PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET transformLayout PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(transformLayout PUBLIC -lcudart -lcurand -lcublas) - - -find_package(Python COMPONENTS Interpreter Development REQUIRED) -pybind11_add_module(int8_gemm MODULE bindings.cpp) -target_link_libraries(int8_gemm PUBLIC -lpython3.9 -ltorch -ltorch_python -lcudart cublasINT8MMWrapper cublasAlgoMap transformLayout) - - - - - diff --git a/csrc/ftgemm/allocator.h b/csrc/ftgemm/allocator.h index 82e7da567b9b..7ac1345441f1 100644 --- a/csrc/ftgemm/allocator.h +++ b/csrc/ftgemm/allocator.h @@ -41,8 +41,6 @@ #include #endif -// #include "src/fastertransformer/utils/logger.h" - #if defined(CUDART_VERSION) && CUDART_VERSION < 11020 #define CUDA_MEMORY_POOL_DISABLED #endif diff --git a/csrc/ftgemm/cublasINT8MMWrapper.h b/csrc/ftgemm/cublasINT8MMWrapper.h index cbd2879a36b0..ab3de04692af 100644 --- a/csrc/ftgemm/cublasINT8MMWrapper.h +++ b/csrc/ftgemm/cublasINT8MMWrapper.h @@ -60,6 +60,7 @@ class cublasINT8MMWrapper : public cublasMMWrapper { int64_t strideb, int64_t stridec, const float alpha, const int8_t *ATransform, const int8_t *kernel); + // w8a8sfp32ofp32 void Gemm_f(float *res, int batchCount, int m, int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, const int8_t *ATransform, const int8_t *kernel); diff --git a/csrc/ftgemm/cuda_utils.cc b/csrc/ftgemm/cuda_utils.cc index be8c6d7fb46b..dc0bc509fd81 100644 --- a/csrc/ftgemm/cuda_utils.cc +++ b/csrc/ftgemm/cuda_utils.cc @@ -49,11 +49,6 @@ template void print_to_file(const float *result, const int size, template void print_to_file(const half *result, const int size, const char *file, cudaStream_t stream, std::ios::openmode open_mode); -// #ifdef ENABLE_BF16 -// template void print_to_file( -// const __nv_bfloat16* result, const int size, const char* file, -// cudaStream_t stream, std::ios::openmode open_mode); -// #endif template void print_abs_mean(const T *buf, uint size, cudaStream_t stream, @@ -97,18 +92,12 @@ template void print_abs_mean(const float *buf, uint size, cudaStream_t stream, std::string name); template void print_abs_mean(const half *buf, uint size, cudaStream_t stream, std::string name); -// #ifdef ENABLE_BF16 -// template void print_abs_mean(const __nv_bfloat16* buf, uint size, -// cudaStream_t stream, std::string name); #endif template void print_abs_mean(const int *buf, uint size, cudaStream_t stream, std::string name); template void print_abs_mean(const uint *buf, uint size, cudaStream_t stream, std::string name); template void print_abs_mean(const int8_t *buf, uint size, cudaStream_t stream, std::string name); -// #ifdef ENABLE_FP8 -// template void print_abs_mean(const __nv_fp8_e4m3* buf, uint size, -// cudaStream_t stream, std::string name); #endif template void print_to_screen(const T *result, const int size) { if (result == nullptr) { @@ -126,15 +115,10 @@ template void print_to_screen(const T *result, const int size) { template void print_to_screen(const float *result, const int size); template void print_to_screen(const half *result, const int size); -// #ifdef ENABLE_BF16 -// template void print_to_screen(const __nv_bfloat16* result, const int size); -// #endif template void print_to_screen(const int *result, const int size); template void print_to_screen(const uint *result, const int size); template void print_to_screen(const bool *result, const int size); -// #ifdef ENABLE_FP8 -// template void print_to_screen(const __nv_fp8_e4m3* result, const int size); -// #endif + template void printMatrix(T *ptr, int m, int k, int stride, bool is_device_ptr) { @@ -174,9 +158,6 @@ template void printMatrix(float *ptr, int m, int k, int stride, bool is_device_ptr); template void printMatrix(half *ptr, int m, int k, int stride, bool is_device_ptr); -// #ifdef ENABLE_BF16 -// template void printMatrix(__nv_bfloat16* ptr, int m, int k, int stride, bool -// is_device_ptr); #endif void printMatrix(unsigned long long *ptr, int m, int k, int stride, bool is_device_ptr) { @@ -297,9 +278,6 @@ template void check_max_val(const T *result, const int size) { template void check_max_val(const float *result, const int size); template void check_max_val(const half *result, const int size); -// #ifdef ENABLE_BF16 -// template void check_max_val(const __nv_bfloat16* result, const int size); -// #endif template void check_abs_mean_val(const T *result, const int size) { T *tmp = new T[size]; @@ -314,9 +292,6 @@ template void check_abs_mean_val(const T *result, const int size) { template void check_abs_mean_val(const float *result, const int size); template void check_abs_mean_val(const half *result, const int size); -// #ifdef ENABLE_BF16 -// template void check_abs_mean_val(const __nv_bfloat16* result, const int -// size); #endif /* ***************************** common utils ****************************** */ @@ -348,34 +323,4 @@ cudaError_t getSetDevice(int i_device, int *o_device) { return cudaSuccess; } -// FtCudaDataType getModelFileType(std::string ini_file, std::string -// section_name) -// { -// FtCudaDataType model_file_type; -// INIReader reader = INIReader(ini_file); -// if (reader.ParseError() < 0) { -// FT_LOG_WARNING("Can't load %s. Use FP32 as default", -// ini_file.c_str()); model_file_type = FtCudaDataType::FP32; -// } -// else { -// std::string weight_data_type_str = -// std::string(reader.Get(section_name, "weight_data_type")); if -// (weight_data_type_str.find("fp32") != std::string::npos) { -// model_file_type = FtCudaDataType::FP32; -// } -// else if (weight_data_type_str.find("fp16") != std::string::npos) { -// model_file_type = FtCudaDataType::FP16; -// } -// else if (weight_data_type_str.find("bf16") != std::string::npos) { -// model_file_type = FtCudaDataType::BF16; -// } -// else { -// FT_LOG_WARNING("Invalid type %s. Use FP32 as default", -// weight_data_type_str.c_str()); model_file_type = -// FtCudaDataType::FP32; -// } -// } -// return model_file_type; -// } - /* ************************** end of common utils ************************** */ diff --git a/csrc/ftgemm/cuda_utils.h b/csrc/ftgemm/cuda_utils.h index 33713cf11757..e85d38bf53d4 100644 --- a/csrc/ftgemm/cuda_utils.h +++ b/csrc/ftgemm/cuda_utils.h @@ -16,10 +16,6 @@ #pragma once -// #include "3rdparty/INIReader.h" -// #include "cuda_bf16_wrapper.h" -// #include "src/fastertransformer/utils/logger.h" - #include #include #include @@ -346,27 +342,12 @@ template <> struct getTypeFromCudaDataType { using Type = half; }; -// #ifdef ENABLE_BF16 -// template<> -// struct getTypeFromCudaDataType { -// using Type = __nv_bfloat16; -// }; -// #endif - -// FtCudaDataType getModelFileType(std::string ini_file, std::string -// section_name); // clang-format off template struct packed_type; template <> struct packed_type { using type = float; }; // we don't need to pack float by default template <> struct packed_type { using type = half2; }; -// #ifdef ENABLE_BF16 -// template<> -// struct packed_type<__nv_bfloat16> { -// using type = __nv_bfloat162; -// }; -// #endif template struct num_elems; template <> struct num_elems { static constexpr int value = 1; }; @@ -374,10 +355,7 @@ template <> struct num_elems { static constexpr int va template <> struct num_elems { static constexpr int value = 4; }; template <> struct num_elems { static constexpr int value = 1; }; template <> struct num_elems { static constexpr int value = 2; }; -// #ifdef ENABLE_BF16 -// template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; -// template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; -// #endif + template struct packed_as; template struct packed_as { using type = T; }; @@ -386,10 +364,7 @@ template<> struct packed_as { using type = template<> struct packed_as { using type = int16_t; }; template<> struct packed_as { using type = int2; }; template<> struct packed_as { using type = half; }; -// #ifdef ENABLE_BF16 -// template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; -// template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; -// #endif + inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } From 4d5c1a7eaf01cafcad998c07d5ab04ae931f3e5c Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Wed, 27 Sep 2023 17:17:09 +0800 Subject: [PATCH 23/52] code clean --- csrc/{ftgemm => int8gemm/cublas}/allocator.h | 0 csrc/{ftgemm => int8gemm/cublas}/bindings.cpp | 51 ++++++++++--------- .../cublas}/cublasAlgoMap.cc | 0 .../cublas}/cublasAlgoMap.h | 0 .../cublas}/cublasINT8MMWrapper.cc | 0 .../cublas}/cublasINT8MMWrapper.h | 0 .../cublas}/cublasMMWrapper.cc | 0 .../cublas}/cublasMMWrapper.h | 0 .../{ftgemm => int8gemm/cublas}/cuda_utils.cc | 0 csrc/{ftgemm => int8gemm/cublas}/cuda_utils.h | 0 .../cublas}/int8_utils.cuh | 0 .../cublas}/transform_layout.cu | 0 .../cublas}/transform_layout.h | 0 csrc/int8gemm/{ => cutlass}/bindings.cpp | 0 csrc/int8gemm/{ => cutlass}/bmm.cu | 0 csrc/int8gemm/{ => cutlass}/fused.cu | 0 csrc/int8gemm/{ => cutlass}/include/bmm.h | 0 csrc/int8gemm/{ => cutlass}/include/common.h | 0 csrc/int8gemm/{ => cutlass}/include/fused.h | 0 csrc/int8gemm/{ => cutlass}/include/linear.h | 0 csrc/int8gemm/{ => cutlass}/linear.cu | 0 setup.py | 30 +++++------ .../layers/int8_linear/w8a8linear.py | 13 +++-- 23 files changed, 48 insertions(+), 46 deletions(-) rename csrc/{ftgemm => int8gemm/cublas}/allocator.h (100%) rename csrc/{ftgemm => int8gemm/cublas}/bindings.cpp (81%) rename csrc/{ftgemm => int8gemm/cublas}/cublasAlgoMap.cc (100%) rename csrc/{ftgemm => int8gemm/cublas}/cublasAlgoMap.h (100%) rename csrc/{ftgemm => int8gemm/cublas}/cublasINT8MMWrapper.cc (100%) rename csrc/{ftgemm => int8gemm/cublas}/cublasINT8MMWrapper.h (100%) rename csrc/{ftgemm => int8gemm/cublas}/cublasMMWrapper.cc (100%) rename csrc/{ftgemm => int8gemm/cublas}/cublasMMWrapper.h (100%) rename csrc/{ftgemm => int8gemm/cublas}/cuda_utils.cc (100%) rename csrc/{ftgemm => int8gemm/cublas}/cuda_utils.h (100%) rename csrc/{ftgemm => int8gemm/cublas}/int8_utils.cuh (100%) rename csrc/{ftgemm => int8gemm/cublas}/transform_layout.cu (100%) rename csrc/{ftgemm => int8gemm/cublas}/transform_layout.h (100%) rename csrc/int8gemm/{ => cutlass}/bindings.cpp (100%) rename csrc/int8gemm/{ => cutlass}/bmm.cu (100%) rename csrc/int8gemm/{ => cutlass}/fused.cu (100%) rename csrc/int8gemm/{ => cutlass}/include/bmm.h (100%) rename csrc/int8gemm/{ => cutlass}/include/common.h (100%) rename csrc/int8gemm/{ => cutlass}/include/fused.h (100%) rename csrc/int8gemm/{ => cutlass}/include/linear.h (100%) rename csrc/int8gemm/{ => cutlass}/linear.cu (100%) diff --git a/csrc/ftgemm/allocator.h b/csrc/int8gemm/cublas/allocator.h similarity index 100% rename from csrc/ftgemm/allocator.h rename to csrc/int8gemm/cublas/allocator.h diff --git a/csrc/ftgemm/bindings.cpp b/csrc/int8gemm/cublas/bindings.cpp similarity index 81% rename from csrc/ftgemm/bindings.cpp rename to csrc/int8gemm/cublas/bindings.cpp index 9623a37c4dd5..fa35fe686dcc 100644 --- a/csrc/ftgemm/bindings.cpp +++ b/csrc/int8gemm/cublas/bindings.cpp @@ -1,18 +1,21 @@ +/* + gemm methods are adapted from ft +*/ #include #include #include "cublasAlgoMap.h" #include "cublasINT8MMWrapper.h" #include "transform_layout.h" -class FTGEMM { +class I8CUGEMM { private: cublasINT8MMWrapper *int8_gemm_wrapper = nullptr; // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); public: - FTGEMM(); - ~FTGEMM(); + I8CUGEMM(); + ~I8CUGEMM(); void linear_a8_w8_o32(torch::Tensor &input, torch::Tensor &weight, torch::Tensor &output); @@ -30,7 +33,7 @@ class FTGEMM { void transform_row_to_turing(torch::Tensor &input, torch::Tensor &out); }; -FTGEMM::FTGEMM() { +I8CUGEMM::I8CUGEMM() { // cublasAlgoMap *cublas_algo_map = new cublasAlgoMap("igemm_config.in"); cublasAlgoMap *cublas_algo_map = new cublasAlgoMap(); std::mutex *cublas_wrapper_mutex = new std::mutex(); @@ -47,9 +50,9 @@ FTGEMM::FTGEMM() { cublas_wrapper_mutex, use_ORDER_COL32_2R_4R4); } -FTGEMM::~FTGEMM() {} +I8CUGEMM::~I8CUGEMM() {} -void FTGEMM::linear_a8_w8_o32(torch::Tensor &input, // INT8 +void I8CUGEMM::linear_a8_w8_o32(torch::Tensor &input, // INT8 torch::Tensor &weight, // INT8 torch::Tensor &out // INT32 ) { @@ -66,7 +69,7 @@ void FTGEMM::linear_a8_w8_o32(torch::Tensor &input, // INT8 weight_ptr); } -void FTGEMM::linear_a8_w8_o32_(torch::Tensor &input, // INT8 +void I8CUGEMM::linear_a8_w8_o32_(torch::Tensor &input, // INT8 torch::Tensor &weight, // INT8 torch::Tensor &out // INT32 ) { @@ -83,7 +86,7 @@ void FTGEMM::linear_a8_w8_o32_(torch::Tensor &input, // INT8 weight_ptr); } -void FTGEMM::linear_a8_w8_o8(torch::Tensor &input, // INT8 +void I8CUGEMM::linear_a8_w8_o8(torch::Tensor &input, // INT8 torch::Tensor &weight, // INT8 torch::Tensor &out, // INT8 float alpha // FP32 @@ -101,7 +104,7 @@ void FTGEMM::linear_a8_w8_o8(torch::Tensor &input, // INT8 weight_ptr); } -void FTGEMM::linear_a8_w8_o8_(torch::Tensor &input, // INT8 +void I8CUGEMM::linear_a8_w8_o8_(torch::Tensor &input, // INT8 torch::Tensor &weight, // INT8 torch::Tensor &out, // INT8 float alpha // FP32 @@ -119,7 +122,7 @@ void FTGEMM::linear_a8_w8_o8_(torch::Tensor &input, // INT8 weight_ptr); } -void FTGEMM::linear_a8_w8_ofp32(torch::Tensor &input, // INT8 +void I8CUGEMM::linear_a8_w8_ofp32(torch::Tensor &input, // INT8 torch::Tensor &weight, // INT8 torch::Tensor &out, // INT8 float alpha // FP32 @@ -137,7 +140,7 @@ void FTGEMM::linear_a8_w8_ofp32(torch::Tensor &input, // INT8 weight_ptr); } -void FTGEMM::transform_row_to_col32(torch::Tensor &input, torch::Tensor &out) { +void I8CUGEMM::transform_row_to_col32(torch::Tensor &input, torch::Tensor &out) { int m = input.size(0); int n = input.size(1); int m_ = out.size(0); @@ -152,7 +155,7 @@ void FTGEMM::transform_row_to_col32(torch::Tensor &input, torch::Tensor &out) { // invokeRowMajorToCOL32(out_ptr, input_ptr, m, n, stream); } -void FTGEMM::transform_col32_to_row(torch::Tensor &input, torch::Tensor &out) { +void I8CUGEMM::transform_col32_to_row(torch::Tensor &input, torch::Tensor &out) { int m = input.size(0); int n = input.size(1); int m_ = out.size(0); @@ -167,7 +170,7 @@ void FTGEMM::transform_col32_to_row(torch::Tensor &input, torch::Tensor &out) { // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); } -void FTGEMM::transform_row_to_ampere(torch::Tensor &input, torch::Tensor &out) { +void I8CUGEMM::transform_row_to_ampere(torch::Tensor &input, torch::Tensor &out) { int m = input.size(0); int n = input.size(1); int m_ = out.size(0); @@ -182,7 +185,7 @@ void FTGEMM::transform_row_to_ampere(torch::Tensor &input, torch::Tensor &out) { // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); } -void FTGEMM::transform_row_to_turing(torch::Tensor &input, torch::Tensor &out) { +void I8CUGEMM::transform_row_to_turing(torch::Tensor &input, torch::Tensor &out) { int m = input.size(0); int n = input.size(1); int m_ = out.size(0); @@ -198,15 +201,15 @@ void FTGEMM::transform_row_to_turing(torch::Tensor &input, torch::Tensor &out) { } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - pybind11::class_(m, "FTGEMM") + pybind11::class_(m, "I8CUGEMM") .def(pybind11::init<>()) - .def("linear_a8_w8_o32", &FTGEMM::linear_a8_w8_o32) - .def("linear_a8_w8_o8", &FTGEMM::linear_a8_w8_o8) - .def("linear_a8_w8_o8_", &FTGEMM::linear_a8_w8_o8_) - .def("linear_a8_w8_o32_", &FTGEMM::linear_a8_w8_o32_) - .def("linear_a8_w8_ofp32", &FTGEMM::linear_a8_w8_ofp32) - .def("transform_row_to_col32", &FTGEMM::transform_row_to_col32) - .def("transform_col32_to_row", &FTGEMM::transform_col32_to_row) - .def("transform_row_to_ampere", &FTGEMM::transform_row_to_ampere) - .def("transform_row_to_turing", &FTGEMM::transform_row_to_turing); + .def("linear_a8_w8_o32", &I8CUGEMM::linear_a8_w8_o32) + .def("linear_a8_w8_o8", &I8CUGEMM::linear_a8_w8_o8) + .def("linear_a8_w8_o8_", &I8CUGEMM::linear_a8_w8_o8_) + .def("linear_a8_w8_o32_", &I8CUGEMM::linear_a8_w8_o32_) + .def("linear_a8_w8_ofp32", &I8CUGEMM::linear_a8_w8_ofp32) + .def("transform_row_to_col32", &I8CUGEMM::transform_row_to_col32) + .def("transform_col32_to_row", &I8CUGEMM::transform_col32_to_row) + .def("transform_row_to_ampere", &I8CUGEMM::transform_row_to_ampere) + .def("transform_row_to_turing", &I8CUGEMM::transform_row_to_turing); } diff --git a/csrc/ftgemm/cublasAlgoMap.cc b/csrc/int8gemm/cublas/cublasAlgoMap.cc similarity index 100% rename from csrc/ftgemm/cublasAlgoMap.cc rename to csrc/int8gemm/cublas/cublasAlgoMap.cc diff --git a/csrc/ftgemm/cublasAlgoMap.h b/csrc/int8gemm/cublas/cublasAlgoMap.h similarity index 100% rename from csrc/ftgemm/cublasAlgoMap.h rename to csrc/int8gemm/cublas/cublasAlgoMap.h diff --git a/csrc/ftgemm/cublasINT8MMWrapper.cc b/csrc/int8gemm/cublas/cublasINT8MMWrapper.cc similarity index 100% rename from csrc/ftgemm/cublasINT8MMWrapper.cc rename to csrc/int8gemm/cublas/cublasINT8MMWrapper.cc diff --git a/csrc/ftgemm/cublasINT8MMWrapper.h b/csrc/int8gemm/cublas/cublasINT8MMWrapper.h similarity index 100% rename from csrc/ftgemm/cublasINT8MMWrapper.h rename to csrc/int8gemm/cublas/cublasINT8MMWrapper.h diff --git a/csrc/ftgemm/cublasMMWrapper.cc b/csrc/int8gemm/cublas/cublasMMWrapper.cc similarity index 100% rename from csrc/ftgemm/cublasMMWrapper.cc rename to csrc/int8gemm/cublas/cublasMMWrapper.cc diff --git a/csrc/ftgemm/cublasMMWrapper.h b/csrc/int8gemm/cublas/cublasMMWrapper.h similarity index 100% rename from csrc/ftgemm/cublasMMWrapper.h rename to csrc/int8gemm/cublas/cublasMMWrapper.h diff --git a/csrc/ftgemm/cuda_utils.cc b/csrc/int8gemm/cublas/cuda_utils.cc similarity index 100% rename from csrc/ftgemm/cuda_utils.cc rename to csrc/int8gemm/cublas/cuda_utils.cc diff --git a/csrc/ftgemm/cuda_utils.h b/csrc/int8gemm/cublas/cuda_utils.h similarity index 100% rename from csrc/ftgemm/cuda_utils.h rename to csrc/int8gemm/cublas/cuda_utils.h diff --git a/csrc/ftgemm/int8_utils.cuh b/csrc/int8gemm/cublas/int8_utils.cuh similarity index 100% rename from csrc/ftgemm/int8_utils.cuh rename to csrc/int8gemm/cublas/int8_utils.cuh diff --git a/csrc/ftgemm/transform_layout.cu b/csrc/int8gemm/cublas/transform_layout.cu similarity index 100% rename from csrc/ftgemm/transform_layout.cu rename to csrc/int8gemm/cublas/transform_layout.cu diff --git a/csrc/ftgemm/transform_layout.h b/csrc/int8gemm/cublas/transform_layout.h similarity index 100% rename from csrc/ftgemm/transform_layout.h rename to csrc/int8gemm/cublas/transform_layout.h diff --git a/csrc/int8gemm/bindings.cpp b/csrc/int8gemm/cutlass/bindings.cpp similarity index 100% rename from csrc/int8gemm/bindings.cpp rename to csrc/int8gemm/cutlass/bindings.cpp diff --git a/csrc/int8gemm/bmm.cu b/csrc/int8gemm/cutlass/bmm.cu similarity index 100% rename from csrc/int8gemm/bmm.cu rename to csrc/int8gemm/cutlass/bmm.cu diff --git a/csrc/int8gemm/fused.cu b/csrc/int8gemm/cutlass/fused.cu similarity index 100% rename from csrc/int8gemm/fused.cu rename to csrc/int8gemm/cutlass/fused.cu diff --git a/csrc/int8gemm/include/bmm.h b/csrc/int8gemm/cutlass/include/bmm.h similarity index 100% rename from csrc/int8gemm/include/bmm.h rename to csrc/int8gemm/cutlass/include/bmm.h diff --git a/csrc/int8gemm/include/common.h b/csrc/int8gemm/cutlass/include/common.h similarity index 100% rename from csrc/int8gemm/include/common.h rename to csrc/int8gemm/cutlass/include/common.h diff --git a/csrc/int8gemm/include/fused.h b/csrc/int8gemm/cutlass/include/fused.h similarity index 100% rename from csrc/int8gemm/include/fused.h rename to csrc/int8gemm/cutlass/include/fused.h diff --git a/csrc/int8gemm/include/linear.h b/csrc/int8gemm/cutlass/include/linear.h similarity index 100% rename from csrc/int8gemm/include/linear.h rename to csrc/int8gemm/cutlass/include/linear.h diff --git a/csrc/int8gemm/linear.cu b/csrc/int8gemm/cutlass/linear.cu similarity index 100% rename from csrc/int8gemm/linear.cu rename to csrc/int8gemm/cutlass/linear.cu diff --git a/setup.py b/setup.py index 1707e98fcc89..8c886e7a892a 100644 --- a/setup.py +++ b/setup.py @@ -95,12 +95,12 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: i8gemm_extension = CUDAExtension( name='vllm.i8gemm', sources=[ - 'csrc/int8gemm/linear.cu', - 'csrc/int8gemm/bmm.cu', - 'csrc/int8gemm/fused.cu', - 'csrc/int8gemm/bindings.cpp', + 'csrc/int8gemm/cutlass/linear.cu', + 'csrc/int8gemm/cutlass/bmm.cu', + 'csrc/int8gemm/cutlass/fused.cu', + 'csrc/int8gemm/cutlass/bindings.cpp', ], - include_dirs=['csrc/int8gemm/include'], + include_dirs=['csrc/int8gemm/cutlass/include'], extra_link_args=['-lcublas_static', '-lcublasLt_static', '-lculibos', '-lcudart', '-lcudart_static', '-lrt', '-lpthread', '-ldl', '-L/usr/lib/x86_64-linux-gnu/'], @@ -109,23 +109,23 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ) ext_modules.append(i8gemm_extension) -# FTGEMM(cutlass required) -ftgemm_extension = CUDAExtension( - name='vllm.ftgemm', +# int8gemm(cutlass required) +i8cugemm_extension = CUDAExtension( + name='vllm.i8cugemm', sources=[ - 'csrc/ftgemm/bindings.cpp', - 'csrc/ftgemm/cublasAlgoMap.cc', - 'csrc/ftgemm/cublasINT8MMWrapper.cc', - 'csrc/ftgemm/cublasMMWrapper.cc', - 'csrc/ftgemm/cuda_utils.cc', - 'csrc/ftgemm/transform_layout.cu' + 'csrc/int8gemm/cublas/bindings.cpp', + 'csrc/int8gemm/cublas/cublasAlgoMap.cc', + 'csrc/int8gemm/cublas/cublasINT8MMWrapper.cc', + 'csrc/int8gemm/cublas/cublasMMWrapper.cc', + 'csrc/int8gemm/cublas/cuda_utils.cc', + 'csrc/int8gemm/cublas/transform_layout.cu' ], extra_compile_args={ "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, ) -ext_modules.append(ftgemm_extension) +ext_modules.append(i8cugemm_extension) # Cache operations. cache_extension = CUDAExtension( diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index 32434091ee7b..14144b436361 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -7,9 +7,8 @@ fake_quantize_activation_per_tensor_absmax, fake_quantize_activation_per_token_absmax, ) -from vllm.ftgemm import FTGEMM -ftgemm = FTGEMM() - +from vllm.i8cugemm import I8CUGEMM +i8cugemm = I8CUGEMM() class W8A8B8O8Linear(torch.nn.Module): # For qkv_proj @@ -230,7 +229,7 @@ def from_float(module: torch.nn.Linear, input_scale): int8_module.inscale = torch.tensor(input_scale) return int8_module -# use ftgemm a8w8o8 +# use cublasgemm a8w8o8 class W8A8OFP32LinearWithSFactorCublas(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): @@ -271,13 +270,13 @@ def forward(self, x): # quant activation x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) - ftgemm.linear_a8_w8_o32_(x, self.weight, y) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) y = y * self.a.item() y = y.view(*x_shape[:-1], -1) return y, None -# use ftgemm a8w8o8 +# use cublasgemm a8w8o8 class W8A8O32LinearCublas(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): @@ -310,7 +309,7 @@ def forward(self, x): x_shape = x.shape x = x.view(-1, x_shape[-1]) y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) - ftgemm.linear_a8_w8_o32_(x, self.weight, y) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) y = y * self.a.item() y = y.view(*x_shape[:-1], -1) return y, None \ No newline at end of file From 9176b1f72791fd63a5290087c373310f1f8b5523 Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Wed, 27 Sep 2023 17:12:22 +0800 Subject: [PATCH 24/52] support generating kv quant parameters and evaluting kv quant models --- benchmarks/benchmark_evaluation.py | 181 ++++++++++++++++ benchmarks/mmlu_template.py | 119 +++++++++++ csrc/quant_utils.cuh | 1 + examples/offline_inference_quant.py | 107 ++++++++++ tests/kernels/test_cache.py | 239 ++++++++------------- vllm/engine/llm_engine.py | 4 +- vllm/kv_quant/calib_dataloader.py | 311 ++++++++++++++++++++++++++++ vllm/kv_quant/calibrate.py | 117 +++++++++++ vllm/kv_quant/calibration.py | 307 +++++++++++++++++++++++++++ vllm/kv_quant/export_kv_params.py | 123 +++++++++++ vllm/kv_quant/observer.py | 192 +++++++++++++++++ vllm/kv_quant/utils.py | 164 +++++++++++++++ 12 files changed, 1706 insertions(+), 159 deletions(-) create mode 100644 benchmarks/benchmark_evaluation.py create mode 100644 benchmarks/mmlu_template.py create mode 100644 examples/offline_inference_quant.py create mode 100644 vllm/kv_quant/calib_dataloader.py create mode 100644 vllm/kv_quant/calibrate.py create mode 100644 vllm/kv_quant/calibration.py create mode 100644 vllm/kv_quant/export_kv_params.py create mode 100644 vllm/kv_quant/observer.py create mode 100644 vllm/kv_quant/utils.py diff --git a/benchmarks/benchmark_evaluation.py b/benchmarks/benchmark_evaluation.py new file mode 100644 index 000000000000..4ac9af033098 --- /dev/null +++ b/benchmarks/benchmark_evaluation.py @@ -0,0 +1,181 @@ +import argparse +# import asyncio +# import json +import os +# import random +# import time +from typing import List, Tuple, Dict + +# import aiohttp +import numpy as np +import pandas as pd +# from transformers import PreTrainedTokenizerBase +# from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm import LLM, SamplingParams, RequestOutput +from mmlu_template import MMLUTemplate + +TEMPLATE_REGITRY = { + "mmlu": MMLUTemplate, +} + + +def sample_requests( + # dataset_path: str, + # num_requests: int, + # tokenizer: PreTrainedTokenizerBase, + dev_data_path: str, + test_data_path: str, + subjects: List[str], + dataset_template: str = "mmlu", + is_analyse: bool = False, +) -> List[Tuple[str, int, int]]: + # Load the dataset. + nums_questions = [] + dataset = [] + labels = [] + template_class = TEMPLATE_REGITRY[dataset_template] + for subject in subjects: + test_dataset = pd.read_csv(os.path.join(test_data_path, subject + "_test.csv"), header=None) + nums_questions.append(len(test_dataset)) + template = template_class(subject, os.path.join(dev_data_path, subject + "_dev.csv"), is_analyse) + for idx in range(len(test_dataset)): + prompt = template.getTemplate(test_dataset, idx) + dataset.append(prompt) + labels.append(test_dataset.iloc[idx, -1]) + return dataset, labels, nums_questions + + +def run_vllm( + requests: List[str], + output_len: int, + model: str, + tokenizer: str, + kv_cache_dtype: str = "int8", + kv_quant_params_path: str = None, + tensor_parallel_size: int = 1, + seed: int = 0, + n: int = 1, + use_beam_search: bool = False, + trust_remote_code: bool = False, +) -> List[RequestOutput]: + llm = LLM( + model=model, + tokenizer=tokenizer, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + kv_cache_dtype=kv_cache_dtype, + kv_quant_params_path=kv_quant_params_path, + ) + for prompt in requests: + sampling_params = SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + ) + # FIXME(woosuk): Do not use internal method. + llm._add_request( + prompt=prompt, + prompt_token_ids=None, + sampling_params=sampling_params, + ) + + # FIXME(woosuk): Do use internal method. + return llm._run_engine(use_tqdm=True) + + +def evalute( + request_outputs: List[RequestOutput], + labels: List[str], + nums_questions: List[int], + subjects: List[str], + dataset_template: str = "mmlu", +) -> Dict[str, float]: + template_class = TEMPLATE_REGITRY[dataset_template] + pred = [template_class.findAnswer(r.outputs[0].text) for r in request_outputs] + ids = np.cumsum(nums_questions) + lhs = 0 + accs: List[float] = [] + for rhs in ids: + pred_paritition = np.array(pred[lhs: rhs]) + labels_partition = np.array(labels[lhs: rhs]) + acc = np.mean(pred_paritition == labels_partition) + accs.append(acc) + sub2acc = {sub: acc for sub, acc in zip(subjects, accs)} + return sub2acc + + +def main(args: argparse.Namespace): + subjects = [ + "abstract_algebra", + ] + dataset, labels, nums_questions = sample_requests( + args.dev_data_path, + args.test_data_path, + subjects, + is_analyse=args.is_analyse + ) + request_outputs = run_vllm( + dataset, + args.output_len, + args.model, + args.tokenizer, + args.kv_cache_dtype, + args.kv_quant_params_path, + args.tensor_parallel_size, + args.seed, args.n, + args.use_beam_search, + args.trust_remote_code, + ) + foo = request_outputs[0] + print(foo.outputs[0].text) + assert False + sub2acc = evalute( + request_outputs, + labels, + nums_questions, + subjects, + ) + print(sub2acc) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="evaluation for quantization.") + + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument("--dev-data-path", + type=str, + default=None, + help="path to few-shot dataset") + parser.add_argument("--test-data-path", + type=str, + default=None, + help="path to test dataset") + parser.add_argument("--is-analyse", + action="store_true") + parser.add_argument("--output-len", + type=int, + default=100, + help="nums of max token for evaluation outputs") + parser.add_argument("--kv-cache-dtype", + type=str, + default="int8") + parser.add_argument("--kv-quant-params-path", + type=str, + default=None) + args = parser.parse_args() + main(args) diff --git a/benchmarks/mmlu_template.py b/benchmarks/mmlu_template.py new file mode 100644 index 000000000000..81a7f8bc6128 --- /dev/null +++ b/benchmarks/mmlu_template.py @@ -0,0 +1,119 @@ +import pandas as pd +import json +from langchain.prompts import PromptTemplate + +template = PromptTemplate( + input_variables=["question", "A", "B", "C", "D", "Answer"], + template= + """ +USER: {question} +A. {A} +B. {B} +C. {C} +D. {D} ASSISTANT: Answer: {Answer} +""", +) + +template_with_analyse = PromptTemplate( + input_variables=["question", "A", "B", "C", "D"], + template= + """ +Q:{question} +(A) {A} (B) {B} (C) {C} (D) {D} +A: Let's think step by step. +""", +) + + +def gen_prompt(train_df, subject, k=1): + prompt = "SYSTEM: The following are multiple choice questions (with answers) about {}," \ + "Please select the correct answer from the options.".format(subject.replace('_', ' ')) + + for i in range(k): + prompt += template.format(question=train_df.iloc[i, 0], + A=train_df.iloc[i, 1], + B=train_df.iloc[i, 2], + C=train_df.iloc[i, 3], + D=train_df.iloc[i, 4], + Answer=train_df.iloc[i, 5] + )[1:-1] + return prompt + + +## add an abstract base class or common base class for generality +class MMLUTemplate(): + + def __init__(self, subject, file_path, is_analyse): + self.fiveShotTemplate = "" + self.file_path = file_path + self.subject = subject + self.choices = ["A", "B", "C", "D"] + self.is_analyse = is_analyse + self.few_shot_template = "" + if not is_analyse: + self.getFewShotBaseTemplates() + else: + self.getFewShotBaseTemplateAnalyse() + + def getFewShotBaseTemplates(self, k=5): + """few_shot模板不带分析""" + dev_df = pd.read_csv(self.file_path, header=None) + + self.few_shot_template = gen_prompt(dev_df, self.subject, k) + return self.few_shot_template + + def getFewShotBaseTemplateAnalyse(self): + """few_shot模板带分析,更改json文件就行""" + mmlu_prompt = json.load(open('templates/lib_prompt/mmlu-cot.json')) + self.few_shot_template = mmlu_prompt[self.subject] + return self.few_shot_template + + def getTemplate(self, test_df, i): + """获得模板""" + if self.is_analyse: + templ = template_with_analyse.format( + question=test_df.iloc[i, 0], + A=test_df.iloc[i, 1], + B=test_df.iloc[i, 2], + C=test_df.iloc[i, 3], + D=test_df.iloc[i, 4] + ) + + return self.few_shot_template + "\n" + templ + + else: + prompt_end = template.format( + question=test_df.iloc[i, 0], + A=test_df.iloc[i, 1], + B=test_df.iloc[i, 2], + C=test_df.iloc[i, 3], + D=test_df.iloc[i, 4], + Answer='')[1:-5] + return self.few_shot_template + prompt_end + @staticmethod + def findAnswer(res): + """解析函数""" + # print("模型输出为:", res) + d = "NO" + for d_ in res: + if 65 <= ord(d_) <= 68: + d = d_ + break + # print("答案解析为:", d) + return d + + @staticmethod + def findAnwerUsingRule(res): + # print("模型输出为:", res) + result = "NO" + pattern = 'the answer is (' + try: + pred = res.lower().split(pattern)[1][0] + + if 65 <= ord(pred.upper()) <= 68: + result = pred.upper() + except: + pass + + # print("答案解析为:",result) + return result diff --git a/csrc/quant_utils.cuh b/csrc/quant_utils.cuh index 6b84931ac228..d26e754bb40e 100644 --- a/csrc/quant_utils.cuh +++ b/csrc/quant_utils.cuh @@ -1,3 +1,4 @@ +// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp #pragma once #include diff --git a/examples/offline_inference_quant.py b/examples/offline_inference_quant.py new file mode 100644 index 000000000000..29589ce30c23 --- /dev/null +++ b/examples/offline_inference_quant.py @@ -0,0 +1,107 @@ +import argparse +import os +from typing import List, Tuple, Dict + +import numpy as np +import pandas as pd +from vllm import LLM, SamplingParams, RequestOutput +from benchmarks.mmlu_template import MMLUTemplate + + +def sample_requests( + # dataset_path: str, + # num_requests: int, + # tokenizer: PreTrainedTokenizerBase, + dev_data_path: str, + test_data_path: str, + subjects: List[str], + # dataset_template: str = "mmlu", + is_analyse: bool = False, +) -> List[Tuple[str, int, int]]: + # Load the dataset. + nums_questions = [] + dataset = [] + labels = [] + template_class = MMLUTemplate + for subject in subjects: + test_dataset = pd.read_csv(os.path.join(test_data_path, subject + "_test.csv"), header=None) + nums_questions.append(len(test_dataset)) + template = template_class(subject, os.path.join(dev_data_path, subject + "_dev.csv"), is_analyse) + for idx in range(len(test_dataset)): + prompt = template.getTemplate(test_dataset, idx) + dataset.append(prompt) + labels.append(test_dataset.iloc[idx, -1]) + return dataset, labels, nums_questions + + +def main(args: argparse.Namespace): + subjects = ["abstract_algebra"] + llm = LLM( + model=args.model, + tokenizer=args.tokenizer, + tensor_parallel_size=args.tensor_parallel_size, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + kv_cache_dtype=args.kv_cache_dtype, + kv_quant_params_path=args.kv_quant_params_path, + ) + requests, labels, _ = sample_requests( + args.dev_data_path, + args.test_data_path, + subjects, + args.is_analyse, + ) + prompt, label = requests[0], labels[0] + print(f"the correct answer is\n{label}") + sampling_params = SamplingParams( + n=args.n, + temperature=0.0 if args.use_beam_search else 1.0, + top_p=1.0, + use_beam_search=args.use_beam_search, + ignore_eos=True, + max_tokens=args.output_len, + ) + outputs = llm.generate(prompt, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="evaluation for quantization.") + + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument("--dev-data-path", + type=str, + default=None, + help="path to few-shot dataset") + parser.add_argument("--test-data-path", + type=str, + default=None, + help="path to test dataset") + parser.add_argument("--is-analyse", + action="store_true") + parser.add_argument("--output-len", + type=int, + default=200, + help="nums of max token for evaluation outputs") + parser.add_argument("--kv-cache-dtype", + type=str, + default="float16") + parser.add_argument("--kv-quant-params-path", + type=str, + default=None) + args = parser.parse_args() + main(args) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 7e449cb182b3..476007249ac2 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,80 +5,88 @@ from vllm import cache_ops -DTYPES = [torch.half, torch.bfloat16, torch.float] +DTYPES = [ + # torch.half, + # torch.bfloat16, + torch.float +] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_LAYERS = [5] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [8, 16, 32] +BLOCK_SIZES = [ + 8, + 16, + 32, +] NUM_BLOCKS = [1024] # Arbitrary values for testing NUM_MAPPINGS = [32, 256] # Arbitrary values for testing SEEDS = [0] -@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) -@pytest.mark.parametrize("num_layers", NUM_LAYERS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_copy_blocks( - kv_cache_factory, - num_mappings: int, - num_layers: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - # Generate random block mappings where each source block is mapped to two - # destination blocks. - assert 2 * num_mappings <= num_blocks - src_blocks = random.sample(range(num_blocks), num_mappings) - remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) - dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping = {} - for i in range(num_mappings): - src = src_blocks[i] - dst1 = dst_blocks[2 * i] - dst2 = dst_blocks[2 * i + 1] - block_mapping[src] = [dst1, dst2] - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, - num_layers, num_heads, - head_size, dtype, seed) - - # Clone the KV caches. - cloned_key_caches = [key_cache.clone() for key_cache in key_caches] - cloned_value_caches = [value_cache.clone() for value_cache in value_caches] - - # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) - - # Run the reference implementation. - for src, dsts in block_mapping.items(): - for dst in dsts: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst] = cloned_key_cache[src] - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst] = cloned_value_cache[src] - - # Compare the results. - for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): - assert torch.allclose(key_cache, cloned_key_cache) - for value_cache, cloned_value_cache in zip(value_caches, - cloned_value_caches): - assert torch.allclose(value_cache, cloned_value_cache) +# @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +# @pytest.mark.parametrize("num_layers", NUM_LAYERS) +# @pytest.mark.parametrize("num_heads", NUM_HEADS) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("block_size", BLOCK_SIZES) +# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +# @pytest.mark.parametrize("dtype", DTYPES) +# @pytest.mark.parametrize("seed", SEEDS) +# @torch.inference_mode() +# def test_copy_blocks( +# kv_cache_factory, +# num_mappings: int, +# num_layers: int, +# num_heads: int, +# head_size: int, +# block_size: int, +# num_blocks: int, +# dtype: torch.dtype, +# seed: int, +# ) -> None: +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) + +# # Generate random block mappings where each source block is mapped to two +# # destination blocks. +# assert 2 * num_mappings <= num_blocks +# src_blocks = random.sample(range(num_blocks), num_mappings) +# remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) +# dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) +# block_mapping = {} +# for i in range(num_mappings): +# src = src_blocks[i] +# dst1 = dst_blocks[2 * i] +# dst2 = dst_blocks[2 * i + 1] +# block_mapping[src] = [dst1, dst2] + +# # Create the KV caches. +# key_caches, value_caches = kv_cache_factory(num_blocks, block_size, +# num_layers, num_heads, +# head_size, dtype, seed) + +# # Clone the KV caches. +# cloned_key_caches = [key_cache.clone() for key_cache in key_caches] +# cloned_value_caches = [value_cache.clone() for value_cache in value_caches] + +# # Call the copy blocks kernel. +# cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + +# # Run the reference implementation. +# for src, dsts in block_mapping.items(): +# for dst in dsts: +# for cloned_key_cache in cloned_key_caches: +# cloned_key_cache[dst] = cloned_key_cache[src] +# for cloned_value_cache in cloned_value_caches: +# cloned_value_cache[dst] = cloned_value_cache[src] + +# # Compare the results. +# for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): +# assert torch.allclose(key_cache, cloned_key_cache) +# for value_cache, cloned_value_cache in zip(value_caches, +# cloned_value_caches): +# assert torch.allclose(value_cache, cloned_value_cache) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -146,8 +154,15 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +# @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_reshape_and_cache_quantized( +def test_reshape_and_cache_quantized( num_tokens: int, num_heads: int, head_size: int, @@ -204,95 +219,3 @@ def run_reshape_and_cache_quantized( assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) - - -@torch.inference_mode() -def run_gather_cached_kv( - num_tokens: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, -) -> None: - num_slots = block_size * num_blocks - slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') - - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device='cuda') - _, key, value = qkv.unbind(dim=1) - - qkv_clone = qkv.clone() - _, cloned_key, cloned_value = qkv_clone.unbind(dim=1) - - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_cache = torch.randn(size=value_cache_shape, - dtype=dtype, - device='cuda') - - cache_ops.gather_cached_kv(key, value, key_cache, value_cache, - slot_mapping) - - # Reference implementation. - for i in range(num_tokens): - reshaped_key = cloned_key.reshape(num_tokens, num_heads, - head_size // x, x) - block_idx = torch.div(slot_mapping[i], - block_size, - rounding_mode='floor') - block_offset = slot_mapping[i] % block_size - reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] - cloned_value[i] = value_cache[block_idx, :, :, block_offset] - - assert torch.allclose(key, cloned_key) - assert torch.allclose(value, cloned_value) - - -def test_copy_blocks() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - run_copy_blocks(num_mappings=23, - num_layers=7, - num_heads=17, - head_size=16, - block_size=8, - num_blocks=1024, - dtype=dtype) - - -def test_reshape_and_cache() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - run_reshape_and_cache(num_tokens=3, - num_heads=2, - head_size=16, - block_size=8, - num_blocks=2, - dtype=dtype) - - -def test_reshape_and_cache_quantized() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - run_reshape_and_cache_quantized(num_tokens=3, - num_heads=2, - head_size=16, - block_size=8, - num_blocks=2, - dtype=dtype) - - -def test_gather_cached_kv() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - run_gather_cached_kv(num_tokens=3, - num_heads=2, - head_size=16, - block_size=8, - num_blocks=2, - dtype=dtype) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 743454301838..4214f835a2dc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -81,7 +81,9 @@ def __init__( f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"quantization={model_config.quantization}, " - f"seed={model_config.seed})") + f"seed={model_config.seed})" + f"kv_cache_type={model_config.kv_cache_dtype}" + f"use kv cache quantization: {model_config.quant_kv_cache}") # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config diff --git a/vllm/kv_quant/calib_dataloader.py b/vllm/kv_quant/calib_dataloader.py new file mode 100644 index 000000000000..bd0a86823577 --- /dev/null +++ b/vllm/kv_quant/calib_dataloader.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_wikitext2(tokenizer, nsamples, seed, seqlen, path=None): + """Load Wikitext-2 train and test datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized Wikitext-2 test set. + """ + from datasets import load_dataset + traindata = load_dataset(path if path else 'wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset(path if path else 'wikitext', 'wikitext-2-raw-v1', split='test') + + trainenc = tokenizer('\n\n'.join(traindata['text']), return_tensors='pt') + testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(tokenizer, nsamples, seed, seqlen): + """Load PTB train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + valdata = load_dataset('ptb_text_only', + 'penn_treebank', + split='validation') + + trainenc = tokenizer('\n\n'.join(traindata['sentence']), + return_tensors='pt') + testenc = tokenizer('\n\n'.join(valdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(tokenizer, nsamples, seed, seqlen, path=None): + """Load C4 train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset( + path if path else 'allenai/c4', + 'allenai--c4', + data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, + split='train', + use_auth_token=False) + valdata = load_dataset( + path if path else 'allenai/c4', + 'allenai--c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation', + use_auth_token=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_ptb_new(tokenizer, nsamples, seed, seqlen): + """Load PTB New train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4_new(tokenizer, nsamples, seed, seqlen): + """Load C4 New train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, + split='train') + valdata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_pileval(tokenizer, nsamples, seed, seqlen=512): + """Load pileval train dataset and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + from datasets.builder import DatasetGenerationError + try: + dataset = load_dataset( + 'json', + data_files='https://the-eye.eu/public/AI/pile/val.jsonl.zst', + split='train') + except DatasetGenerationError: + raise InterruptedError('There have been some issues when generating ' + 'the dataset, you could try to download it ' + 'locally first, and replace the `data_files`' + 'with local addresses or use other datasets ' + '(c4, wiki, ptb).') + dataset = dataset.shuffle(seed=seed) + samples = [] + n_run = 0 + for data in dataset: + line = data['text'] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run += 1 + if n_run == nsamples: + break + # now concatenate all samples and split according to block size + cat_samples = torch.cat(samples, dim=1) + n_split = cat_samples.shape[1] // seqlen + print(f' * Split into {n_split} blocks') + return [ + cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split) + ], None + + +def get_calib_loaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048, path=None): + """Get calibration data loaders for a dataset. + + Args: + name: Dataset name ('wikitext2', 'ptb', 'c4', etc). + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_data: Full tokenized validation set. + """ + if 'wikitext2' in name: + return get_wikitext2(tokenizer, nsamples, seed, seqlen, path) + if 'ptb' in name: + if 'new' in name: + return get_ptb_new(tokenizer, nsamples, seed, seqlen, path) + return get_ptb(tokenizer, nsamples, seed, seqlen, path) + if 'c4' in name: + if 'new' in name: + return get_c4_new(tokenizer, nsamples, seed, seqlen, path) + return get_c4(tokenizer, nsamples, seed, seqlen, path) + + if 'pileval' in name: + return get_pileval(tokenizer, nsamples, seed, seqlen, path) diff --git a/vllm/kv_quant/calibrate.py b/vllm/kv_quant/calibrate.py new file mode 100644 index 000000000000..7097e29e9d98 --- /dev/null +++ b/vllm/kv_quant/calibrate.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Adapted from +# https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/lite/apis/calibrate.py + +# Copyright (c) OpenMMLab. All rights reserved. + +from pathlib import Path + +import fire +import torch +from accelerate import (infer_auto_device_map, init_empty_weights, + load_checkpoint_in_model) +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from vllm.kv_quant.calibration import CalibrationContext +from vllm.kv_quant.utils import collect_target_modules +from vllm.kv_quant.calib_dataloader import get_calib_loaders + +LAYER_TYPE_MAP = { + 'InternLMForCausalLM': 'InternLMDecoderLayer', + 'QWenLMHeadModel': 'QWenBlock', + 'BaiChuanForCausalLM': 'DecoderLayer', + 'LlamaForCausalLM': 'LlamaDecoderLayer', +} +NORM_TYPE_MAP = { + 'InternLMForCausalLM': 'InternLMRMSNorm', + 'QWenLMHeadModel': 'RMSNorm', + 'BaiChuanForCausalLM': 'RMSNorm', + 'LlamaForCausalLM': 'LlamaRMSNorm', +} + + +def calibrate(model: str, + calib_dataset: str = 'c4', + calib_samples: int = 128, + calib_seqlen: int = 2048, + work_dir: str = './work_dir', + device: str = 'cuda', + dataset_path: str = None) -> None: + """The main function for loading the model and performing calibration on a + given dataset. + + Args: + model (str): The model to be loaded. + calib_dataset (str, optional): The calibration dataset name. + Defaults to 'c4'. + calib_samples (int, optional): The number of samples for calibration. + Defaults to 128. + calib_seqlen (int, optional): The sequence length for calibration. + Defaults to 2048. + work_dir (str): The working directory for outputs. + Defaults to './work_dir'. + device (str, optional): The device to be used for calculation. + Defaults to 'cuda'. + """ + + assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \ + 'Support only `c4`, `ptb`, `wikitext2` or `pileval`.' + + # Load tokenizer and configuration + tokenizer = AutoTokenizer.from_pretrained(model, + use_fast=False, + trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True) + checkpoint = hf_config._name_or_path + + with init_empty_weights(): + # Load model + model = AutoModelForCausalLM.from_pretrained(model, + torch_dtype=torch.float16, + trust_remote_code=True) + model.config.use_cache = False + + layer_type = LAYER_TYPE_MAP[type(model).__name__] + norm_type = NORM_TYPE_MAP[type(model).__name__] + + decoder_layers = collect_target_modules(model, layer_type) + + # Infer device map + device_map = infer_auto_device_map(model, + no_split_module_classes=[layer_type]) + for name in device_map.keys(): + if name in decoder_layers or 'lm_head' in name: + device_map[name] = 'cpu' + else: + device_map[name] = 0 + load_checkpoint_in_model(model, checkpoint, device_map) + + print('Loading calibrate dataset ...') + calib_loader, _ = get_calib_loaders(calib_dataset, + tokenizer, + nsamples=calib_samples, + seqlen=calib_seqlen, + path=dataset_path) + + # Initialize calibration context + calib_ctx = CalibrationContext(model, + tokenizer, + layer_type=layer_type, + norm_type=norm_type, + device=device) + + with calib_ctx: + all_data = torch.cat([ + data if isinstance(data, torch.Tensor) else data[0] + for data in calib_loader + ]).to(device) + calib_ctx.calibrate(all_data) + + # Create work directory if not exists + work_dir = Path(work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + calib_ctx.export(work_dir) + + +if __name__ == '__main__': + fire.Fire(calibrate) diff --git a/vllm/kv_quant/calibration.py b/vllm/kv_quant/calibration.py new file mode 100644 index 000000000000..d38e9e486456 --- /dev/null +++ b/vllm/kv_quant/calibration.py @@ -0,0 +1,307 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Union + +import torch +from torch import nn +from transformers import PreTrainedTokenizer +from vllm.kv_quant.utils import (bimap_name_mod, collect_target_modules, + concat_decoder_layer_outputs, + split_decoder_layer_inputs) +from vllm.kv_quant.observer import ActivationObserver, KVCacheObserver + + +class CalibrationContext(): + """Calibration context manager for model quantization. + + Parameters: + - model: The target model to be calibrated and quantized + - tokenizer: The tokenizer used in the model training + - layer_type: Layer type to be targeted for calibration + - norm_type: Normalization type used for calibration + - device: Device on which model is to be calibrated ('cpu' or 'cuda') + """ + + inp_obs_group = 'inputs' + out_obs_group = 'outputs' + key_obs_group = 'keys' + value_obs_group = 'values' + + def __init__(self, + model: nn.Module, + tokenizer: PreTrainedTokenizer, + layer_type: Union[str, type], + norm_type: Union[str, type], + device: str = 'cuda') -> None: + """Initiate calibration context. + + Args: + model (nn.Module): Model to be calibrated. + tokenizer (PreTrainedTokenizer): Tokenizer of the given model. + layer_type (Union[str, type]): Type of the layers to be observed. + norm_type (Union[str, type]): Norm type used in the model. + device (str, optional): Device where the model should run. + Defaults to 'cuda'. + """ + + self.layer_type = layer_type + self.norm_type = norm_type + + num_kv_heads, num_attn_heads = self._guess_num_heads(model) + self.num_kv_heads = num_kv_heads + self.head_dim = model.config.hidden_size // num_attn_heads + self.model = model + del self.model.lm_head + + self.tokenizer = tokenizer + + # Collect modules to observe + self.name2layer = collect_target_modules(self.model, layer_type) + self.name2fc = {} + for l_name, layer in self.name2layer.items(): + name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name) + self.name2fc.update(name2fc) + self.name2norm = collect_target_modules(self.model, norm_type) + + maps = bimap_name_mod([self.name2layer, self.name2fc, self.name2norm]) + self.name2mod, self.mod2name = maps + + # Initialize observers + self._init_input_observers(self.name2fc) + self._init_output_observers(self.name2norm) + self._init_output_observers(self.name2fc) + self._init_kv_observers(self.name2layer) + + self.device = device + + def _guess_num_heads(self, model): + + if hasattr(model.config, 'num_key_value_heads'): + num_kv_heads = model.config.num_key_value_heads + else: + num_kv_heads = model.config.num_attention_heads + + num_attn_heads = model.config.num_attention_heads + + return num_kv_heads, num_attn_heads + + def _init_input_observers(self, name2mod): + """Initialize input observers for given modules.""" + for name, mod in name2mod.items(): + obs = ActivationObserver(mod.weight.size(-1)) + obs.global_available(name, group=self.inp_obs_group) + + def _init_output_observers(self, name2mod): + """Initialize output observers for given modules.""" + for name, mod in name2mod.items(): + obs = ActivationObserver(mod.weight.size(0)) + obs.global_available(name, group=self.out_obs_group) + + def _init_kv_observers(self, name2mod): + """Initialize KV observers for given modules.""" + for name in name2mod.keys(): + k_obs = KVCacheObserver(self.num_kv_heads, self.head_dim) + v_obs = KVCacheObserver(self.num_kv_heads, self.head_dim) + k_obs.global_available(name, group=self.key_obs_group) + v_obs.global_available(name, group=self.value_obs_group) + + def _insert_input_observers(self): + """Insert input observers into the target modules. + + This function registers a forward pre-hook on each target module to + observe the inputs. + """ + + def _input_hook(mod: nn.Module, inp: torch.Tensor): + m_name = self.mod2name[mod] + obs = ActivationObserver.find(m_name, group=self.inp_obs_group) + obs.observe(inp[0]) + + group = ActivationObserver.find_group(self.inp_obs_group) + for name in group.keys(): + mod = self.name2mod[name] + hook_fn = mod.register_forward_pre_hook(_input_hook) + self._hooks.append(hook_fn) + + def _insert_output_observers(self): + """Insert output observers into the target modules. + + This function registers a forward hook on each target module to observe + the outputs. + """ + + def _output_hook(mod: nn.Module, inp: torch.Tensor, out: torch.Tensor): + m_name = self.mod2name[mod] + obs = ActivationObserver.find(m_name, group=self.out_obs_group) + obs.observe(out) + + group = ActivationObserver.find_group(self.out_obs_group) + for name in group.keys(): + mod = self.name2mod[name] + hook_fn = mod.register_forward_hook(_output_hook) + self._hooks.append(hook_fn) + + def _wrap_decoder_layers(self): + """Method to wrap the decoder layers' forward functions for observing + their key/value cache during batched forward passes.""" + + def _forward(mod, *args, **kwargs): + + mod.to(self.device) + batch_args, batch_kwargs = split_decoder_layer_inputs( + *args, **kwargs) + batch_outputs = [] + samples = len(batch_args) + + m_name = self.mod2name[mod] + k_obs = KVCacheObserver.find(m_name, group=self.key_obs_group) + v_obs = KVCacheObserver.find(m_name, group=self.value_obs_group) + + for i in range(len(batch_args)): + + if k_obs and v_obs: + batch_kwargs[i]['use_cache'] = True + out = self._ori_forwards[mod](*batch_args[i], + **batch_kwargs[i]) + out = list(out) + key, value = out.pop(-1) + k_obs.observe(key) + v_obs.observe(value) + + del key, value + torch.cuda.empty_cache() + batch_outputs.append(tuple(out)) + else: + batch_outputs.append(self._ori_forwards[mod]( + *batch_args[i], **batch_kwargs[i])) + + outputs = concat_decoder_layer_outputs(batch_outputs) + + del batch_outputs, batch_args, batch_kwargs, args + mod.to('cpu') + torch.cuda.empty_cache() + max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + print(f'{m_name}, samples: {samples}, ' + f'max gpu memory: {max_memory:.2f} GB') + return outputs + + for layer in self.name2layer.values(): + self._ori_forwards[layer] = layer.forward + layer.forward = partial(_forward, layer) + + def collect_inputs_stats(self): + """Collect statistics (min, max, absmax values) of the observed inputs. + + Returns a dictionary with these collected stats. + """ + inputs_stats = { + 'max': {}, + 'min': {}, + 'mean': {}, + 'absmax': {}, + 'absmean': {} + } + obs_group = ActivationObserver.find_group(self.inp_obs_group) + for name, obs in obs_group.items(): + inputs_stats['max'][name] = obs.max_val + inputs_stats['min'][name] = obs.min_val + inputs_stats['mean'][name] = obs.mean_val + inputs_stats['absmax'][name] = obs.absmax_val + inputs_stats['absmean'][name] = obs.absmean_val + return inputs_stats + + def collect_outputs_stats(self): + """Collect statistics (min, max, absmax values) of the observed + outputs. + + Returns a dictionary with these collected stats. + """ + outputs_stats = { + 'max': {}, + 'min': {}, + 'mean': {}, + 'absmax': {}, + 'absmean': {} + } + obs_group = ActivationObserver.find_group(self.out_obs_group) + for name, obs in obs_group.items(): + outputs_stats['max'][name] = obs.max_val + outputs_stats['min'][name] = obs.min_val + outputs_stats['mean'][name] = obs.mean_val + outputs_stats['absmax'][name] = obs.absmax_val + outputs_stats['absmean'][name] = obs.absmean_val + return outputs_stats + + def collect_kv_stats(self): + """Collect statistics (min, max, absmax values) of the observed keys + and values. + + Returns a tuple of two dictionaries with these collected stats. + """ + key_stats = {'max': {}, 'min': {}, 'absmax': {}} + obs_group = KVCacheObserver.find_group(self.key_obs_group) + for name, obs in obs_group.items(): + key_stats['max'][name] = obs.max_val + key_stats['min'][name] = obs.min_val + key_stats['absmax'][name] = obs.absmax_val + + value_stats = {'max': {}, 'min': {}, 'absmax': {}} + obs_group = KVCacheObserver.find_group(self.value_obs_group) + for name, obs in obs_group.items(): + value_stats['max'][name] = obs.max_val + value_stats['min'][name] = obs.min_val + value_stats['absmax'][name] = obs.absmax_val + return key_stats, value_stats + + def export(self, out_dir): + """Export the calibration statistics (inputs, outputs, keys and values) + to specified directory. + + Args: + out_dir (Union[str, Path]): The directory path where the stats + will be saved. + """ + + inp_stats = self.collect_inputs_stats() + torch.save(inp_stats, out_dir / 'inputs_stats.pth') + + out_stats = self.collect_outputs_stats() + torch.save(out_stats, out_dir / 'outputs_stats.pth') + + key_stats, value_stats = self.collect_kv_stats() + torch.save(key_stats, out_dir / 'key_stats.pth') + torch.save(value_stats, out_dir / 'value_stats.pth') + + def calibrate(self, data): + """Forward pass through the model in inference mode with given data.""" + + if type(self.model).__name__ == 'QWenLMHeadModel': + model = self.model.transformer + else: + model = self.model.model + with torch.inference_mode(): + _ = model(data.to(self.device)) + + def __enter__(self): + """Prepares the Calibration object for a 'with' statement by + registering hooks and wrapping layer forward methods.""" + + self._hooks = list() + + self._ori_forwards = {} + for layer in self.name2layer.values(): + self._ori_forwards[layer] = layer.forward + + self._insert_input_observers() + self._insert_output_observers() + self._wrap_decoder_layers() + + def __exit__(self, exc_type, exc_value, traceback): + """Clean up after a 'with' statement by removing registered hooks, + restoring original forward methods, and if no exception occurred, + collecting all gathered statistics and saving them.""" + for h in self._hooks: + h.remove() + + for layer in self.name2layer.values(): + layer.forward = self._ori_forwards[layer] diff --git a/vllm/kv_quant/export_kv_params.py b/vllm/kv_quant/export_kv_params.py new file mode 100644 index 000000000000..e0cf47d9b751 --- /dev/null +++ b/vllm/kv_quant/export_kv_params.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import fire + + +def _export_sym(key_stats: dict, + value_stats: dict, + bits: int, + out_dir: Union[str, Path], + tp: int = 1) -> None: + """Export symmetric quantization parameters to specified directory.""" + keys_absmax = key_stats['absmax'] + values_absmax = value_stats['absmax'] + for layer_idx, name in enumerate(keys_absmax.keys()): + k_absmax = keys_absmax[name] + v_absmax = values_absmax[name] + + heads, dims = k_absmax.shape + assert heads % tp == 0 + + mp_k_absmax = torch.chunk(k_absmax, tp) + mp_v_absmax = torch.chunk(v_absmax, tp) + for i in range(tp): + # quant: q = f / scale + # dequant: f = q * scale + k_s = mp_k_absmax[i].max() / (2**(bits - 1) - 1) + v_s = mp_v_absmax[i].max() / (2**(bits - 1) - 1) + + kv_qparams = np.array([k_s, v_s], dtype=np.float32) + out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501 + kv_qparams.tofile(out_path) + print(f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}') + + +def _export_asym(key_stats: dict, + value_stats: dict, + bits: int, + out_dir: Union[str, Path], + tp: int = 1) -> None: + """Export asymmetric quantization parameters to specified directory.""" + keys_min = key_stats['min'] + values_min = value_stats['min'] + + keys_max = key_stats['max'] + values_max = value_stats['max'] + for layer_idx, name in enumerate(keys_min.keys()): + k_max = keys_max[name] + v_max = values_max[name] + + k_min = keys_min[name] + v_min = values_min[name] + + heads, dims = k_min.shape + assert heads % tp == 0 + + tp_k_min = torch.chunk(k_min, tp) + tp_v_min = torch.chunk(v_min, tp) + + tp_k_max = torch.chunk(k_max, tp) + tp_v_max = torch.chunk(v_max, tp) + for i in range(tp): + # zp = (min+max) / 2 + # scale = (max-min) / 255 + # quant: q = (f-zp) / scale + # dequant: f = q * scale + zp + k_min = tp_k_min[i].min() + v_min = tp_v_min[i].min() + + k_max = tp_k_max[i].max() + v_max = tp_v_max[i].max() + + k_scale = (k_max - k_min) / (2**bits - 1) + v_scale = (v_max - v_min) / (2**bits - 1) + + k_zp = (k_max + k_min) / 2 + v_zp = (v_max + v_min) / 2 + + kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp], + dtype=np.float32) + out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' + kv_qparams.tofile(out_path) + print(f'Layer {layer_idx} MP {i} qparam: ' + f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}') + + +def main(work_dir: str, + kv_params_dir: str, + kv_bits: int = 8, + kv_sym: bool = False, + num_tp: int = 1) -> None: + """Main function to export key and value stats. + + Args: + work_dir (Union[str, Path]): Directory path where the stats are saved. + turbomind_dir (Union[str, Path]): Directory path where to + save the results. + kv_bits (int, optional): Number of bits for quantization. + Defaults to 8. + kv_sym (bool, optional): Whether to use symmetric quantizaiton. + Defaults to False. + num_tp (int, optional): Number of tensor parallelism. Defaults to 1. + """ + + work_dir = Path(work_dir) + + tm_dir = Path(kv_params_dir) + assert tm_dir.exists(), 'The specified TurboMind directory does not exist.' + + key_stats = torch.load(work_dir / 'key_stats.pth') + value_stats = torch.load(work_dir / 'value_stats.pth') + + if kv_sym: + _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + else: + _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/vllm/kv_quant/observer.py b/vllm/kv_quant/observer.py new file mode 100644 index 000000000000..f36a63c0e0df --- /dev/null +++ b/vllm/kv_quant/observer.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Union +import torch +from torch import nn + + +class GlobalAvailMixin: + """Mixin class to make instances globally available.""" + + _instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = { + 'default': {} + } + + def global_available(self, + key: Union[str, nn.Module] = 'default', + group: str = 'default') -> None: + """Make the instance globally available. + + Args: + key (Union[str, nn.Module], optional): Key to save the instance. + Defaults to 'default'. + group (str, optional): Group to save the instance. + Defaults to 'default'. + """ + self._save_instance(self, key, group) + + @classmethod + def _save_instance(cls, + instance: 'GlobalAvailMixin', + key: Union[str, nn.Module] = 'default', + group: str = 'default') -> None: + """Save the instance. + + Args: + instance (GlobalAvailMixin): Instance to save. + key (Union[str, nn.Module], optional): Key to save the instance. + Defaults to 'default'. + group (str, optional): Group to save the instance. + Defaults to 'default'. + """ + if group not in cls._instances: + assert isinstance(group, str) + cls._instances[group] = {} + + cls._instances[group][key] = instance + + @classmethod + def find(cls, + key: Union[str, nn.Module] = 'default', + group: str = 'default') -> Union[None, 'GlobalAvailMixin']: + """Find an instance by its key and group. + + Args: + key (Union[str, nn.Module], optional): Key of the instance. + Defaults to 'default'. + group (str, optional): Group of the instance. + Defaults to 'default'. + + Returns: + Union[None, GlobalAvailMixin]: The found instance, or None if + it does not exist. + """ + return cls._instances.get(group, {}).get(key) + + @classmethod + def find_group( + cls, + group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']: + """Find all instances in a group. + + Args: + group (str): Group of the instances. + + Returns: + Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in + the group. + """ + return cls._instances.get(group, {}) + + @classmethod + def instances( + cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]: + """Get all instances.""" + return cls._instances + + +class KVCacheObserver(GlobalAvailMixin): + """A class to observe and record the max, min, and absolute max value of + given tensor.""" + + def __init__(self, num_head: int, head_dim: int) -> None: + """Constructor for KVCacheObserver. + + Args: + num_head : Number of heads + head_dim : Dimension of each head + """ + self.num_head = num_head + self.head_dim = head_dim + self.max_val = torch.full((num_head, head_dim), + -torch.inf, + dtype=torch.float16) + self.min_val = torch.full((num_head, head_dim), + torch.inf, + dtype=torch.float16) + self.absmax_val = torch.full((num_head, head_dim), + 0, + dtype=torch.float16) + + @torch.no_grad() + def observe(self, x: torch.Tensor) -> None: + """Function to observe the input tensor and update the max, min, and + absolute max values. + + Args: + x : Input tensor + """ + assert len(x.shape) == 4 + + if x.size(2) == self.num_head and x.size(3) == self.head_dim: + # layout: (bs, seqlen, heads, dims) + x = x + elif x.size(1) == self.num_head and x.size(3) == self.head_dim: + # layout: (bs, heads, seqlen, dims) + x = x.transpose(1, 2) + else: + raise RuntimeError + + cur_max = x.flatten(0, 1).max(0)[0].cpu() + cur_min = x.flatten(0, 1).min(0)[0].cpu() + cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu() + + self.max_val = torch.maximum(self.max_val, cur_max) + self.min_val = torch.minimum(self.min_val, cur_min) + self.absmax_val = torch.maximum(self.absmax_val, cur_absmax) + + +class ActivationObserver(GlobalAvailMixin): + """A class to observe and record the max, min, mean, absolute max, and + absolute mean value of a given tensor. + + Also keeps track of the number of batches observed. + """ + + def __init__(self, dim: int) -> None: + """Constructor for ActivationObserver. + + Args: + dim : Dimension of the tensor + """ + self.dim = dim + self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16) + self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16) + self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16) + self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16) + self.mean_val = torch.full((dim, ), 0, dtype=torch.float16) + self.num_batches_tracked = 0 + + @torch.no_grad() + def observe(self, x: torch.Tensor) -> None: + """Function to observe the input tensor and update the max, min, mean, + absolute max, absolute mean values and number of batches tracked. + + Args: + x : Input tensor + """ + assert len(x.shape) == 3 + assert x.size(2) == self.dim + cur_val = x.flatten(0, 1) + cur_max = cur_val.max(0)[0].cpu() + cur_min = cur_val.min(0)[0].cpu() + cur_mean = cur_val.mean(0).cpu() + + cur_abs = cur_val.abs() + cur_absmax = cur_abs.max(0)[0].cpu() + cur_absmean = cur_abs.mean(0).cpu() + + self.max_val = torch.maximum(self.max_val, cur_max) + self.min_val = torch.minimum(self.min_val, cur_min) + self.absmax_val = torch.maximum(self.absmax_val, cur_absmax) + + # Update mean and absmean value with accumulated sum divided + # by total number of batches + self.mean_val = ( + (self.mean_val * self.num_batches_tracked + cur_mean) / + (self.num_batches_tracked + 1)) + self.absmean_val = ( + (self.absmean_val * self.num_batches_tracked + cur_absmean) / + (self.num_batches_tracked + 1)) + + # Increment the count of batches tracked + self.num_batches_tracked += 1 diff --git a/vllm/kv_quant/utils.py b/vllm/kv_quant/utils.py new file mode 100644 index 000000000000..309c48e3c213 --- /dev/null +++ b/vllm/kv_quant/utils.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple, Union +import torch +from torch import nn + + +def split_decoder_layer_inputs( + *args: Union[torch.Tensor, Any], **kwargs: Union[torch.Tensor, Any] +) -> Tuple[List[List[Any]], List[Dict[str, Any]]]: + """This function splits batched decoder layer inputs into individual + elements. + + Args: + *args (Union[torch.Tensor, Any]): Positional arguments which could + be a mix of tensors and other types. + **kwargs (Union[torch.Tensor, Any]): Keyword arguments which could + be a mix of tensors and other types. + + Returns: + Tuple[List[List[Any]], List[Dict[str, Any]]]: A tuple containing two + lists, one for positional arguments, one for keyword arguments. + Each list contains individual elements from the batch. + """ + + if not isinstance(args[0], torch.Tensor): + raise ValueError('The first argument must be a Tensor') + + bs = args[0].size(0) + + batch_args = [] + batch_kwargs = [] + for i in range(bs): + new_args = [] + # Iterate over each argument. If it's a torch.Tensor and its first + # dimension equals the batch size, then get the value corresponding + # to the current index, else directly add the whole value. + for val in args: + if isinstance(val, torch.Tensor) and val.size(0) == bs: + new_args.append(val[i:i + 1]) + else: + new_args.append(val) + + new_kwargs = {} + # Execute the same operation for the keyword arguments. + for name, val in kwargs.items(): + if isinstance(val, torch.Tensor) and val.size(0) == bs: + new_kwargs[name] = val[i:i + 1] + else: + new_kwargs[name] = val + + batch_args.append(new_args) + batch_kwargs.append(new_kwargs) + + return batch_args, batch_kwargs + + +def concat_decoder_layer_outputs( + batch_outputs: List[Tuple[Any]]) -> Tuple[Any]: + """This function concatenates individual decoder layer outputs into a + batched output. + + Args: + batch_outputs (List[Tuple[Any]]): A list of tuples, where each tuple + represents the output from an individual element in the batch. + + Returns: + Tuple[Any]: A tuple representing the batched output. + """ + + num_returns = len(batch_outputs[0]) + + def is_past_key_value(data: Any) -> bool: + """Check whether data is a past key-value pair. + + Args: + data (Any): The data to check. + + Returns: + bool: True if data is a past key-value pair, False otherwise. + """ + flag = isinstance(data, tuple) + flag = flag and len(data) == 2 + flag = flag and isinstance(data[0], torch.Tensor) + flag = flag and isinstance(data[1], torch.Tensor) + return flag + + new_outputs = [] + + # Iterate over all types of return values. + for i in range(num_returns): + # Check if the current element is a past key-value pair. + flag = is_past_key_value(batch_outputs[0][i]) + if flag: + # Concatenate the keys and values separately. + key = torch.cat([out[i][0] for out in batch_outputs]) + value = torch.cat([out[i][1] for out in batch_outputs]) + out_i = (key, value) + else: + # If it's not a past key-value pair, concatenate directly. + out_i = torch.cat([out[i] for out in batch_outputs]) + new_outputs.append(out_i) + + return tuple(new_outputs) + + +def collect_target_modules(model: nn.Module, + # target: Union[str, type], + target: str, + skip_names: List[str] = [], + prefix: str = '') -> Dict[str, nn.Module]: + """Collects the specific target modules from the model. + + Args: + model : The PyTorch module from which to collect the target modules. + target : The specific target to be collected. It can be a class of a + module or the name of a module. + skip_names : List of names of modules to be skipped during collection. + prefix : A string to be added as a prefix to the module names. + + Returns: + A dictionary mapping from module names to module instances. + """ + + # if isinstance(target, LazyAttr): + # target = target.build() + + if not isinstance(target, (type, str)): + raise TypeError('Target must be a string (name of the module) ' + 'or a type (class of the module)') + + def _is_target(n, m): + if isinstance(target, str): + return target == type(m).__name__ and n not in skip_names + return isinstance(m, target) and n not in skip_names + + name2mod = {} + for name, mod in model.named_modules(): + m_name = f'{prefix}.{name}' if prefix else name + if _is_target(name, mod): + name2mod[m_name] = mod + return name2mod + + +def bimap_name_mod( + name2mod_mappings: List[Dict[str, nn.Module]] +) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]: + """Generates bidirectional maps from module names to module instances and + vice versa. + + Args: + name2mod_mappings : List of dictionaries each mapping from module + names to module instances. + + Returns: + Two dictionaries providing bidirectional mappings between module + names and module instances. + """ + + name2mod = {} + mod2name = {} + for mapping in name2mod_mappings: + mod2name.update({v: k for k, v in mapping.items()}) + name2mod.update(mapping) + return name2mod, mod2name From 9f872d9124e31dd2b6f5d77d9db2cbe2064a16a6 Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Thu, 28 Sep 2023 14:37:40 +0800 Subject: [PATCH 25/52] modify test functions --- benchmarks/benchmark_evaluation.py | 9 +++------ tests/kernels/test_cache.py | 18 +++++++++++------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_evaluation.py b/benchmarks/benchmark_evaluation.py index 4ac9af033098..7bdc92b53fe0 100644 --- a/benchmarks/benchmark_evaluation.py +++ b/benchmarks/benchmark_evaluation.py @@ -28,7 +28,7 @@ def sample_requests( subjects: List[str], dataset_template: str = "mmlu", is_analyse: bool = False, -) -> List[Tuple[str, int, int]]: +) -> Tuple[List[str], List[str], List[int]]: # Load the dataset. nums_questions = [] dataset = [] @@ -110,7 +110,7 @@ def evalute( def main(args: argparse.Namespace): subjects = [ - "abstract_algebra", + "college_computer_science", ] dataset, labels, nums_questions = sample_requests( args.dev_data_path, @@ -130,9 +130,6 @@ def main(args: argparse.Namespace): args.use_beam_search, args.trust_remote_code, ) - foo = request_outputs[0] - print(foo.outputs[0].text) - assert False sub2acc = evalute( request_outputs, labels, @@ -173,7 +170,7 @@ def main(args: argparse.Namespace): help="nums of max token for evaluation outputs") parser.add_argument("--kv-cache-dtype", type=str, - default="int8") + default="float16") parser.add_argument("--kv-quant-params-path", type=str, default=None) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 476007249ac2..f277e8770c7d 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -186,7 +186,7 @@ def test_reshape_and_cache_quantized( device='cuda') _, key, value = qkv.unbind(dim=1) - x = 16 // torch.tensor([], dtype=dtype).element_size() + x = 16 // torch.tensor([], dtype=torch.int8).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache = torch.randint(-10, 10, size=key_cache_shape, dtype=torch.int8, device='cuda') ## change to int8 cloned_key_cache = key_cache.clone() @@ -201,11 +201,15 @@ def test_reshape_and_cache_quantized( slot_mapping, k_scale, k_zp, v_scale, v_zp) lower_bound, upper_bound = torch.tensor([-128.0], dtype=dtype, device='cuda'), torch.tensor([127.0], dtype=dtype, device='cuda') ## quantize and store here - reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) - reshaped_key = torch.maximum(lower_bound, torch.minimum(upper_bound, (reshaped_key - k_zp) / k_scale)) - reshaped_key = torch.round(reshaped_key) - reshaped_key = reshaped_key.to(torch.int8) ## change to int8 - quantized_value = torch.maximum(lower_bound, torch.minimum(upper_bound, (value - v_zp) / v_scale)) + ## quantize and store here + quantized_key = key.reshape(num_tokens, num_heads, head_size // x, x) + quantized_key = quantized_key.to(torch.float32) + quantized_key = torch.maximum(lower_bound, torch.minimum(upper_bound, (quantized_key - k_zp) / k_scale)) + quantized_key = torch.round(quantized_key) + quantized_key = quantized_key.to(torch.int8) ## change to int8 + + quantized_value = value.to(torch.float32) + quantized_value = torch.maximum(lower_bound, torch.minimum(upper_bound, (quantized_value - v_zp) / v_scale)) quantized_value = torch.round(quantized_value) quantized_value = quantized_value.to(torch.int8) @@ -214,7 +218,7 @@ def test_reshape_and_cache_quantized( block_size, rounding_mode='floor') block_offset = slot_mapping[i] % block_size - cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] + cloned_key_cache[block_idx, :, :, block_offset, :] = quantized_key[i] cloned_value_cache[block_idx, :, :, block_offset] = quantized_value[i] assert torch.allclose(key_cache, cloned_key_cache) From 892c589f73e486893d70069d0c3c9c762d0d25d5 Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Thu, 28 Sep 2023 14:48:01 +0800 Subject: [PATCH 26/52] fix test code --- tests/kernels/test_cache.py | 126 ++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index f277e8770c7d..baa90dc675b9 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -24,69 +24,69 @@ SEEDS = [0] -# @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) -# @pytest.mark.parametrize("num_layers", NUM_LAYERS) -# @pytest.mark.parametrize("num_heads", NUM_HEADS) -# @pytest.mark.parametrize("head_size", HEAD_SIZES) -# @pytest.mark.parametrize("block_size", BLOCK_SIZES) -# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -# @pytest.mark.parametrize("dtype", DTYPES) -# @pytest.mark.parametrize("seed", SEEDS) -# @torch.inference_mode() -# def test_copy_blocks( -# kv_cache_factory, -# num_mappings: int, -# num_layers: int, -# num_heads: int, -# head_size: int, -# block_size: int, -# num_blocks: int, -# dtype: torch.dtype, -# seed: int, -# ) -> None: -# random.seed(seed) -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) - -# # Generate random block mappings where each source block is mapped to two -# # destination blocks. -# assert 2 * num_mappings <= num_blocks -# src_blocks = random.sample(range(num_blocks), num_mappings) -# remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) -# dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) -# block_mapping = {} -# for i in range(num_mappings): -# src = src_blocks[i] -# dst1 = dst_blocks[2 * i] -# dst2 = dst_blocks[2 * i + 1] -# block_mapping[src] = [dst1, dst2] - -# # Create the KV caches. -# key_caches, value_caches = kv_cache_factory(num_blocks, block_size, -# num_layers, num_heads, -# head_size, dtype, seed) - -# # Clone the KV caches. -# cloned_key_caches = [key_cache.clone() for key_cache in key_caches] -# cloned_value_caches = [value_cache.clone() for value_cache in value_caches] - -# # Call the copy blocks kernel. -# cache_ops.copy_blocks(key_caches, value_caches, block_mapping) - -# # Run the reference implementation. -# for src, dsts in block_mapping.items(): -# for dst in dsts: -# for cloned_key_cache in cloned_key_caches: -# cloned_key_cache[dst] = cloned_key_cache[src] -# for cloned_value_cache in cloned_value_caches: -# cloned_value_cache[dst] = cloned_value_cache[src] - -# # Compare the results. -# for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): -# assert torch.allclose(key_cache, cloned_key_cache) -# for value_cache, cloned_value_cache in zip(value_caches, -# cloned_value_caches): -# assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_layers", NUM_LAYERS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_copy_blocks( + kv_cache_factory, + num_mappings: int, + num_layers: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Generate random block mappings where each source block is mapped to two + # destination blocks. + assert 2 * num_mappings <= num_blocks + src_blocks = random.sample(range(num_blocks), num_mappings) + remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) + block_mapping = {} + for i in range(num_mappings): + src = src_blocks[i] + dst1 = dst_blocks[2 * i] + dst2 = dst_blocks[2 * i + 1] + block_mapping[src] = [dst1, dst2] + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, block_size, + num_layers, num_heads, + head_size, dtype, seed) + + # Clone the KV caches. + cloned_key_caches = [key_cache.clone() for key_cache in key_caches] + cloned_value_caches = [value_cache.clone() for value_cache in value_caches] + + # Call the copy blocks kernel. + cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + # Run the reference implementation. + for src, dsts in block_mapping.items(): + for dst in dsts: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst] = cloned_key_cache[src] + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst] = cloned_value_cache[src] + + # Compare the results. + for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): + assert torch.allclose(key_cache, cloned_key_cache) + for value_cache, cloned_value_cache in zip(value_caches, + cloned_value_caches): + assert torch.allclose(value_cache, cloned_value_cache) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) From 538947dda996f09b889e58fd0d0bc07ba3b5b993 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Thu, 28 Sep 2023 15:55:36 +0800 Subject: [PATCH 27/52] fix test attention --- tests/kernels/test_attention.py | 101 +++++++++++++------------------- 1 file changed, 42 insertions(+), 59 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 4d575428d646..ba7bfb1ef8a3 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -91,14 +91,6 @@ def ref_single_query_cached_kv_attention( out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) - -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) def ref_single_query_cached_kv_attention_quantized( output: torch.Tensor, query: torch.Tensor, @@ -238,6 +230,13 @@ def ref_multi_query_cached_kv_attention( return ref_output +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_single_query_cached_kv_attention( kv_cache_factory, @@ -470,7 +469,42 @@ def run_single_query_cached_kv_attention_quantized( # We should use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) +def test_single_query_cached_kv_attention_quantized() -> None: + # FIXME: set TEST_SEED + torch.random.manual_seed(0) + torch.cuda.manual_seed(0) + for dtype in [ + torch.half, + torch.bfloat16, + torch.float, + ]: + for block_size in [8, + 16, + ]: + for head_size in [64, + 80, + 96, + 112, + 128, + 256, + ]: + print(f'Testing single_query_cached_kv_attention with ' + f'dtype={dtype}, block_size={block_size}, ' + f'head_size={head_size}') + run_single_query_cached_kv_attention_quantized( + num_tokens=37, + num_heads=3, + head_size=head_size, + block_size=block_size, + num_blocks=1024, + dtype=dtype, + ) +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def run_multi_query_kv_attention( num_seqs: int, @@ -526,57 +560,6 @@ def run_multi_query_kv_attention( ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - -def test_single_query_cached_kv_attention() -> None: - torch.random.manual_seed(TEST_SEED) - torch.cuda.manual_seed(TEST_SEED) - for dtype in [torch.half, torch.bfloat16, torch.float]: - for block_size in [8, 16, 32]: - for head_size in [64, 80, 96, 112, 128, 256]: - print(f'Testing single_query_cached_kv_attention with ' - f'dtype={dtype}, block_size={block_size}, ' - f'head_size={head_size}') - run_single_query_cached_kv_attention( - num_tokens=37, - num_heads=3, - head_size=head_size, - block_size=block_size, - num_blocks=1024, - dtype=dtype, - ) - - -def test_single_query_cached_kv_attention_quantized() -> None: - torch.random.manual_seed(TEST_SEED) - torch.cuda.manual_seed(TEST_SEED) - for dtype in [ - torch.half, - torch.bfloat16, - torch.float, - ]: - for block_size in [8, - 16, - ]: - for head_size in [64, - 80, - 96, - 112, - 128, - 256, - ]: - print(f'Testing single_query_cached_kv_attention with ' - f'dtype={dtype}, block_size={block_size}, ' - f'head_size={head_size}') - run_single_query_cached_kv_attention_quantized( - num_tokens=37, - num_heads=3, - head_size=head_size, - block_size=block_size, - num_blocks=1024, - dtype=dtype, - ) - - def test_multi_query_kv_attention() -> None: torch.random.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED) From bf3eb58ac02f646977ebbbc05e8f7d20d1541d56 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Thu, 28 Sep 2023 16:34:25 +0800 Subject: [PATCH 28/52] evaluation support quant --- benchmarks/benchmark_evaluation.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/benchmarks/benchmark_evaluation.py b/benchmarks/benchmark_evaluation.py index 7bdc92b53fe0..00958cbb14c7 100644 --- a/benchmarks/benchmark_evaluation.py +++ b/benchmarks/benchmark_evaluation.py @@ -57,6 +57,7 @@ def run_vllm( n: int = 1, use_beam_search: bool = False, trust_remote_code: bool = False, + quantmethod: str = None, ) -> List[RequestOutput]: llm = LLM( model=model, @@ -66,6 +67,7 @@ def run_vllm( trust_remote_code=trust_remote_code, kv_cache_dtype=kv_cache_dtype, kv_quant_params_path=kv_quant_params_path, + quantization = quantmethod ) for prompt in requests: sampling_params = SamplingParams( @@ -110,7 +112,15 @@ def evalute( def main(args: argparse.Namespace): subjects = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", "college_computer_science", + "college_mathematics", ] dataset, labels, nums_questions = sample_requests( args.dev_data_path, @@ -129,6 +139,7 @@ def main(args: argparse.Namespace): args.seed, args.n, args.use_beam_search, args.trust_remote_code, + args.quantization ) sub2acc = evalute( request_outputs, @@ -174,5 +185,8 @@ def main(args: argparse.Namespace): parser.add_argument("--kv-quant-params-path", type=str, default=None) + parser.add_argument("--quantization", + type=str, + default=None) args = parser.parse_args() main(args) From a0be417c0235e30070dff940c750721383699c6b Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Thu, 28 Sep 2023 10:53:54 +0800 Subject: [PATCH 29/52] fuse dequant silu and quant --- csrc/activation.cpp | 33 ++++------ csrc/activation_kernels.cu | 131 +++++++++++++++++++++---------------- 2 files changed, 86 insertions(+), 78 deletions(-) diff --git a/csrc/activation.cpp b/csrc/activation.cpp index c100f89ac737..76dda452e822 100644 --- a/csrc/activation.cpp +++ b/csrc/activation.cpp @@ -1,28 +1,19 @@ #include -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); +void silu_and_mul(torch::Tensor &out, torch::Tensor &input); -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); +void gelu_new(torch::Tensor &out, torch::Tensor &input); -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); +void gelu_fast(torch::Tensor &out, torch::Tensor &input); + +void invoke_dequant_silu_and_mul_quant(torch::Tensor &out, torch::Tensor &input, + const float scale_gate, + const float scale_up, + const float scale_out); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - m.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - m.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + m.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + m.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); + m.def("invoke_dequant_silu_and_mul_quant", &invoke_dequant_silu_and_mul_quant, "Dequant input, apply silu act and quant output"); } diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index c6ae5db8f9c4..23969b3a56bb 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,21 +1,21 @@ -#include #include +#include #include "dispatch_utils.h" +#include "quant_utils.cuh" namespace vllm { -template -__device__ __forceinline__ T silu(const T& x) { +template __device__ __forceinline__ T silu(const T &x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template -__global__ void silu_and_mul_kernel( - scalar_t* __restrict__ out, // [num_tokens, d] - const scalar_t* __restrict__ input, // [num_tokens, 2, d] - const int d) { +template +__global__ void +silu_and_mul_kernel(scalar_t *__restrict__ out, // [num_tokens, d] + const scalar_t *__restrict__ input, // [num_tokens, 2, d] + const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); @@ -24,11 +24,24 @@ __global__ void silu_and_mul_kernel( } } +__global__ void dequant_silu_and_mul_quant_kernel( + int8_t *__restrict__ out, // [num_tokens, d] + const int32_t *__restrict__ input, // [num_tokens, 2, d] + const int d, const float scale_gate, const float scale_up, + const float scale_out) { + const int token_idx = blockIdx.x; + for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate; + const float y = + (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up; + out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out); + } +} + } // namespace vllm -void silu_and_mul( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, 2 * d] +void silu_and_mul(torch::Tensor &out, // [num_tokens, d] + torch::Tensor &input) // [num_tokens, 2 * d] { int num_tokens = input.size(0); int d = input.size(1) / 2; @@ -36,25 +49,35 @@ void silu_and_mul( dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "silu_and_mul_kernel", - [&] { - vllm::silu_and_mul_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - d); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] { + vllm::silu_and_mul_kernel<<>>( + out.data_ptr(), input.data_ptr(), d); + }); +} + +void invoke_dequant_silu_and_mul_quant(torch::Tensor &out, torch::Tensor &input, + const float scale_gate, + const float scale_up, + const float scale_out) { + int num_tokens = input.size(0); + int d = input.size(1) / 2; + + dim3 grid(num_tokens); + dim3 block(std::min(d, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::dequant_silu_and_mul_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), d, scale_gate, + scale_up, scale_out); } namespace vllm { // Element-wise activation kernel template. -template -__global__ void activation_kernel( - scalar_t* __restrict__ out, // [num_tokens, d] - const scalar_t* __restrict__ input, // [num_tokens, d] - const int d) { +template +__global__ void +activation_kernel(scalar_t *__restrict__ out, // [num_tokens, d] + const scalar_t *__restrict__ input, // [num_tokens, d] + const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * d + idx]); @@ -65,50 +88,44 @@ __global__ void activation_kernel( } // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int num_tokens = input.size(0); \ - int d = input.size(1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int num_tokens = input.size(0); \ + int d = input.size(1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template -__device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); +template __device__ __forceinline__ T gelu_new_kernel(const T &x) { + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template -__device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); +template +__device__ __forceinline__ T gelu_fast_kernel(const T &x) { + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } } // namespace vllm -void gelu_new( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, d] +void gelu_new(torch::Tensor &out, // [num_tokens, d] + torch::Tensor &input) // [num_tokens, d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, d] +void gelu_fast(torch::Tensor &out, // [num_tokens, d] + torch::Tensor &input) // [num_tokens, d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } From 52af06e6ac67f6782a76f1001cd284fdff7d9aad Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Thu, 28 Sep 2023 10:54:38 +0800 Subject: [PATCH 30/52] fuse dequant and add residual --- csrc/fused.cpp | 12 ++++++++++++ csrc/fused_kernels.cu | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 csrc/fused.cpp create mode 100644 csrc/fused_kernels.cu diff --git a/csrc/fused.cpp b/csrc/fused.cpp new file mode 100644 index 000000000000..eeede6f3bd94 --- /dev/null +++ b/csrc/fused.cpp @@ -0,0 +1,12 @@ +#include + +void invoke_dequant_add_residual( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + float scale); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("invoke_dequant_add_residual", &invoke_dequant_add_residual, + "Add the dequanted result and residual."); +} diff --git a/csrc/fused_kernels.cu b/csrc/fused_kernels.cu new file mode 100644 index 000000000000..a085a0718e2d --- /dev/null +++ b/csrc/fused_kernels.cu @@ -0,0 +1,38 @@ +#include +#include + +#include "dispatch_utils.h" + +namespace vllm { +template +__global__ void dequant_add_residual_kernel(const int32_t *__restrict__ input, + const T *__restrict__ residual, + T *__restrict__ output, + const float scale, int m, int n) { + const int tid = threadIdx.x; + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = + (T)((((float)input[blockIdx.x * n + i]) * scale) + + (float)residual[blockIdx.x * n + i]); + } +} +} // namespace vllm + +void invoke_dequant_add_residual( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + float scale) { + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + residual.scalar_type(), "dequant_add_residual_kernel", [&] { + vllm::dequant_add_residual_kernel<<>>( + input.data_ptr(), residual.data_ptr(), + out.data_ptr(), scale, m, n); + }); +} From 627b766953da9ef5e9704cc6937db3619b428583 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Thu, 28 Sep 2023 10:55:50 +0800 Subject: [PATCH 31/52] fuse dequant, add residual, rms_norm and quant --- csrc/layernorm.cpp | 11 +++++ csrc/layernorm_kernels.cu | 91 +++++++++++++++++++++++++-------------- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp index c1092b9583e7..c9f917e1aab8 100644 --- a/csrc/layernorm.cpp +++ b/csrc/layernorm.cpp @@ -8,10 +8,21 @@ void invoke_rms_norm_quant(torch::Tensor &out, // [num_tokens, hidden_size] torch::Tensor &gamma, // [hidden_size] float epsilon); +void invoke_dequant_add_residual_rms_norm_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon, float scale); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rms_norm", &rms_norm, "Apply Root Mean Square (RMS) Normalization to the input tensor."); m.def("invoke_rms_norm_quant", &invoke_rms_norm_quant, "Apply Root Mean Square (RMS) Normalization to the input tensor and " "quant output."); + m.def("invoke_dequant_add_residual_rms_norm_quant", + &invoke_dequant_add_residual_rms_norm_quant, + "Add the dequanted result and residual, then use RMS norm and quant " + "output."); } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index a20b9d4ddbe0..205016021a92 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -36,9 +36,10 @@ rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size] } template -__global__ void RMSLayerNorm(const T *__restrict input, - const T *__restrict gamma, int8_t *output, - const float layernorm_eps, int m, int n) { +__global__ void rms_norm_quant_kernel(const T *__restrict__ input, + const T *__restrict__ gamma, + int8_t *__restrict__ output, + const float layernorm_eps, int m, int n) { // layernorm module in the T5 style No bias and no subtraction of mean. const int tid = threadIdx.x; @@ -59,40 +60,40 @@ __global__ void RMSLayerNorm(const T *__restrict input, __syncthreads(); for (int i = tid; i < n; i += blockDim.x) { - output[blockIdx.x * n + i] = - // float_to_int8_rn((((float)input[blockIdx.x * n + i]) * s_variance) * - // (float)(ldg(&gamma[i]))); - float_to_int8_rn((((float)input[blockIdx.x * n + i]) * s_variance) * - (float)(gamma[i])); + output[blockIdx.x * n + i] = float_to_int8_rn( + (((float)input[blockIdx.x * n + i]) * s_variance) * (float)(gamma[i])); } } template -void invokeRMSLayerNorm(int8_t *out, const T *input, const T *gamma, - // const T* beta, - const float layernorm_eps, const int m, const int n, - cudaStream_t stream) { - // if (beta != nullptr) { - // invokeGeneralLayerNorm(out, input, gamma, beta, layernorm_eps, m, n, - // (float*)nullptr, 0, stream); return; - // } +__global__ void dequant_add_residual_rms_norm_quant_kernel( + const int32_t *__restrict__ input, const T *__restrict__ residual, + int8_t *__restrict__ output, const T *__restrict__ gamma, + const float layernorm_eps, const float scale, int m, int n) { + // layernorm module in the T5 style No bias and no subtraction of mean. + const int tid = threadIdx.x; - dim3 grid(m); - dim3 block(min(n, 1024)); + __shared__ float s_variance; + float variance = 0.0f; - /* For general cases, n is equal to hidden_units, e.g., 512/1024. - Since we have warp shuffle inside the code, block.x % 32 should be 0. - */ - if (n % 32 != 0) { - block.x = 1024; + float local_var_sum = 0.0f; + for (int i = tid; i < n; i += blockDim.x) { + float diff = (((float)input[blockIdx.x * n + i]) * scale) + + (float)residual[blockIdx.x * n + i]; + // float diff = (float)(input[blockIdx.x * n + i]); + local_var_sum += diff * diff; } + variance = blockReduceSum(local_var_sum); - block.x = - block.x / (4 / sizeof(T)); // if using half, only need half of block.x + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (float)n + layernorm_eps); + } + __syncthreads(); - /* should pay attention to the rsqrt precision*/ - RMSLayerNorm<<>>(input, gamma, out, layernorm_eps, - m, n); // For gpt-3 + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = float_to_int8_rn( + (((float)input[blockIdx.x * n + i]) * s_variance) * (float)(gamma[i])); + } } } // namespace vllm @@ -124,9 +125,33 @@ void invoke_rms_norm_quant(torch::Tensor &out, // [num_tokens, hidden_size] dim3 block(min(n, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "invokeRMSLayerNorm", [&] { - vllm::RMSLayerNorm<<>>( - input.data_ptr(), gamma.data_ptr(), out.data_ptr(), - epsilon, m, n); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_quant_kernel", [&] { + vllm::rms_norm_quant_kernel<<>>( + input.data_ptr(), gamma.data_ptr(), + out.data_ptr(), epsilon, m, n); + }); +} + +void invoke_dequant_add_residual_rms_norm_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon, float scale) { + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", + [&] { + vllm::dequant_add_residual_rms_norm_quant_kernel + <<>>( + input.data_ptr(), residual.data_ptr(), + out.data_ptr(), gamma.data_ptr(), epsilon, + scale, m, n); + }); } From dfc957252f7ebd473a8759f8166fa5b6959d5f18 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Thu, 28 Sep 2023 10:56:49 +0800 Subject: [PATCH 32/52] fuse dequant and pos_encoding --- csrc/pos_encoding.cpp | 27 ++++--- csrc/pos_encoding_kernels.cu | 138 +++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 11 deletions(-) diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp index eee0cf0d0fa0..60930025c988 100644 --- a/csrc/pos_encoding.cpp +++ b/csrc/pos_encoding.cpp @@ -1,16 +1,21 @@ #include -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox); +void invoke_dequant_rotary_embedding( + torch::Tensor &positions, // [num_tokens] + torch::Tensor &query, // [num_tokens, num_heads * head_size] + torch::Tensor &query_out, // [num_tokens, num_heads * head_size] + torch::Tensor &key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor &key_out, // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor &cos_sin_cache, // [max_position, rot_dim] + const float query_scale, const float key_scale, bool is_neox); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + m.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + m.def("invoke_dequant_rotary_embedding", &invoke_dequant_rotary_embedding, + "Dequant the input and apply rotary embedding."); } diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b4351ee0d794..c12fd49fb26c 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -35,6 +35,38 @@ inline __device__ void apply_rotary_embedding( arr[y_index] = y * cos + x * sin; } +template +inline __device__ void apply_dequant_rotary_embedding( + const int32_t* __restrict__ arr, + scalar_t* __restrict__ arr_out, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim, + const float scale) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); + } + + const scalar_t x = (scalar_t)((float)arr[x_index] * scale); + const scalar_t y = (scalar_t)((float)arr[y_index] * scale); + arr_out[x_index] = x * cos - y * sin; + arr_out[y_index] = y * cos + x * sin; +} + template __global__ void rotary_embedding_kernel( const int64_t* __restrict__ positions, // [num_tokens] @@ -75,6 +107,50 @@ __global__ void rotary_embedding_kernel( } } +template +__global__ void dequant_rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [num_tokens] + const int32_t* __restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ query_out, // [num_tokens, num_heads, head_size] + const int32_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] + scalar_t* __restrict__ key_out, // [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int query_stride, + const int key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size, + const float query_scale, + const float key_scale) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_dequant_rotary_embedding(query + token_head, query_out + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim, query_scale); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_dequant_rotary_embedding(key + token_head, key_out + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim, key_scale); + } +} + } // namespace vllm void rotary_embedding( @@ -125,3 +201,65 @@ void rotary_embedding( } }); } + + +void invoke_dequant_rotary_embedding( + torch::Tensor& positions, // [num_tokens] + torch::Tensor& query, // [num_tokens, num_heads * head_size] + torch::Tensor& query_out, // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor& key_out, // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + const float query_scale, + const float key_scale, + bool is_neox) { + int num_tokens = query.size(0); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(1) / head_size; + int num_kv_heads = key.size(1) / head_size; + int query_stride = query.stride(0); + int key_stride = key.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), + "dequant_rotary_embedding_kernel", + [&] { + if (is_neox) { + vllm::dequant_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + query_out.data_ptr(), + key.data_ptr(), + key_out.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size, + query_scale, + key_scale); + } else { + vllm::dequant_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + query_out.data_ptr(), + key.data_ptr(), + key_out.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size, + query_scale, + key_scale); + } + }); +} From e025b664caa365f756f9554bcfe393b29a1a2fc3 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Thu, 28 Sep 2023 10:57:27 +0800 Subject: [PATCH 33/52] setup for fused kernels --- setup.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/setup.py b/setup.py index 8c886e7a892a..079a816bff45 100644 --- a/setup.py +++ b/setup.py @@ -138,6 +138,17 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ) ext_modules.append(cache_extension) +# Fuse kernels. +fused_extension = CUDAExtension( + name="vllm.fused_kernels", + sources=["csrc/fused.cpp", "csrc/fused_kernels.cu"], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, +) +ext_modules.append(fused_extension) + # Attention kernels. attention_extension = CUDAExtension( name="vllm.attention_ops", From 9eba3c389ee95bd538ea66c347d6fda8e49b8bda Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Thu, 28 Sep 2023 18:31:26 +0800 Subject: [PATCH 34/52] fix bugs --- csrc/layernorm_kernels.cu | 11 ++++++----- csrc/pos_encoding_kernels.cu | 5 +++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 205016021a92..42999f4dc2a8 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -78,9 +78,8 @@ __global__ void dequant_add_residual_rms_norm_quant_kernel( float local_var_sum = 0.0f; for (int i = tid; i < n; i += blockDim.x) { - float diff = (((float)input[blockIdx.x * n + i]) * scale) + - (float)residual[blockIdx.x * n + i]; - // float diff = (float)(input[blockIdx.x * n + i]); + float diff = ((((float)input[blockIdx.x * n + i]) * scale) + + (float)residual[blockIdx.x * n + i]); local_var_sum += diff * diff; } variance = blockReduceSum(local_var_sum); @@ -91,8 +90,10 @@ __global__ void dequant_add_residual_rms_norm_quant_kernel( __syncthreads(); for (int i = tid; i < n; i += blockDim.x) { - output[blockIdx.x * n + i] = float_to_int8_rn( - (((float)input[blockIdx.x * n + i]) * s_variance) * (float)(gamma[i])); + float tmp = ((((float)input[blockIdx.x * n + i]) * scale) + + (float)residual[blockIdx.x * n + i]); + output[blockIdx.x * n + i] = + float_to_int8_rn((tmp * s_variance) * (float)(gamma[i])); } } diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index c12fd49fb26c..7f62fbcdfd51 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -220,12 +220,13 @@ void invoke_dequant_rotary_embedding( int num_kv_heads = key.size(1) / head_size; int query_stride = query.stride(0); int key_stride = key.stride(0); - + std::cout << rot_dim << std::endl; + std::cout << query_stride << std::endl; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), + query_out.scalar_type(), "dequant_rotary_embedding_kernel", [&] { if (is_neox) { From 1e603485214a8f75d6b15ebb63001a35c22298e0 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Mon, 9 Oct 2023 19:54:58 +0800 Subject: [PATCH 35/52] add tests for fusion kernels --- tests/kernels/test_activation.py | 37 +++++++++++ tests/kernels/test_layernorm.py | 98 +++++++++++++++++++++++++++++- tests/kernels/test_pos_encoding.py | 73 +++++++++++++++++++++- 3 files changed, 205 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 8aa35d2b2340..a1b485ec71e8 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -9,6 +9,9 @@ NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] +SCALE_UP = [0.09, 1.2, 1.9] +SCALE_GATE = [2.17, 1.2, 1.9] +SCALE_OUT = [1.2, 1.9, 0.17] def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: @@ -36,6 +39,40 @@ def test_silu_and_mul( assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale_gate", SCALE_GATE) +@pytest.mark.parametrize("scale_up", SCALE_UP) +@pytest.mark.parametrize("scale_out", SCALE_OUT) +@torch.inference_mode() +def test_dequant_silu_and_mul_quant( + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, + scale_gate: float, + scale_up: float, + scale_out: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + # x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') + x = torch.randint(-1000, 1000, (num_tokens, 2 * d), dtype=torch.int32, device='cuda') + x_ = torch.empty_like(x, dtype=dtype) + x_[:, :d] = x[:, :d] * scale_gate + x_[:, d:] = x[:, d:] * scale_up + out1 = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + activation_ops.silu_and_mul(out1, x_) + out1 = (out1 / scale_out).round().clamp(-128, 127).to(torch.int8) + # ref_out = ref_silu_and_mul(x) + + out2 = torch.empty(num_tokens, d, dtype=torch.int8, device='cuda') + activation_ops.invoke_dequant_silu_and_mul_quant(out2, x, scale_gate, scale_up, scale_out) + assert torch.allclose(out1, out2, atol=2) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index a63ef5cc76ff..67d4aab7a7cf 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -2,12 +2,13 @@ import torch import torch.nn as nn -from vllm import layernorm_ops +from vllm import layernorm_ops, fused_kernels DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] class RefRMSNorm(nn.Module): @@ -56,3 +57,98 @@ def test_rms_norm( ) ref_out = ref(x) assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_rms_norm_quant( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(hidden_size**-0.5) + x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x.uniform_(-scale, scale) + ref = RefRMSNorm(hidden_size).to(dtype).cuda() + + out1 = torch.empty_like(x) + layernorm_ops.rms_norm( + out1, + x, + ref.weight.data, + ref.variance_epsilon, + ) + out1 = out1.clamp(-128, 127).round().to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_rms_norm_quant(out2, x, ref.weight.data, ref.variance_epsilon) + assert torch.allclose(out1, out2, atol=1.0) + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_dequant_add_residual_rms_norm_quant( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + # x = torch.randint(torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max, (num_tokens, hidden_size), dtype=torch.int32, device="cuda") + x = torch.randint(-1000, 1000, (num_tokens, hidden_size), dtype=torch.int32, device="cuda") + residual.uniform_(-s, s) + ref = RefRMSNorm(hidden_size).to(dtype).cuda() + x_ = (x * scale + residual).to(dtype) + + out1 = torch.empty_like(x_) + layernorm_ops.rms_norm( + out1, + x_, + ref.weight.data, + ref.variance_epsilon, + ) + out1 = out1.round().clamp(-128, 127).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_dequant_add_residual_rms_norm_quant(out2, x, residual, ref.weight.data, ref.variance_epsilon, scale) + + assert torch.allclose(out1, out2, atol=1.0) + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_dequant_add_residual( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.randint(torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max, (num_tokens, hidden_size), dtype=torch.int32, device="cuda") + residual.uniform_(-s, s) + out1 = (x * scale + residual).to(dtype) + + out2 = torch.empty_like(x, dtype=dtype) + fused_kernels.invoke_dequant_add_residual(out2, x, residual, scale) + + assert torch.allclose(out1, out2, atol=0.001) \ No newline at end of file diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 0d255900d4c1..d3d9fedcf324 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -10,10 +10,12 @@ IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 96, 112, 128, 256] -ROTARY_DIMS = [None, 32] # None means rotary dim == head size +ROTARY_DIMS = [None] # None means rotary dim == head size NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing SEEDS = [0] +QUERY_SCALE = [0.09, 1.13, 1.78] +KEY_SCALE = [0.23, 0.78, 1.45] def rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -168,7 +170,74 @@ def test_rotary_embedding( ) ref_query = ref_query.view(num_tokens, num_heads * head_size) ref_key = ref_key.view(num_tokens, num_heads * head_size) - # Compare the results. assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("query_scale", QUERY_SCALE) +@pytest.mark.parametrize("key_scale", KEY_SCALE) +@torch.inference_mode() +def test_dequant_rotary_embedding( + is_neox_style: bool, + num_tokens: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + query_scale: float, + key_scale: float, + max_position: int = 8192, + base: int = 10000, +) -> None: + if rotary_dim is None: + rotary_dim = head_size + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") + query = torch.randint(-1000, 1000, (num_tokens, num_heads * head_size), + dtype=torch.int32, + device="cuda") + key = torch.randint(-1000, 1000, (num_tokens, num_heads * head_size), + dtype=torch.int32, + device="cuda") + query_ = (query * query_scale).to(dtype) + key_ = (key * key_scale).to(dtype) + + # Create the rotary embedding. + inv_freq = 1.0 / (base**( + torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + + ref_rotary_embedding = RefRotaryEmbedding( + dim=rotary_dim, + is_neox_style=is_neox_style, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device="cuda") + ref_query, ref_key = ref_rotary_embedding( + positions, + query_.view(num_tokens, num_heads, head_size), + key_.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + out2_query = query_.clone() + out2_key = key_.clone() + + pos_encoding_ops.invoke_dequant_rotary_embedding(positions, query, out2_query, key, out2_key, head_size, cos_sin_cache, query_scale, key_scale, is_neox_style) + assert torch.allclose(ref_key, out2_key, atol=1e-4) From eab850ddde1bfa2fd0454ea66517882e02bfeef8 Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Thu, 12 Oct 2023 16:43:47 +0800 Subject: [PATCH 36/52] modify attention kernel test using pytest --- tests/kernels/test_attention.py | 519 ++++++++++++++------------------ 1 file changed, 228 insertions(+), 291 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index ba7bfb1ef8a3..141efdf3c8e8 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -11,13 +11,24 @@ MAX_SEQ_LEN = 8192 NUM_BLOCKS = 128 # Arbitrary values for testing -DTYPES = [torch.half, torch.bfloat16, torch.float] +DTYPES = [ + torch.half, + # torch.bfloat16, + torch.float, + ] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [8, 16, 32] -USE_ALIBI = [False, True] +BLOCK_SIZES = [ + 8, + 16, + # 32, + ] +USE_ALIBI = [ + False, + True, + ] SEEDS = [0] @@ -91,144 +102,6 @@ def ref_single_query_cached_kv_attention( out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) -def ref_single_query_cached_kv_attention_quantized( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - k_scale: float, - k_zp: float, - v_scale: float, - v_zp: float, -) -> None: - num_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - block_size = value_cache.shape[3] - - num_input_tokens = query.shape[0] - for i in range(num_input_tokens): - q = query[i].unsqueeze(0) - block_table = block_tables[i] - context_len = int(context_lens[i]) - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_heads, head_size) - k = k.to(torch.float32) - k = k * k_scale + k_zp - k = k.to(q.dtype) - keys.append(k) - - v = value_cache[block_number, :, :, block_offset] - v = v.to(torch.float32) - v = v * v_scale + v_zp - v = v.to(q.dtype) - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - - scale = 1.0 / (head_size**0.5) - out = ref_masked_attention(q, keys, values, scale) - out = out.view(num_heads, head_size) - output[i].copy_(out, non_blocking=True) - - -def ref_multi_query_kv_attention( - cu_seq_lens: List[int], - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dtype: torch.dtype, -) -> torch.Tensor: - head_size = query.shape[-1] - scale = 1.0 / (head_size**0.5) - - num_seqs = len(cu_seq_lens) - 1 - ref_outputs = [] - for i in range(num_seqs): - start_idx = cu_seq_lens[i] - end_idx = cu_seq_lens[i + 1] - seq_len = end_idx - start_idx - - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device='cuda') - - ref_output = ref_masked_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -def ref_multi_query_cached_kv_attention( - cu_query_lens: List[int], - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - dtype: torch.dtype, -) -> torch.Tensor: - num_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - block_size = value_cache.shape[3] - scale = 1.0 / (head_size**0.5) - - num_queries = len(cu_query_lens) - 1 - ref_outputs = [] - for i in range(num_queries): - start_idx = cu_query_lens[i] - end_idx = cu_query_lens[i + 1] - query_len = end_idx - start_idx - context_len = int(context_lens[i]) - block_table = block_tables[i] - - # Create attention mask - attn_mask = torch.triu(torch.ones(query_len, context_len), - diagonal=context_len - query_len + 1) * -1e5 - attn_mask = attn_mask.to(dtype=dtype, device='cuda') - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_heads, head_size) - keys.append(k) - - v = value_cache[block_number, :, :, block_offset] - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - - ref_output = ref_masked_attention( - query[start_idx:end_idx], - keys, - values, - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -369,69 +242,235 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_single_query_cached_kv_attention_quantized( - num_tokens: int, - num_heads: int, +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], head_size: int, + dtype: torch.dtype, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) + num_tokens = sum(seq_lens) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + qkv = torch.empty(num_tokens, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + num_queries_per_kv = num_query_heads // num_kv_heads + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +# def test_single_query_cached_kv_attention_quantized() -> None: +# torch.random.manual_seed(TEST_SEED) +# torch.cuda.manual_seed(TEST_SEED) +# for dtype in [ +# torch.half, +# torch.bfloat16, +# torch.float, +# ]: +# for block_size in [8, +# 16, +# ]: +# for head_size in [64, +# 80, +# 96, +# 112, +# 128, +# 256, +# ]: +# print(f'Testing single_query_cached_kv_attention with ' +# f'dtype={dtype}, block_size={block_size}, ' +# f'head_size={head_size}') +# run_single_query_cached_kv_attention_quantized( +# num_tokens=37, +# num_heads=3, +# head_size=head_size, +# block_size=block_size, +# num_blocks=1024, +# dtype=dtype, +# ) + + +def ref_single_query_cached_kv_attention_quantized( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + k_zp: float, + v_scale: float, + v_zp: float, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + k = k.to(torch.float32) + k = k * k_scale + k_zp + k = k.to(q.dtype) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + v = v.to(torch.float32) + v = v * v_scale + v_zp + v = v.to(q.dtype) + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len, device="cuda").int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_single_query_cached_kv_attention_quantized( + # kv_cache_factory, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, block_size: int, - num_blocks: int, dtype: torch.dtype, - num_kv_heads: int = None, + seed: int, k_scale: float = 1e-2, k_zp: float = 0.0, v_scale: float = 1e-2, v_zp: float = 0.0, ) -> None: - qkv = torch.empty(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device='cuda') - qkv.uniform_(-1e-3, 1e-3) - query, _, _ = qkv.unbind(dim=1) - - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_block_shape = (num_heads, head_size // x, block_size, x) - key_cache = torch.empty(size=(num_blocks, *key_block_shape), - dtype=torch.int8, ## fixed this to int8 - device='cuda') - key_cache.random_(-1, 2) ## change data range - value_block_shape = (num_heads, head_size, block_size) - value_cache = torch.empty(size=(num_blocks, *value_block_shape), - dtype=torch.int8, ## fixed this to int8 - device='cuda') - value_cache.random_(-1, 2) ## change data range + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] - for _ in range(num_tokens): + for _ in range(num_seqs): block_table = [ - random.randint(0, num_blocks - 1) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') - head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda") - - scale = float(1.0 / (head_size**0.5)) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - assert num_heads % num_kv_heads == 0 - num_queries_per_kv = num_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) + # Create the KV caches. - output = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') + # key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + # num_kv_heads, head_size, dtype, + # seed) + # key_cache, value_cache = key_caches[0], value_caches[0] + + x = 16 // torch.tensor([], dtype=torch.int8).element_size() ## use int8 dtype + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + key_cache = torch.randint(-10, 10, size=key_cache_shape, dtype=torch.int8, device='cuda') + value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) + value_cache = torch.randint(-10, 10, size=value_cache_shape, + dtype=torch.int8, ## change to int8 + device='cuda') + # Call the paged attention kernel. + output = torch.empty_like(query) attention_ops.single_query_cached_kv_quantized_attention( output, query, @@ -443,7 +482,7 @@ def run_single_query_cached_kv_attention_quantized( context_lens, block_size, max_context_len, - None, # ALiBi slopes. + alibi_slopes, # ALiBi slopes. k_scale, k_zp, v_scale, @@ -454,10 +493,13 @@ def run_single_query_cached_kv_attention_quantized( ref_single_query_cached_kv_attention_quantized( ref_output, query, + num_queries_per_kv, key_cache, value_cache, block_tables, context_lens, + scale, + alibi_slopes, k_scale, k_zp, v_scale, @@ -468,108 +510,3 @@ def run_single_query_cached_kv_attention_quantized( # there is a small difference in the final outputs. # We should use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - -def test_single_query_cached_kv_attention_quantized() -> None: - # FIXME: set TEST_SEED - torch.random.manual_seed(0) - torch.cuda.manual_seed(0) - for dtype in [ - torch.half, - torch.bfloat16, - torch.float, - ]: - for block_size in [8, - 16, - ]: - for head_size in [64, - 80, - 96, - 112, - 128, - 256, - ]: - print(f'Testing single_query_cached_kv_attention with ' - f'dtype={dtype}, block_size={block_size}, ' - f'head_size={head_size}') - run_single_query_cached_kv_attention_quantized( - num_tokens=37, - num_heads=3, - head_size=head_size, - block_size=block_size, - num_blocks=1024, - dtype=dtype, - ) - -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def run_multi_query_kv_attention( - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - seed: int, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) - num_tokens = sum(seq_lens) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - qkv = torch.empty(num_tokens, - num_query_heads + 2 * num_kv_heads, - head_size, - dtype=dtype, - device="cuda") - qkv.uniform_(-scale, scale) - query, key, value = qkv.split( - [num_query_heads, num_kv_heads, num_kv_heads], dim=1) - - num_queries_per_kv = num_query_heads // num_kv_heads - if num_queries_per_kv > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) - - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - ref_output = ref_multi_query_kv_attention( - cu_seq_lens, - query, - key, - value, - scale, - dtype, - ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - -def test_multi_query_kv_attention() -> None: - torch.random.manual_seed(TEST_SEED) - torch.cuda.manual_seed(TEST_SEED) - for dtype in [torch.half, torch.bfloat16, torch.float]: - for head_size in [64, 80, 96, 112, 128, 256]: - print(f'Testing multi_query_kv_attention with dtype={dtype}, ' - f'head_size={head_size}') - run_multi_query_kv_attention( - num_seqs=5, - num_heads=3, - head_size=head_size, - dtype=dtype, - ) From d3735c780149f0e0d0d9d2c04d13ce5a332c09df Mon Sep 17 00:00:00 2001 From: Lin Pengyun Date: Mon, 16 Oct 2023 16:05:01 +0800 Subject: [PATCH 37/52] fix quant parameter passing --- csrc/attention/attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 5cd5aeeddbc5..ddb2ad22b535 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -856,7 +856,7 @@ void single_query_cached_kv_attention_quantized_launcher( k_scale, \ k_zp, \ v_scale, \ - k_zp); + v_zp); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes From 3e7874ce548b0cad09e7ccd03d2029a59461c943 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 19:37:14 +0800 Subject: [PATCH 38/52] fix uncontiguous tensor case --- csrc/pos_encoding_kernels.cu | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 7f62fbcdfd51..4a36f1de4747 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -117,7 +117,9 @@ __global__ void dequant_rotary_embedding_kernel( const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, const int query_stride, + const int query_out_stride, const int key_stride, + const int key_out_stride, const int num_heads, const int num_kv_heads, const int head_size, @@ -136,8 +138,9 @@ __global__ void dequant_rotary_embedding_kernel( for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; const int token_head = token_idx * query_stride + head_idx * head_size; + const int token_out_head = token_idx * query_out_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_dequant_rotary_embedding(query + token_head, query_out + token_head, cos_ptr, + apply_dequant_rotary_embedding(query + token_head, query_out + token_out_head, cos_ptr, sin_ptr, rot_offset, embed_dim, query_scale); } @@ -145,8 +148,9 @@ __global__ void dequant_rotary_embedding_kernel( for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; const int token_head = token_idx * key_stride + head_idx * head_size; + const int token_out_head = token_idx * key_out_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_dequant_rotary_embedding(key + token_head, key_out + token_head, cos_ptr, + apply_dequant_rotary_embedding(key + token_head, key_out + token_out_head, cos_ptr, sin_ptr, rot_offset, embed_dim, key_scale); } } @@ -220,8 +224,8 @@ void invoke_dequant_rotary_embedding( int num_kv_heads = key.size(1) / head_size; int query_stride = query.stride(0); int key_stride = key.stride(0); - std::cout << rot_dim << std::endl; - std::cout << query_stride << std::endl; + int query_out_stride = query_out.stride(0); + int key_out_stride = key_out.stride(0); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -239,7 +243,9 @@ void invoke_dequant_rotary_embedding( cos_sin_cache.data_ptr(), rot_dim, query_stride, + query_out_stride, key_stride, + key_out_stride, num_heads, num_kv_heads, head_size, @@ -255,7 +261,9 @@ void invoke_dequant_rotary_embedding( cos_sin_cache.data_ptr(), rot_dim, query_stride, + query_out_stride, key_stride, + key_out_stride, num_heads, num_kv_heads, head_size, From 4ee29a9e1d05ec9eaa8f9610b4e264e7f9fdb1c0 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 19:37:55 +0800 Subject: [PATCH 39/52] add quant, dequant kernel --- csrc/fused.cpp | 14 ++++++++++ csrc/fused_kernels.cu | 62 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/csrc/fused.cpp b/csrc/fused.cpp index eeede6f3bd94..b3cd340047ca 100644 --- a/csrc/fused.cpp +++ b/csrc/fused.cpp @@ -6,7 +6,21 @@ void invoke_dequant_add_residual( torch::Tensor &residual, // [num_tokens, hidden_size] float scale); +void invoke_dequant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale); + +void invoke_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("invoke_dequant_add_residual", &invoke_dequant_add_residual, "Add the dequanted result and residual."); + m.def("invoke_dequant", &invoke_dequant, + "Dequant."); + m.def("invoke_quant", &invoke_quant, + "Quant."); } diff --git a/csrc/fused_kernels.cu b/csrc/fused_kernels.cu index a085a0718e2d..83f357142aa9 100644 --- a/csrc/fused_kernels.cu +++ b/csrc/fused_kernels.cu @@ -2,6 +2,8 @@ #include #include "dispatch_utils.h" +#include "quant_utils.cuh" +#include namespace vllm { template @@ -16,6 +18,28 @@ __global__ void dequant_add_residual_kernel(const int32_t *__restrict__ input, (float)residual[blockIdx.x * n + i]); } } + +template +__global__ void dequant_kernel(const int32_t *__restrict__ input, + T *__restrict__ output, + const float scale, int m, int n, int input_stride, int out_stride) { + const int tid = threadIdx.x; + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * out_stride + i] = + (T)(((float)input[blockIdx.x * input_stride + i]) * scale); + } +} + +template +__global__ void quant_kernel(const T *__restrict__ input, + int8_t *__restrict__ output, + const float scale, int m, int n) { + const int tid = threadIdx.x; + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = + float_to_int8_rn(((float)input[blockIdx.x * n + i]) / scale); + } +} } // namespace vllm void invoke_dequant_add_residual( @@ -36,3 +60,41 @@ void invoke_dequant_add_residual( out.data_ptr(), scale, m, n); }); } + +void invoke_dequant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale) { + int m = input.size(0); + int n = input.size(1); + int input_stride = input.stride(0); + int out_stride = out.stride(0); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + out.scalar_type(), "dequant_kernel", [&] { + vllm::dequant_kernel<<>>( + input.data_ptr(), out.data_ptr(), scale, m, n, input_stride, out_stride); + }); +} + +void invoke_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale) { + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), out.data_ptr(), scale, m, n); + }); +} \ No newline at end of file From b746c0c4fce9c9feee2f41fd3896d7d84afc5500 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 19:38:48 +0800 Subject: [PATCH 40/52] optimize layernorm kernel --- csrc/layernorm_kernels.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 42999f4dc2a8..1b5bf933edab 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -67,7 +67,7 @@ __global__ void rms_norm_quant_kernel(const T *__restrict__ input, template __global__ void dequant_add_residual_rms_norm_quant_kernel( - const int32_t *__restrict__ input, const T *__restrict__ residual, + const int32_t *__restrict__ input, T *__restrict__ residual, int8_t *__restrict__ output, const T *__restrict__ gamma, const float layernorm_eps, const float scale, int m, int n) { // layernorm module in the T5 style No bias and no subtraction of mean. @@ -80,6 +80,7 @@ __global__ void dequant_add_residual_rms_norm_quant_kernel( for (int i = tid; i < n; i += blockDim.x) { float diff = ((((float)input[blockIdx.x * n + i]) * scale) + (float)residual[blockIdx.x * n + i]); + residual[blockIdx.x * n + i] = (T)diff; local_var_sum += diff * diff; } variance = blockReduceSum(local_var_sum); @@ -90,10 +91,8 @@ __global__ void dequant_add_residual_rms_norm_quant_kernel( __syncthreads(); for (int i = tid; i < n; i += blockDim.x) { - float tmp = ((((float)input[blockIdx.x * n + i]) * scale) + - (float)residual[blockIdx.x * n + i]); output[blockIdx.x * n + i] = - float_to_int8_rn((tmp * s_variance) * (float)(gamma[i])); + float_to_int8_rn((((float)(residual[blockIdx.x * n + i])) * s_variance) * (float)(gamma[i])); } } From 8893069547312a9c9ace1d63bb21b66044a4e26c Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 19:39:58 +0800 Subject: [PATCH 41/52] support quant method in examples --- examples/offline_inference_quant.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/offline_inference_quant.py b/examples/offline_inference_quant.py index 29589ce30c23..819eb2a4e72b 100644 --- a/examples/offline_inference_quant.py +++ b/examples/offline_inference_quant.py @@ -44,6 +44,7 @@ def main(args: argparse.Namespace): trust_remote_code=args.trust_remote_code, kv_cache_dtype=args.kv_cache_dtype, kv_quant_params_path=args.kv_quant_params_path, + quantization=args.quantization ) requests, labels, _ = sample_requests( args.dev_data_path, @@ -103,5 +104,8 @@ def main(args: argparse.Namespace): parser.add_argument("--kv-quant-params-path", type=str, default=None) + parser.add_argument("--quantization", + type=str, + default="smoothquant") args = parser.parse_args() main(args) From b3bdc50dc25b1586cbaa0d44a2e2c280004c3a73 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 19:42:45 +0800 Subject: [PATCH 42/52] add python class DequantAddResidualI8RMSNormQuant, DequantPagedAttentionWithRoPEQuant, DequantSiluAndMulQuant --- vllm/model_executor/layers/activation.py | 34 +++++++ vllm/model_executor/layers/attention.py | 117 +++++++++++++++++++++++ vllm/model_executor/layers/layernorm.py | 47 +++++++-- 3 files changed, 188 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 9222fe27218c..1d3b87bca27a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -21,7 +21,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) activation_ops.silu_and_mul(out, x) return out + +class DequantSiluAndMulQuant(nn.Module): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. + + Shapes: + x: (num_tokens, 2 * d) + return: (num_tokens, d) + """ + def __init__(self, scale_in: float = 1.0, scale_out: float = 1.0) -> None: + super().__init__() + self.register_buffer('a', torch.tensor(scale_in, dtype=torch.float32, requires_grad=False)) + self.register_buffer('inscale', torch.tensor(scale_out, dtype=torch.float32, requires_grad=False)) + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + self.inscale = self.inscale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + self.inscale = self.inscale.to(*args, **kwargs) + self.inscale = self.inscale.to(torch.float32) + return self + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + d = x.shape[1] // 2 + out = torch.empty(num_tokens, d, dtype=torch.int8, device=x.device) + activation_ops.invoke_dequant_silu_and_mul_quant(out, x, self.a.item(), self.a.item(), self.inscale.item()) + return out class NewGELU(nn.Module): diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 48cb1a2e1ee4..c82df29e1af4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -10,6 +10,7 @@ from vllm import attention_ops from vllm import cache_ops from vllm import pos_encoding_ops +from vllm import fused_kernels from vllm.model_executor.input_metadata import InputMetadata _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -362,6 +363,122 @@ def forward( ) +class DequantPagedAttentionWithRoPEQuant(PagedAttention): + """PagedAttention with rotary embedding.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + rotary_dim: int, + max_position: int = 8192, + base: int = 10000, + num_kv_heads: Optional[int] = None, + is_neox_style: bool = True, + quant_kv_cache: bool = False, + kv_quant_params: torch.Tensor = None, + dequant_scale: float = 1.0, + quant_scale: float = 1.0 + ) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, quant_kv_cache, kv_quant_params) + self.is_neox_style = is_neox_style + + # Create the cos and sin cache. + inv_freq = 1.0 / (base**(torch.arange( + 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim)) + t = torch.arange(max_position, dtype=torch.float, device="cuda") + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + + # FIXME(woosuk): This assumes that we configure the default dtype when + # initializing the model. + # TODO(woosuk): Make it more robust. + torch_dtype = torch.get_default_dtype() + cache = cache.to(torch_dtype) + # Embedding size: [max_position, rotary_dim] + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.register_buffer('a', torch.tensor(dequant_scale, dtype=torch.float32, requires_grad=False)) + self.register_buffer('inscale', torch.tensor(quant_scale, dtype=torch.float32, requires_grad=False)) + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + self.inscale = self.inscale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + self.inscale = self.inscale.to(*args, **kwargs) + self.inscale = self.inscale.to(torch.float32) + return self + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + """ PagedAttention forward pass with rotary embedding. + + Args: + positions: shape = [num_tokens] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for paged attention. + cache_event: event to wait for the cache operations to finish. + + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + # Apply rotary embedding to the query and key before passing them + # to the attention op. + query_dequant = torch.empty_like(query, dtype=self.cos_sin_cache.dtype) + key_dequant = torch.empty_like(key, dtype=self.cos_sin_cache.dtype) + value_dequant = torch.empty_like(value, dtype=self.cos_sin_cache.dtype) + + fused_kernels.invoke_dequant(value_dequant, value, self.a.item()) + pos_encoding_ops.invoke_dequant_rotary_embedding( + positions, + query, + query_dequant, + key, + key_dequant, + self.head_size, + self.cos_sin_cache, + self.a.item(), + self.a.item(), + self.is_neox_style, + ) + out = super().forward( + query_dequant, + key_dequant, + value_dequant, + key_cache, + value_cache, + input_metadata, + cache_event, + ) + quant_out = torch.empty_like(out, dtype=torch.int8) + fused_kernels.invoke_quant(quant_out, out, self.inscale.item()) + return quant_out + + class PagedAttentionWithALiBi(PagedAttention): """PagedAttention with ALiBi attention bias.""" diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 99f8927a2fe3..abe8a9844ec9 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -48,16 +48,43 @@ def __init__( self.variance_epsilon = eps def forward(self, x: torch.Tensor) -> torch.Tensor: - # out = torch.empty_like(x) - # layernorm_ops.rms_norm( - # out, - # x, - # self.weight.data, - # self.variance_epsilon, - # ) - # # TODO: kernel fusion - # q_out = out.round().clamp(-128, 127).to(torch.int8) - # return q_out out = torch.empty_like(x, dtype=torch.int8) layernorm_ops.invoke_rms_norm_quant(out, x, self.weight.data, self.variance_epsilon) return out + + +class DequantAddResidualI8RMSNormQuant(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + scale: float = 1.0, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.register_buffer( + "a", torch.tensor(scale, dtype=torch.float32, requires_grad=False) + ) + self.variance_epsilon = eps + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + return self + + def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_dequant_add_residual_rms_norm_quant(out, x, residual, self.weight.data, self.variance_epsilon, self.a.item()) + return residual, out From 219738f0f760dc0e404bec06a3c1ce81427c6f91 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 19:55:34 +0800 Subject: [PATCH 43/52] add tests --- tests/kernels/test_activation.py | 26 +++++---- tests/kernels/test_layernorm.py | 89 ++++++++++++++++++++++++------ tests/kernels/test_pos_encoding.py | 89 ++++++++++++++++++------------ 3 files changed, 140 insertions(+), 64 deletions(-) diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index a1b485ec71e8..cd548cb000bb 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -32,8 +32,8 @@ def test_silu_and_mul( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.silu_and_mul(out, x) ref_out = ref_silu_and_mul(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -54,22 +54,26 @@ def test_dequant_silu_and_mul_quant( seed: int, scale_gate: float, scale_up: float, - scale_out: float + scale_out: float, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) # x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') - x = torch.randint(-1000, 1000, (num_tokens, 2 * d), dtype=torch.int32, device='cuda') + x = torch.randint( + -1000, 1000, (num_tokens, 2 * d), dtype=torch.int32, device="cuda" + ) x_ = torch.empty_like(x, dtype=dtype) x_[:, :d] = x[:, :d] * scale_gate x_[:, d:] = x[:, d:] * scale_up - out1 = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + out1 = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.silu_and_mul(out1, x_) out1 = (out1 / scale_out).round().clamp(-128, 127).to(torch.int8) # ref_out = ref_silu_and_mul(x) - out2 = torch.empty(num_tokens, d, dtype=torch.int8, device='cuda') - activation_ops.invoke_dequant_silu_and_mul_quant(out2, x, scale_gate, scale_up, scale_out) + out2 = torch.empty(num_tokens, d, dtype=torch.int8, device="cuda") + activation_ops.invoke_dequant_silu_and_mul_quant( + out2, x, scale_gate, scale_up, scale_out + ) assert torch.allclose(out1, out2, atol=2) @@ -86,8 +90,8 @@ def test_gelu_new( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.gelu_new(out, x) ref_out = get_activation("gelu_new")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -105,8 +109,8 @@ def test_gelu_fast( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.gelu_fast(out, x) ref_out = get_activation("gelu_fast")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 67d4aab7a7cf..22dd822f36b7 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -12,7 +12,6 @@ class RefRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): super().__init__() weight = torch.empty(hidden_size) @@ -24,8 +23,7 @@ def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) @@ -58,6 +56,7 @@ def test_rms_norm( ref_out = ref(x) assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -89,6 +88,7 @@ def test_rms_norm_quant( layernorm_ops.invoke_rms_norm_quant(out2, x, ref.weight.data, ref.variance_epsilon) assert torch.allclose(out1, out2, atol=1.0) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -96,11 +96,7 @@ def test_rms_norm_quant( @pytest.mark.parametrize("scale", SCALE) @torch.inference_mode() def test_dequant_add_residual_rms_norm_quant( - num_tokens: int, - hidden_size: int, - dtype: torch.dtype, - seed: int, - scale: float + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -108,7 +104,9 @@ def test_dequant_add_residual_rms_norm_quant( s = float(hidden_size**-0.5) residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") # x = torch.randint(torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max, (num_tokens, hidden_size), dtype=torch.int32, device="cuda") - x = torch.randint(-1000, 1000, (num_tokens, hidden_size), dtype=torch.int32, device="cuda") + x = torch.randint( + -1000, 1000, (num_tokens, hidden_size), dtype=torch.int32, device="cuda" + ) residual.uniform_(-s, s) ref = RefRMSNorm(hidden_size).to(dtype).cuda() x_ = (x * scale + residual).to(dtype) @@ -122,10 +120,13 @@ def test_dequant_add_residual_rms_norm_quant( ) out1 = out1.round().clamp(-128, 127).to(torch.int8) out2 = torch.empty_like(x, dtype=torch.int8) - layernorm_ops.invoke_dequant_add_residual_rms_norm_quant(out2, x, residual, ref.weight.data, ref.variance_epsilon, scale) + layernorm_ops.invoke_dequant_add_residual_rms_norm_quant( + out2, x, residual, ref.weight.data, ref.variance_epsilon, scale + ) assert torch.allclose(out1, out2, atol=1.0) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -133,22 +134,76 @@ def test_dequant_add_residual_rms_norm_quant( @pytest.mark.parametrize("scale", SCALE) @torch.inference_mode() def test_dequant_add_residual( - num_tokens: int, - hidden_size: int, - dtype: torch.dtype, - seed: int, - scale: float + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) s = float(hidden_size**-0.5) residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") - x = torch.randint(torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max, (num_tokens, hidden_size), dtype=torch.int32, device="cuda") + x = torch.randint( + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + (num_tokens, hidden_size), + dtype=torch.int32, + device="cuda", + ) residual.uniform_(-s, s) out1 = (x * scale + residual).to(dtype) out2 = torch.empty_like(x, dtype=dtype) fused_kernels.invoke_dequant_add_residual(out2, x, residual, scale) - assert torch.allclose(out1, out2, atol=0.001) \ No newline at end of file + assert torch.allclose(out1, out2, atol=0.001) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_dequant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + # residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.randint( + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + (num_tokens, hidden_size), + dtype=torch.int32, + device="cuda", + ) + # residual.uniform_(-s, s) + out1 = (x * scale).to(dtype) + + out2 = torch.empty_like(x, dtype=dtype) + fused_kernels.invoke_dequant(out2, x, scale) + assert torch.allclose(out1, out2, atol=0.001) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + # residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + # residual.uniform_(-s, s) + out1 = (x / scale).round().clamp(-128, 127).to(torch.int8) + + out2 = torch.empty_like(x, dtype=torch.int8) + fused_kernels.invoke_quant(out2, x, scale) + assert torch.allclose(out1, out2, atol=1) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index d3d9fedcf324..8c99f3bc77d4 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -14,13 +14,13 @@ NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing SEEDS = [0] -QUERY_SCALE = [0.09, 1.13, 1.78] -KEY_SCALE = [0.23, 0.78, 1.45] +QUERY_SCALE = [0.0002, 0.0008] +KEY_SCALE = [0.0002, 0.0008] def rotate_neox(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -60,7 +60,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings # Create cos and sin embeddings. - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) t = torch.arange(max_position_embeddings).float() freqs = torch.einsum("i,j->ij", t, inv_freq.float()) if is_neox_style: @@ -78,18 +78,19 @@ def forward( query: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size] ) -> Tuple[torch.Tensor, torch.Tensor]: - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] query_rot = query_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1) cos = F.embedding(positions, self.cos_cached) sin = F.embedding(positions, self.sin_cached) - query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin, - self.is_neox_style) + query_rot, key_rot = apply_rope( + query_rot, key_rot, cos, sin, self.is_neox_style + ) query_rot = query_rot.transpose(0, 1).contiguous() key_rot = key_rot.transpose(0, 1).contiguous() @@ -124,25 +125,20 @@ def test_rotary_embedding( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") - query = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device="cuda") - key = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device="cuda") + positions = torch.randint(0, max_position, (num_tokens,), device="cuda") + query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") + key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") # Create the rotary embedding. - inv_freq = 1.0 / (base**( - torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim) + ) t = torch.arange(max_position).float() freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() @@ -203,25 +199,34 @@ def test_dequant_rotary_embedding( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") - query = torch.randint(-1000, 1000, (num_tokens, num_heads * head_size), - dtype=torch.int32, - device="cuda") - key = torch.randint(-1000, 1000, (num_tokens, num_heads * head_size), - dtype=torch.int32, - device="cuda") + positions = torch.randint(0, max_position, (num_tokens,), device="cuda") + query = torch.randint( + -1000, + 1000, + (num_tokens, num_heads * head_size), + dtype=torch.int32, + device="cuda", + ) + key = torch.randint( + -1000, + 1000, + (num_tokens, num_heads * head_size), + dtype=torch.int32, + device="cuda", + ) query_ = (query * query_scale).to(dtype) key_ = (key * key_scale).to(dtype) # Create the rotary embedding. - inv_freq = 1.0 / (base**( - torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim) + ) t = torch.arange(max_position).float() freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") ref_rotary_embedding = RefRotaryEmbedding( dim=rotary_dim, @@ -236,8 +241,20 @@ def test_dequant_rotary_embedding( ) ref_query = ref_query.view(num_tokens, num_heads * head_size) ref_key = ref_key.view(num_tokens, num_heads * head_size) - out2_query = query_.clone() - out2_key = key_.clone() + out2_query = torch.empty_like(query_) + out2_key = torch.empty_like(key_) - pos_encoding_ops.invoke_dequant_rotary_embedding(positions, query, out2_query, key, out2_key, head_size, cos_sin_cache, query_scale, key_scale, is_neox_style) + pos_encoding_ops.invoke_dequant_rotary_embedding( + positions, + query, + out2_query, + key, + out2_key, + head_size, + cos_sin_cache, + query_scale, + key_scale, + is_neox_style, + ) assert torch.allclose(ref_key, out2_key, atol=1e-4) + assert torch.allclose(ref_query, out2_query, atol=1e-4) From 074e86b9fdf38c0ca46c782235203b0cc784416f Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 19:56:22 +0800 Subject: [PATCH 44/52] add w8a8linear without quant and dequant --- .../layers/int8_linear/w8a8linear.py | 117 +++++++++++++++++- 1 file changed, 115 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index 14144b436361..df6949361130 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -229,7 +229,6 @@ def from_float(module: torch.nn.Linear, input_scale): int8_module.inscale = torch.tensor(input_scale) return int8_module -# use cublasgemm a8w8o8 class W8A8OFP32LinearWithSFactorCublas(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): @@ -276,7 +275,6 @@ def forward(self, x): return y, None -# use cublasgemm a8w8o8 class W8A8O32LinearCublas(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): @@ -312,4 +310,119 @@ def forward(self, x): i8cugemm.linear_a8_w8_o32_(x, self.weight, y) y = y * self.a.item() y = y.view(*x_shape[:-1], -1) + return y, None + +class W8A8OFP32LinearWithSFactorCublasNoQuant(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.float32, requires_grad=False)) + # self.register_buffer('a', torch.tensor(alpha)) + # self.register_buffer('inscale', torch.tensor(inscale)) + + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + # self.a = self.a.cpu() + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + # self.a = self.a.to(*args, **kwargs) + # self.a = self.a.to(torch.float32) + # self.inscale = self.inscale.to(*args, **kwargs) + # self.inscale = self.inscale.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + # quant activation + # x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) + y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) + # y = y * self.a.item() + y = y.view(*x_shape[:-1], -1) + return y, None +class W8A8O32LinearCublasNoDequant(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.float32, requires_grad=False)) + # self.register_buffer('a', torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + # self.a = self.a.cpu() + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) + # y = y * self.a.item() + y = y.view(*x_shape[:-1], -1) + return y, None + + +class W8A8O32Linear(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.float32, requires_grad=False)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) + y = y.view(*x_shape[:-1], -1) return y, None \ No newline at end of file From d69100d5f6bfe90cc93b2a77b447e543a230e3a9 Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Tue, 17 Oct 2023 20:09:15 +0800 Subject: [PATCH 45/52] adjust code for fusion --- vllm/model_executor/models/llama.py | 91 +++++++++++++++++++++++------ 1 file changed, 73 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1d529b803bff..72314db173c4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -32,14 +32,18 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm, I8RMSNorm -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.activation import SiluAndMul, DequantSiluAndMulQuant +from vllm.model_executor.layers.layernorm import RMSNorm, I8RMSNorm, DequantAddResidualI8RMSNormQuant +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, DequantPagedAttentionWithRoPEQuant from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.layers.int8_linear.w8a8linear import ( W8A8OFP32LinearWithSFactorCublas, - W8A8O32LinearCublas) + W8A8O32LinearCublas, + W8A8O32LinearCublasNoDequant, + W8A8OFP32LinearWithSFactorCublasNoQuant, + W8A8O32Linear) +from vllm.model_executor.layers.fusion import DequantAddResidual from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -66,9 +70,14 @@ def __init__( self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" if self.use_int8: - self.gate_up_proj = W8A8O32LinearCublas(hidden_size, + # self.gate_up_proj = W8A8O32LinearCublas(hidden_size, + # 2 * intermediate_size) + self.gate_up_proj = W8A8O32Linear(hidden_size, 2 * intermediate_size) - self.down_proj = W8A8OFP32LinearWithSFactorCublas(intermediate_size, + + # self.down_proj = W8A8OFP32LinearWithSFactorCublas(intermediate_size, + # hidden_size) + self.down_proj = W8A8O32Linear(intermediate_size, hidden_size) else: self.gate_up_proj = ParallelLinear.column(hidden_size, @@ -87,14 +96,15 @@ def __init__( if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") - self.act_fn = SiluAndMul() + # self.act_fn = SiluAndMul() + self.act_fn = DequantSiluAndMulQuant() def forward(self, x): gate_up, _ = self.gate_up_proj(x) # FIXME: currently gate up share same scale, plan to use seperate scales x = self.act_fn(gate_up) x, _ = self.down_proj(x) - x = x.half() + # x = x.half() return x @@ -127,10 +137,16 @@ def __init__( self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" if self.use_int8: - self.qkv_proj = W8A8O32LinearCublas( + # self.qkv_proj = W8A8O32LinearCublas( + # hidden_size, + # (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) + self.qkv_proj = W8A8O32Linear( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) - self.o_proj = W8A8OFP32LinearWithSFactorCublas( + # self.o_proj = W8A8OFP32LinearWithSFactorCublas( + # self.total_num_heads * self.head_dim, + # hidden_size) + self.o_proj = W8A8O32Linear( self.total_num_heads * self.head_dim, hidden_size) else: @@ -151,7 +167,15 @@ def __init__( perform_initialization=False, quant_config=quant_config, ) - self.attn = PagedAttentionWithRoPE(self.num_heads, + # self.attn = PagedAttentionWithRoPE(self.num_heads, + # self.head_dim, + # self.scaling, + # base=self.rope_theta, + # rotary_dim=self.head_dim, + # num_kv_heads=self.num_kv_heads, + # quant_kv_cache=quant_kv_cache, + # kv_quant_params=kv_quant_params) + self.attn = DequantPagedAttentionWithRoPEQuant(self.num_heads, self.head_dim, self.scaling, base=self.rope_theta, @@ -169,14 +193,14 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - qkv = qkv.half() + # qkv = qkv.half() # FIXME: currently qkv share same scale, plan to use seperate scales q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) output, _ = self.o_proj(attn_output) - output = output.half() + # output = output.half() return output @@ -211,8 +235,13 @@ def __init__( if quant_config is not None and quant_config.get_name() == "smoothquant": self.input_layernorm = I8RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = I8RMSNorm(config.hidden_size, + # self.post_attention_layernorm = I8RMSNorm(config.hidden_size, + # eps=config.rms_norm_eps) + self.dequant_add_residual_layernorm_quant = DequantAddResidualI8RMSNormQuant(config.hidden_size, eps=config.rms_norm_eps) + # self.attn_dequant_add_residual = DequantAddResidual() + # self.mlp_dequant_add_residual = DequantAddResidual() + self.dequant_add_residual = DequantAddResidual() else: self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -237,13 +266,17 @@ def forward( input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states + # hidden_states = residual + hidden_states + # hidden_states = self.attn_dequant_add_residual(residual, hidden_states) # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + # residual = hidden_states + residual, hidden_states = self.dequant_add_residual_layernorm_quant(residual, hidden_states) + # hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + # hidden_states = residual + hidden_states + # hidden_states = self.mlp_dequant_add_residual(residual, hidden_states) + hidden_states = self.dequant_add_residual(residual, hidden_states) return hidden_states @@ -505,6 +538,28 @@ def _load_int8_weights(self, if is_transposed: loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight = loaded_weight.T + + is_fusion_weight = False + name_dict = { + "self_attn.q_proj.a": "self_attn.attn.a", + "self_attn.k_proj.a": "self_attn.attn.a", + "self_attn.v_proj.a": "self_attn.attn.a", + "self_attn.o_proj.inscale": "self_attn.attn.inscale", + "self_attn.o_proj.a": "dequant_add_residual_layernorm_quant.a", + "post_attention_layernorm.weight": "dequant_add_residual_layernorm_quant.weight", + "mlp.gate_proj.a": "mlp.act_fn.a", + "mlp.up_proj.a": "mlp.act_fn.a", + "mlp.down_proj.inscale": "mlp.act_fn.inscale", + "mlp.down_proj.a": "dequant_add_residual.a" + } + for weight_name in name_dict.keys(): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, name_dict[weight_name])] + param.copy_(loaded_weight) + is_fusion_weight = True + if is_fusion_weight: + continue is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: From 3e81f3d611b6e874b3201f0a42b99b5424cf33f8 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Wed, 18 Oct 2023 10:45:19 +0800 Subject: [PATCH 46/52] rm obsolete file --- vllm/model_executor/models/llamaq.py | 663 --------------------------- 1 file changed, 663 deletions(-) delete mode 100644 vllm/model_executor/models/llamaq.py diff --git a/vllm/model_executor/models/llamaq.py b/vllm/model_executor/models/llamaq.py deleted file mode 100644 index 6b9bc301411c..000000000000 --- a/vllm/model_executor/models/llamaq.py +++ /dev/null @@ -1,663 +0,0 @@ -from typing import Dict, List, Optional, Tuple - -import torch -from torch import nn -import torch.nn.functional as F - -from transformers import LlamaConfig - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE -from vllm.model_executor.layers.sampler import Sampler -# from vllm.model_executor.layers.temp_sampler import TempSampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights, - load_tensor_parallel_weights2) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs -from awq.quantize.qmodule import WQLinear -import awq_inference_engine -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast - -KVCache = Tuple[torch.Tensor, torch.Tensor] - -class QuantLlamaQRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - - # self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - # self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("cos_sin_cache", cache.half(), persistent=False) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - positions: torch.Tensor, - ): - # Apply rotary embedding to the query and key before passing them - # to the attention op. - # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape) - query = query.contiguous() - key = key.contiguous() - awq_inference_engine.rotary_embedding_neox( - positions, - query, - key, - self.dim, - self.cos_sin_cache, - ) - return query, key - - - -class FTLlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - # print(f"norm weight shape {self.weight.shape}") - - def forward(self, x): - x = x.unsqueeze(0) - output = torch.empty_like(x) - awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) - output = output.squeeze(0) - return output - - -class LlamaQMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.in_features = hidden_size - self.intermediate_size = intermediate_size - self.out_features = hidden_size - self.w_bit = 4 - self.g_size = 128 - - self.gate_up_proj = WQLinear(self.w_bit, self.g_size, self.in_features, 2 * self.intermediate_size, False, 'cuda') - self.down_proj = WQLinear(self.w_bit, self.g_size, self.intermediate_size, self.out_features, False, 'cuda') - self.act_fn = SiluAndMul() - - - def forward(self, x): - # return self.down_proj(self.custom_LlamaQ_mlp(x)) - gate_up = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x = self.down_proj(x) - return x - - def custom_LlamaQ_mlp(self, x): - out_shape = x.shape[:-1] + (self.intermediate_size, ) - x = x.reshape(-1, x.shape[-1]) - - gate_output = awq_inference_engine.gemm_forward_cuda( - x, self.gate_proj.qweight, self.gate_proj.scales, self.gate_proj.qzeros, 8 - ) - gate_output = self.act_fn(gate_output) - - up_output = awq_inference_engine.gemm_forward_cuda( - x, self.up_proj.qweight, self.up_proj.scales, self.up_proj.qzeros, 8 - ) - c = gate_output * up_output - c = c.reshape(out_shape) - return c - - -class LlamaQAttention2(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - ): - super().__init__() - self.hidden_size = hidden_size - # tensor_model_parallel_world_size = ( - # get_tensor_model_parallel_world_size()) - tensor_model_parallel_world_size = 1 - self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) - self.head_dim = hidden_size // self.total_num_heads - self.scaling = self.head_dim**-0.5 - self.w_bit = 4 - self.g_size = 128 - - self.qkv_proj = WQLinear(self.w_bit, self.g_size, hidden_size, 3 * self.total_num_heads * self.head_dim, False, 'cuda') - self.o_proj = WQLinear(self.w_bit, self.g_size, self.total_num_heads * self.head_dim, hidden_size, False, 'cuda') - - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - rotary_dim=self.head_dim) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: - # print(f"qkv proj size: {self.qkv_proj.shape}, hidden_states size: {hidden_states.shape} ") - # 这里把qkv_proj和o_proj都变成WQLinear - - qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) - output = self.o_proj(attn_output) - - return output - - -class LlamaQAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - ): - super().__init__() - self.hidden_size = hidden_size - # tensor_model_parallel_world_size = ( - # get_tensor_model_parallel_world_size()) - tensor_model_parallel_world_size = 1 - self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) - self.head_dim = hidden_size // self.total_num_heads - self.scaling = self.head_dim**-0.5 - - self.qkv_proj = ColumnParallelLinear( - hidden_size, - 3 * self.total_num_heads * self.head_dim, - bias=False, - gather_output=False, - perform_initialization=False, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False, - ) - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - rotary_dim=self.head_dim) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: - # print(f"qkv proj size: {self.qkv_proj.shape}, hidden_states size: {hidden_states.shape} ") - # 这里把qkv_proj和o_proj都变成WQLinear - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) - output, _ = self.o_proj(attn_output) - return output - - -class LlamaQDecoderLayer(nn.Module): - - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaQAttention2( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - ) - self.mlp = LlamaQMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = FTLlamaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = FTLlamaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - # self.input_layernorm = RMSNorm(config.hidden_size, - # eps=config.rms_norm_eps) - # self.post_attention_layernorm = RMSNorm(config.hidden_size, - # eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: - # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - cache_event=cache_event, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -class LlamaQModel(nn.Module): - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.embed_tokens = VocabParallelEmbedding( - vocab_size, config.hidden_size, perform_initialization=False) - self.layers = nn.ModuleList([ - LlamaQDecoderLayer(config) for _ in range(config.num_hidden_layers) - ]) - self.norm = FTLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): - if cache_events is None: - cache_event = None - else: - cache_event = cache_events[i] - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - cache_event, - ) - hidden_states = self.norm(hidden_states) - return hidden_states - - -class LlamaQForCausalLM(nn.Module): - - def __init__(self, config): - super().__init__() - self.config = config - self.model = LlamaQModel(config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) - self.sampler = Sampler(config.vocab_size) - #self.sampler = TempSampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) - - return next_tokens - - _column_parallel_weights = [ - "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", - "gate_proj.weight", "up_proj.weight" - ] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - - _column_parallel_weights_fp16 = [ - "embed_tokens.weight", "lm_head.weight", "model.norm.weight" - ] - - _row_parallel_weights_fp16 = [] - - _column_parallel_weights_int4 = [ - "qkv_proj.qweight", "gate_proj.qweight", "up_proj.qweight", - "qkv_proj.qzeros", "gate_proj.qzeros", "up_proj.qzeros", - "qkv_proj.scales", "gate_proj.scales", "up_proj.scales", - # "input_layernorm", "post_attention_layernorm" - ] - - _row_parallel_weights_int4 = ["o_proj.qweight", "down_proj.qweight", - "o_proj.qzeros", "down_proj.qzeros", - "o_proj.scales", "down_proj.scales"] - - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - use_np_cache: bool = False): - # tensor_model_parallel_world_size = ( - # get_tensor_model_parallel_world_size()) - tensor_model_parallel_world_size = 1 - # tensor_model_parallel_rank = get_tensor_model_parallel_rank() - tensor_model_parallel_rank = 0 - state_dict = self.state_dict() - - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): - if "rotary_emb.inv_freq" in name: - continue - - if "embed_tokens" in name or "lm_head" in name: - param = state_dict[name] - # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * - tensor_model_parallel_world_size) - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: - continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) - - def load_mix_weights(self, - model_name_or_path: str, - q_weight_path: str, - cache_dir: Optional[str] = None, - use_np_cache: bool = False): - # tensor_model_parallel_world_size = ( - # get_tensor_model_parallel_world_size()) - tensor_model_parallel_world_size = 1 - # tensor_model_parallel_rank = get_tensor_model_parallel_rank() - tensor_model_parallel_rank = 0 - state_dict = self.state_dict() - - column_parallel_weights_fp16 = [ - # "embed_tokens.weight", "lm_head.weight", "model.norm.weight", - # "input_layernorm", "post_attention_layernorm" - "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight" - ] - - row_parallel_weights_fp16 = ["o_proj.weight"] - - column_parallel_weights_int4 = [ - "gate_proj.qweight", "up_proj.qweight", - "gate_proj.qzeros", "up_proj.qzeros", - "gate_proj.scales", "up_proj.scales" - ] - - row_parallel_weights_int4 = ["down_proj.qweight", "down_proj.qzeros", "down_proj.scales"] - - - - # load fp16 - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): - if "rotary_emb.inv_freq" in name: - continue - - if "mlp" in name: - continue - - if "embed_tokens" in name or "lm_head" in name: - param = state_dict[name] - # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * - tensor_model_parallel_world_size) - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: - continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - param = state_dict[name] - # print(f"fp16 layer name: {name}") - load_tensor_parallel_weights2(param, loaded_weight, name, - tensor_model_parallel_rank) - print("****************** load int weight ***********************") - # load int4 - for name, loaded_weight in hf_model_weights_iterator( - q_weight_path, cache_dir, use_np_cache): - if "rotary_emb.inv_freq" in name: - continue - - if "embed_tokens" in name or "lm_head" in name: - continue - - if "self_attn" in name: - continue - - if "input_layernorm" in name or "post_attention_layernorm" in name: - continue - - if "model.norm.weight" in name: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - - shard_size = param.shape[1] // 2 - start = shard_size * stride_id - end = shard_size * (stride_id + 1) - param_slice = param.data[:, start:end] - - print(f"{name} param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue - - param = state_dict[name] - # print(f"int4 layer name: {name}") - load_tensor_parallel_weights2(param, loaded_weight, name, - tensor_model_parallel_rank) - - def load_int4_weights(self, - model_name_or_path: str, - q_weight_path: str, - cache_dir: Optional[str] = None, - use_np_cache: bool = False): - # tensor_model_parallel_world_size = ( - # get_tensor_model_parallel_world_size()) - tensor_model_parallel_world_size = 1 - # tensor_model_parallel_rank = get_tensor_model_parallel_rank() - tensor_model_parallel_rank = 0 - state_dict = self.state_dict() - # for name, weight in state_dict.items(): - # print(f"state dict name: {name}") - - # load int4 - for name, loaded_weight in hf_model_weights_iterator( - q_weight_path, cache_dir, use_np_cache): - if "rotary_emb.inv_freq" in name: - continue - - if "embed_tokens" in name or "lm_head" in name: - param = state_dict[name] - # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * tensor_model_parallel_world_size) - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - - is_attention_weight = False - - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: - continue - # print(f"int4 layer name: {name}") - # print(f"stride_id: {stride_id}, att_weight_name: {att_weight_name}") - param_name = name.replace(att_weight_name, "qkv_proj") - - param = state_dict[param_name] - shard_size = param.shape[1] // 3 - # loaded_weight = loaded_weight[ - # shard_size * tensor_model_parallel_rank:shard_size * - # (tensor_model_parallel_rank + 1)] - param_slice = param.data[:, shard_size * stride_id:shard_size * - (stride_id + 1)] - # print(f"*** {param_name}*** param.shape: {param.shape}, param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - - shard_size = param.shape[1] // 2 - start = shard_size * stride_id - end = shard_size * (stride_id + 1) - param_slice = param.data[:, start:end] - - # print(f"{name} param_slice.shape: {param_slice.shape}, loaded_weight.shape: {loaded_weight.shape}") - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue - - param = state_dict[name] - # print(f"int4 layer name: {name}") - if "norm" in name: - print(f"{name} shape: {loaded_weight.shape}") - load_tensor_parallel_weights2(param, loaded_weight, name, - tensor_model_parallel_rank) From 0ea256fb0e7b275082dd6f52038d9e046214de2f Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Thu, 19 Oct 2023 16:03:15 +0800 Subject: [PATCH 47/52] fix llama --- .../layers/quantized_linear/__init__.py | 7 ++ vllm/model_executor/models/llama.py | 89 ++++++++----------- 2 files changed, 46 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index bcb9a54e7a2c..4d843d3eab98 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -2,9 +2,16 @@ AWQColumnParallelLinear, AWQRowParallelLinear) from vllm.model_executor.parallel_utils.tensor_parallel import ( ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.int8_linear.w8a8linear import ( + W8A8OFP32LinearWithSFactorCublas, + W8A8O32LinearCublas, + W8A8O32LinearCublasNoDequant, + W8A8OFP32LinearWithSFactorCublasNoQuant, + W8A8O32Linear) _QUANTIZED_LINEAR_REGISTRY = { "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), + "smoothquant": (ColumnParallelLinear, RowParallelLinear) } diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 72314db173c4..9e3a3bdf4894 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -69,16 +69,16 @@ def __init__( super().__init__() self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + if self.use_int8: - # self.gate_up_proj = W8A8O32LinearCublas(hidden_size, - # 2 * intermediate_size) self.gate_up_proj = W8A8O32Linear(hidden_size, 2 * intermediate_size) - - # self.down_proj = W8A8OFP32LinearWithSFactorCublas(intermediate_size, - # hidden_size) self.down_proj = W8A8O32Linear(intermediate_size, hidden_size) + self.act_fn = DequantSiluAndMulQuant() else: self.gate_up_proj = ParallelLinear.column(hidden_size, 2 * intermediate_size, @@ -92,19 +92,13 @@ def __init__( input_is_parallel=True, perform_initialization=False, quant_config=quant_config) - - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - # self.act_fn = SiluAndMul() - self.act_fn = DequantSiluAndMulQuant() + self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) # FIXME: currently gate up share same scale, plan to use seperate scales x = self.act_fn(gate_up) x, _ = self.down_proj(x) - # x = x.half() return x @@ -137,18 +131,21 @@ def __init__( self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" if self.use_int8: - # self.qkv_proj = W8A8O32LinearCublas( - # hidden_size, - # (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) self.qkv_proj = W8A8O32Linear( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) - # self.o_proj = W8A8OFP32LinearWithSFactorCublas( - # self.total_num_heads * self.head_dim, - # hidden_size) self.o_proj = W8A8O32Linear( self.total_num_heads * self.head_dim, hidden_size) + self.attn = DequantPagedAttentionWithRoPEQuant(self.num_heads, + self.head_dim, + self.scaling, + base=self.rope_theta, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + quant_kv_cache=quant_kv_cache, + kv_quant_params=kv_quant_params) + else: self.qkv_proj = ParallelLinear.column( hidden_size, @@ -167,22 +164,14 @@ def __init__( perform_initialization=False, quant_config=quant_config, ) - # self.attn = PagedAttentionWithRoPE(self.num_heads, - # self.head_dim, - # self.scaling, - # base=self.rope_theta, - # rotary_dim=self.head_dim, - # num_kv_heads=self.num_kv_heads, - # quant_kv_cache=quant_kv_cache, - # kv_quant_params=kv_quant_params) - self.attn = DequantPagedAttentionWithRoPEQuant(self.num_heads, - self.head_dim, - self.scaling, - base=self.rope_theta, - rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads, - quant_kv_cache=quant_kv_cache, - kv_quant_params=kv_quant_params) + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + base=self.rope_theta, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + quant_kv_cache=quant_kv_cache, + kv_quant_params=kv_quant_params) def forward( self, @@ -215,6 +204,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = LlamaAttention( @@ -232,15 +222,12 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, ) - if quant_config is not None and quant_config.get_name() == "smoothquant": + if self.use_int8: self.input_layernorm = I8RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # self.post_attention_layernorm = I8RMSNorm(config.hidden_size, - # eps=config.rms_norm_eps) + # kernel fusion, post_attention_layernorm are fused into DequantAddResidualI8RMSNormQuant self.dequant_add_residual_layernorm_quant = DequantAddResidualI8RMSNormQuant(config.hidden_size, eps=config.rms_norm_eps) - # self.attn_dequant_add_residual = DequantAddResidual() - # self.mlp_dequant_add_residual = DequantAddResidual() self.dequant_add_residual = DequantAddResidual() else: self.input_layernorm = RMSNorm(config.hidden_size, @@ -266,17 +253,19 @@ def forward( input_metadata=input_metadata, cache_event=cache_event, ) - # hidden_states = residual + hidden_states - # hidden_states = self.attn_dequant_add_residual(residual, hidden_states) - - # Fully Connected - # residual = hidden_states - residual, hidden_states = self.dequant_add_residual_layernorm_quant(residual, hidden_states) - # hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - # hidden_states = residual + hidden_states - # hidden_states = self.mlp_dequant_add_residual(residual, hidden_states) - hidden_states = self.dequant_add_residual(residual, hidden_states) + + if self.use_int8: + residual, hidden_states = self.dequant_add_residual_layernorm_quant(residual, hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.dequant_add_residual(residual, hidden_states) + else: + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states From 29939aa116ea5b989d3ef19de3610cc9c5ad9c28 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Thu, 19 Oct 2023 16:16:24 +0800 Subject: [PATCH 48/52] remove cutlass dependency --- csrc/int8gemm/cutlass/bindings.cpp | 24 -- csrc/int8gemm/cutlass/bmm.cu | 211 ----------- csrc/int8gemm/cutlass/fused.cu | 25 -- csrc/int8gemm/cutlass/include/bmm.h | 10 - csrc/int8gemm/cutlass/include/common.h | 11 - csrc/int8gemm/cutlass/include/fused.h | 16 - csrc/int8gemm/cutlass/include/linear.h | 43 --- csrc/int8gemm/cutlass/linear.cu | 491 ------------------------- setup.py | 20 +- 9 files changed, 1 insertion(+), 850 deletions(-) delete mode 100644 csrc/int8gemm/cutlass/bindings.cpp delete mode 100644 csrc/int8gemm/cutlass/bmm.cu delete mode 100644 csrc/int8gemm/cutlass/fused.cu delete mode 100644 csrc/int8gemm/cutlass/include/bmm.h delete mode 100644 csrc/int8gemm/cutlass/include/common.h delete mode 100644 csrc/int8gemm/cutlass/include/fused.h delete mode 100644 csrc/int8gemm/cutlass/include/linear.h delete mode 100644 csrc/int8gemm/cutlass/linear.cu diff --git a/csrc/int8gemm/cutlass/bindings.cpp b/csrc/int8gemm/cutlass/bindings.cpp deleted file mode 100644 index 3bc20df7fbb0..000000000000 --- a/csrc/int8gemm/cutlass/bindings.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "include/bmm.h" -#include "include/fused.h" -#include "include/linear.h" -#include - -/* -adapt from https://github.com/Guangxuan-Xiao/torch-int -*/ -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_relu_a8_w8_b8_o8", &linear_relu_a8_w8_b8_o8, - "Linear ReLU (INT8)"); - m.def("linear_a8_w8_b32_o32", &linear_a8_w8_b32_o32, "Linear (INT32)"); - m.def("linear_a8_w8_bfp32_ofp32", &linear_a8_w8_bfp32_ofp32, - "Linear (I8-OFP32)"); - m.def("linear_a8_w8_b32_o32_with_scaling", &linear_a8_w8_b32_o32_with_scaling, - "Linear (INT32) with scaling"); - m.def("linear_a8_w8_b8_o8", &linear_a8_w8_b8_o8, "Linear (INT8)"); - m.def("dq_add_layernorm_q", &dq_add_layernorm_q, - "DQ + Add + LayerNorm (INT8)"); - m.def("bmm_s8t_s8n_s8t", &bmm_s8t_s8n_s8t, "BMM (INT8 IO) A x B.T"); - m.def("bmm_s8t_s8n_f32t", &bmm_s8t_s8n_f32t, "BMM (INT8 I FP32 O) A x B.T"); - m.def("bmm_s8t_s8n_s32t", &bmm_s8t_s8n_s32t, - "BMM (INT8 In Int32 Out) A x B.T"); -} diff --git a/csrc/int8gemm/cutlass/bmm.cu b/csrc/int8gemm/cutlass/bmm.cu deleted file mode 100644 index 93b8e06f96d8..000000000000 --- a/csrc/int8gemm/cutlass/bmm.cu +++ /dev/null @@ -1,211 +0,0 @@ -#include "include/bmm.h" -#include "include/common.h" -#include -#include -#include -#include -#include -#include - -torch::Tensor bmm_s8t_s8n_f32t(torch::Tensor A, torch::Tensor B, float alpha) { - int batch_size = A.size(0); - int M = A.size(1); - int N = B.size(1); - int K = A.size(2); - - auto C = torch::empty({batch_size, M, N}, - torch::dtype(torch::kFloat32).device(A.device())); - int lda = A.size(2); - int ldb = B.size(2); - int ldc = C.size(2); - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using ElementOutput = float; - using ElementInputA = int8_t; - using ElementInputB = int8_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue>; - - using Gemm = cutlass::gemm::device::GemmBatched< - ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, - LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp>; - - long long int batch_stride_A = M * K; - long long int batch_stride_B = N * K; - long long int batch_stride_C = M * N; - - Gemm gemm_op; - typename Gemm::Arguments arguments{ - {M, N, K}, {A.data_ptr(), lda}, - batch_stride_A, {B.data_ptr(), ldb}, - batch_stride_B, {C.data_ptr(), ldc}, - batch_stride_C, {C.data_ptr(), ldc}, - batch_stride_C, {alpha, 0}, - batch_size}; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } - return C; -} - -torch::Tensor bmm_s8t_s8n_s8t(torch::Tensor A, torch::Tensor B, float alpha) { - int batch_size = A.size(0); - int M = A.size(1); - int N = B.size(1); - int K = A.size(2); - - auto C = torch::empty({batch_size, M, N}, - torch::dtype(torch::kInt8).device(A.device())); - int lda = A.size(2); - int ldb = B.size(2); - int ldc = C.size(2); - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using ElementOutput = int8_t; - using ElementInputA = int8_t; - using ElementInputB = int8_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue>; - - using Gemm = cutlass::gemm::device::GemmBatched< - ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, - LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp>; - - long long int batch_stride_A = M * K; - long long int batch_stride_B = N * K; - long long int batch_stride_C = M * N; - - Gemm gemm_op; - typename Gemm::Arguments arguments{ - {M, N, K}, {A.data_ptr(), lda}, - batch_stride_A, {B.data_ptr(), ldb}, - batch_stride_B, {C.data_ptr(), ldc}, - batch_stride_C, {C.data_ptr(), ldc}, - batch_stride_C, {alpha, 0}, - batch_size}; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } - return C; -} - -torch::Tensor bmm_s8t_s8n_s32t(torch::Tensor A, torch::Tensor B) { - int batch_size = A.size(0); - int M = A.size(1); - int N = B.size(1); - int K = A.size(2); - - auto C = torch::empty({batch_size, M, N}, - torch::dtype(torch::kInt32).device(A.device())); - int lda = A.size(2); - int ldb = B.size(2); - int ldc = C.size(2); - - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using ElementOutput = int32_t; - using ElementInputA = int8_t; - using ElementInputB = int8_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = int32_t; - - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue>; - - using Gemm = cutlass::gemm::device::GemmBatched< - ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, - LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp>; - - long long int batch_stride_A = M * K; - long long int batch_stride_B = N * K; - long long int batch_stride_C = M * N; - - Gemm gemm_op; - - ElementComputeEpilogue alpha = 1; - - cutlass::Status status = gemm_op({{M, N, K}, - {A.data_ptr(), lda}, - batch_stride_A, - {B.data_ptr(), ldb}, - batch_stride_B, - {C.data_ptr(), ldc}, - batch_stride_C, - {C.data_ptr(), ldc}, - batch_stride_C, - {alpha, 0}, - batch_size}); - - if (status != cutlass::Status::kSuccess) { - std::cout << "cutlass error code: " << (int)status << std::endl; - } - return C; -} \ No newline at end of file diff --git a/csrc/int8gemm/cutlass/fused.cu b/csrc/int8gemm/cutlass/fused.cu deleted file mode 100644 index ee5e99d89f68..000000000000 --- a/csrc/int8gemm/cutlass/fused.cu +++ /dev/null @@ -1,25 +0,0 @@ -#include "include/fused.h" -#include "include/common.h" - - -std::tuple // (residual_output (FP), ln_output (INT8)) -dq_add_layernorm_q( - torch::Tensor input, // INT32 - float input_scale, // FP - torch::Tensor residual_input, // FP - torch::Tensor gamma, // FP - torch::Tensor beta, // FP - float epsilon // FP - ) // The output scale has already been fused into gamma and beta -{ - // residual_output = residual_input + input * input_scale - auto residual_output_fp = torch::add(residual_input, input, input_scale); - - auto ln_output_fp = - torch::layer_norm(residual_output_fp, {residual_output_fp.size(-1)}, - gamma, beta, epsilon); - ln_output_fp.clamp_(-128, 127).round_(); - auto ln_output_int8 = ln_output_fp.to(torch::kChar); - return std::make_tuple(residual_output_fp, ln_output_int8); -} \ No newline at end of file diff --git a/csrc/int8gemm/cutlass/include/bmm.h b/csrc/int8gemm/cutlass/include/bmm.h deleted file mode 100644 index 847265d72a33..000000000000 --- a/csrc/int8gemm/cutlass/include/bmm.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef BMM_H -#define BMM_H -#include -torch::Tensor bmm_s8t_s8n_f32t(torch::Tensor A, torch::Tensor B, float alpha); - -torch::Tensor bmm_s8t_s8n_s8t(torch::Tensor A, torch::Tensor B, float alpha); - -torch::Tensor bmm_s8t_s8n_s32t(torch::Tensor A, torch::Tensor B); - -#endif // BMM_H \ No newline at end of file diff --git a/csrc/int8gemm/cutlass/include/common.h b/csrc/int8gemm/cutlass/include/common.h deleted file mode 100644 index 2f3bdd3221b8..000000000000 --- a/csrc/int8gemm/cutlass/include/common.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef COMMON_H -#define COMMON_H -#include -#include -#include -#include -#include -#include - - -#endif // !COMMON \ No newline at end of file diff --git a/csrc/int8gemm/cutlass/include/fused.h b/csrc/int8gemm/cutlass/include/fused.h deleted file mode 100644 index 42ac634507ef..000000000000 --- a/csrc/int8gemm/cutlass/include/fused.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef FUSED_H -#define FUSED_H - -#include - -std::tuple // (residual_output (FP), ln_output (INT8)) -dq_add_layernorm_q(torch::Tensor input, // INT32 - float input_scale, // FP - torch::Tensor residual_input, // FP - torch::Tensor gamma, // FP - torch::Tensor beta, // FP - float epsilon // FP -); // The output scale has already been fused into gamma and beta - -#endif // FUSED_H \ No newline at end of file diff --git a/csrc/int8gemm/cutlass/include/linear.h b/csrc/int8gemm/cutlass/include/linear.h deleted file mode 100644 index 5df6ac6f1bc1..000000000000 --- a/csrc/int8gemm/cutlass/include/linear.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef LINEAR_H -#define LINEAR_H -#include - -// used by out_proj and fc2, return INT32 -torch::Tensor linear_a8_w8_b32_o32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias // INT32 -); - -// used by out_proj and fc2, return INT32 -torch::Tensor linear_a8_w8_b32_o32_with_scaling(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // INT32 - float alpha, // FP32 - float beta // FP32 -); - -// used by out_proj and fc2, return FP32 -torch::Tensor linear_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -); - -// used by fc1, return INT8 -torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // INT8 - float alpha, // FP32 - float beta // FP32 -); - -// used by q_proj, k_proj, v_proj, return INT8 -torch::Tensor linear_a8_w8_b8_o8(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // INT8 - float alpha, // FP32 - float beta // FP32 -); - -#endif // LINEAR_HS \ No newline at end of file diff --git a/csrc/int8gemm/cutlass/linear.cu b/csrc/int8gemm/cutlass/linear.cu deleted file mode 100644 index 0e11d7b46175..000000000000 --- a/csrc/int8gemm/cutlass/linear.cu +++ /dev/null @@ -1,491 +0,0 @@ -#include "include/linear.h" -#include "include/common.h" - -#include -#include -#include - -#include -#include -#include - -// used by out_proj and fc2, return INT32 -torch::Tensor linear_a8_w8_b32_o32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias // INT32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = int32_t; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType::NoBetaScaling>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - // Initialize alpha and beta for dot product computation - ElementComputeEpilogue alpha = ElementComputeEpilogue(1); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } - - return out; -} - -// used by out_proj and fc2, return INT32 -torch::Tensor linear_a8_w8_b32_o32_with_scaling(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // INT32 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } - - return out; -} - -// used by out_proj and fc2, return FP32 -torch::Tensor linear_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = float; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } - - return out; -} - - -// used by q_proj, k_proj, v_proj, return INT8 -torch::Tensor linear_a8_w8_b8_o8(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // INT8 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = int8_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::FastLinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - auto device = input.device(); - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement, status: " + - std::to_string((int)status)); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize, status: " + - std::to_string((int)status)); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run, status: " + - std::to_string((int)status)); - } - - return out; -} - -// used by fc1 -torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // INT8 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = int8_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 3>; - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement, status: " + - std::to_string((int)status)); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize, status: " + - std::to_string((int)status)); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run, status: " + - std::to_string((int)status)); - } - - return out; -} \ No newline at end of file diff --git a/setup.py b/setup.py index 079a816bff45..a8513eb582fc 100644 --- a/setup.py +++ b/setup.py @@ -91,25 +91,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ext_modules = [] -# Int8GEMM(cutlass required) -i8gemm_extension = CUDAExtension( - name='vllm.i8gemm', - sources=[ - 'csrc/int8gemm/cutlass/linear.cu', - 'csrc/int8gemm/cutlass/bmm.cu', - 'csrc/int8gemm/cutlass/fused.cu', - 'csrc/int8gemm/cutlass/bindings.cpp', - ], - include_dirs=['csrc/int8gemm/cutlass/include'], - extra_link_args=['-lcublas_static', '-lcublasLt_static', - '-lculibos', '-lcudart', '-lcudart_static', - '-lrt', '-lpthread', '-ldl', '-L/usr/lib/x86_64-linux-gnu/'], - extra_compile_args={'cxx': ['-std=c++14', '-O3'], - 'nvcc': ['-O3', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__']}, -) -ext_modules.append(i8gemm_extension) - -# int8gemm(cutlass required) +# int8gemm i8cugemm_extension = CUDAExtension( name='vllm.i8cugemm', sources=[ From d8f7d5af30507f80a7d9ea6fe812833343bf082a Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Tue, 24 Oct 2023 16:30:27 +0800 Subject: [PATCH 49/52] add sq quantized linear --- .../layers/int8_linear/w8a8linear.py | 231 +----------------- .../layers/quantized_linear/__init__.py | 11 +- .../layers/quantized_linear/smoothquant.py | 56 +++++ vllm/model_executor/models/llama.py | 72 +++--- 4 files changed, 96 insertions(+), 274 deletions(-) create mode 100644 vllm/model_executor/layers/quantized_linear/smoothquant.py diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py index df6949361130..7cee0c565bef 100644 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ b/vllm/model_executor/layers/int8_linear/w8a8linear.py @@ -1,234 +1,8 @@ # adapt from https://github.com/Guangxuan-Xiao/torch-int import torch -from vllm import i8gemm -from .quantization import ( - quantize_per_tensor_absmax, - quantize_weight_per_channel_absmax, - fake_quantize_activation_per_tensor_absmax, - fake_quantize_activation_per_token_absmax, -) from vllm.i8cugemm import I8CUGEMM i8cugemm = I8CUGEMM() -class W8A8B8O8Linear(torch.nn.Module): - # For qkv_proj - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('a', torch.tensor(alpha)) - self.register_buffer('b', torch.tensor(beta)) - - def _apply(self, fn): - super()._apply(fn) - self.a = self.a.cpu() - self.b = self.b.cpu() - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = i8gemm.linear_a8_w8_b8_o8(x, self.weight, self.bias, - self.a.item(), self.b.item()) - y = y.view(*x_shape[:-1], -1) - # FIXME: Just adapt to ParallelLinears' output - return y, None - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale, output_scale): - int8_module = W8A8B8O8Linear( - module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - # int8_bias, bias_scale should be 0, 0.0 - mockbias = torch.zeros((1, module.out_features), dtype=torch.int8, requires_grad=False) - int8_bias, bias_scale = quantize_per_tensor_absmax(mockbias) - alpha = input_scale * weight_scale / output_scale - beta = bias_scale / output_scale - int8_module.weight = int8_weight - int8_module.bias = int8_bias - int8_module.a = alpha - int8_module.b = beta - return int8_module - - -class W8A8B8O8LinearWithSFactor(torch.nn.Module): - # For qkv_proj - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0, inscale=1.0, ouscale=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('a', torch.tensor(alpha)) - self.register_buffer('b', torch.tensor(beta)) - self.register_buffer('inscale', torch.tensor(inscale)) - self.register_buffer('ouscale', torch.tensor(ouscale)) - - def _apply(self, fn): - super()._apply(fn) - self.a = self.a.cpu() - self.b = self.b.cpu() - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) - y = i8gemm.linear_a8_w8_b8_o8(x, self.weight, self.bias, - self.a.item(), self.b.item()) - y = y.view(*x_shape[:-1], -1) - return y, None - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale, output_scale): - int8_module = W8A8B8O8LinearWithSFactor( - module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - mockbias = torch.zeros((1, module.out_features), dtype=torch.int8, requires_grad=False) - int8_bias, bias_scale = quantize_per_tensor_absmax(mockbias) - alpha = input_scale * weight_scale / output_scale - beta = bias_scale / output_scale - int8_module.weight = int8_weight - int8_module.bias = int8_bias - int8_module.a = alpha - int8_module.b = beta - int8_module.inscale = input_scale - int8_module.ouscale = output_scale - return int8_module - - -class W8A8BFP32OFP32Linear(torch.nn.Module): - # For fc2 and out_proj - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.float32, requires_grad=False)) - self.register_buffer('a', torch.tensor(alpha)) - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - self.a = self.a.cpu() - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - self.bias = self.bias.to(torch.float32) - y = i8gemm.linear_a8_w8_bfp32_ofp32( - x, self.weight, self.bias, self.a.item(), 1) - y = y.view(*x_shape[:-1], -1) - return y, None - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale): - int8_module = W8A8BFP32OFP32Linear( - module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - alpha = input_scale * weight_scale - int8_module.weight = int8_weight - mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) - int8_module.bias = mockbias.to(torch.float32) - int8_module.a = alpha - int8_module.input_scale = input_scale - int8_module.weight_scale = weight_scale - return int8_module - - -class W8A8BFP32OFP32LinearWithSFactor(torch.nn.Module): - # For fc2 and out_proj - def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.float32, requires_grad=False)) - self.register_buffer('a', torch.tensor(alpha)) - self.register_buffer('inscale', torch.tensor(inscale)) - - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - self.a = self.a.cpu() - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - self.a = self.a.to(*args, **kwargs) - self.a = self.a.to(torch.float32) - self.inscale = self.inscale.to(*args, **kwargs) - self.inscale = self.inscale.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - # quant activation - x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) - self.bias = self.bias.to(torch.float32) - y = i8gemm.linear_a8_w8_bfp32_ofp32( - x, self.weight, self.bias, self.a.item(), 1) - y = y.view(*x_shape[:-1], -1) - return y, None - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale): - int8_module = W8A8BFP32OFP32LinearWithSFactor( - module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - alpha = input_scale * weight_scale - int8_module.weight = int8_weight - mockbias = torch.zeros((1, module.out_features), dtype=torch.float, requires_grad=False) - int8_module.bias = mockbias.to(torch.float32) - int8_module.a = alpha - int8_module.inscale = torch.tensor(input_scale) - return int8_module - class W8A8OFP32LinearWithSFactorCublas(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): @@ -311,7 +85,8 @@ def forward(self, x): y = y * self.a.item() y = y.view(*x_shape[:-1], -1) return y, None - + + class W8A8OFP32LinearWithSFactorCublasNoQuant(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features): @@ -356,6 +131,8 @@ def forward(self, x): # y = y * self.a.item() y = y.view(*x_shape[:-1], -1) return y, None + + class W8A8O32LinearCublasNoDequant(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features): diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index 4d843d3eab98..e9781f1f73c6 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,17 +1,14 @@ from vllm.model_executor.layers.quantized_linear.awq import ( AWQColumnParallelLinear, AWQRowParallelLinear) +from vllm.model_executor.layers.quantized_linear.smoothquant import ( + SQColumnParallelLinear, SQRowParallelLinear) from vllm.model_executor.parallel_utils.tensor_parallel import ( ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.int8_linear.w8a8linear import ( - W8A8OFP32LinearWithSFactorCublas, - W8A8O32LinearCublas, - W8A8O32LinearCublasNoDequant, - W8A8OFP32LinearWithSFactorCublasNoQuant, - W8A8O32Linear) + _QUANTIZED_LINEAR_REGISTRY = { "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), - "smoothquant": (ColumnParallelLinear, RowParallelLinear) + "smoothquant": (SQColumnParallelLinear, SQRowParallelLinear) } diff --git a/vllm/model_executor/layers/quantized_linear/smoothquant.py b/vllm/model_executor/layers/quantized_linear/smoothquant.py new file mode 100644 index 000000000000..8ee3e5238d16 --- /dev/null +++ b/vllm/model_executor/layers/quantized_linear/smoothquant.py @@ -0,0 +1,56 @@ +from typing import Optional + +import torch +from torch.nn.parameter import Parameter +from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( + ColumnParallelLinear, RowParallelLinear) +from vllm.i8cugemm import I8CUGEMM +i8cugemm = I8CUGEMM() + +class SQColumnParallelLinear(ColumnParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert self.input_size % self.quant_config.weight_bits == 0 + self.register_buffer('weight', + torch.randint(-127, + 127, + (self.output_size_per_partition, + self.input_size), + dtype=torch.int8, + requires_grad=False)) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + + + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.output_size_per_partition), dtype=torch.int32, device=x.device) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) + y = y.view(*x_shape[:-1], -1) + return y + +class SQRowParallelLinear(RowParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert (self.input_size_per_partition % + self.quant_config.weight_bits == 0) + self.register_buffer('weight', + torch.randint(-127, + 127, + (self.output_size, + self.input_size_per_partition), + dtype=torch.int8, + requires_grad=False)) + + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.output_size), dtype=torch.int32, device=x.device) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) + y = y.view(*x_shape[:-1], -1) + return y \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9e3a3bdf4894..a485997f3388 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -73,25 +73,22 @@ def __init__( raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") - if self.use_int8: - self.gate_up_proj = W8A8O32Linear(hidden_size, - 2 * intermediate_size) - self.down_proj = W8A8O32Linear(intermediate_size, - hidden_size) - self.act_fn = DequantSiluAndMulQuant() - else: - self.gate_up_proj = ParallelLinear.column(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - perform_initialization=False, - quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, - hidden_size, + self.gate_up_proj = ParallelLinear.column(hidden_size, + 2 * intermediate_size, bias=False, - input_is_parallel=True, + gather_output=False, perform_initialization=False, quant_config=quant_config) + self.down_proj = ParallelLinear.row(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config) + # kernel fusion for int8 inference + if self.use_int8: + self.act_fn = DequantSiluAndMulQuant() + else: self.act_fn = SiluAndMul() def forward(self, x): @@ -130,13 +127,26 @@ def __init__( self.rope_theta = rope_theta self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" + self.qkv_proj = ParallelLinear.column( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=quant_config, + ) + self.o_proj = ParallelLinear.row( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config, + ) + + # kernel fusion for int8 inference if self.use_int8: - self.qkv_proj = W8A8O32Linear( - hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim) - self.o_proj = W8A8O32Linear( - self.total_num_heads * self.head_dim, - hidden_size) self.attn = DequantPagedAttentionWithRoPEQuant(self.num_heads, self.head_dim, self.scaling, @@ -145,25 +155,7 @@ def __init__( num_kv_heads=self.num_kv_heads, quant_kv_cache=quant_kv_cache, kv_quant_params=kv_quant_params) - else: - self.qkv_proj = ParallelLinear.column( - hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, - bias=False, - gather_output=False, - perform_initialization=False, - quant_config=quant_config, - ) - self.o_proj = ParallelLinear.row( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False, - quant_config=quant_config, - ) self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, From 74bd08fd3314a4e24c5138a7af7b2242912d848d Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Tue, 24 Oct 2023 16:34:03 +0800 Subject: [PATCH 50/52] rm unit test for w8a8 linear --- .../layers/int8_linear/__init__.py | 0 .../layers/int8_linear/quantization.py | 97 --------- .../model_executor/layers/int8_linear/test.py | 38 ---- .../layers/int8_linear/w8a8linear.py | 205 ------------------ vllm/model_executor/models/llama.py | 6 - 5 files changed, 346 deletions(-) delete mode 100644 vllm/model_executor/layers/int8_linear/__init__.py delete mode 100644 vllm/model_executor/layers/int8_linear/quantization.py delete mode 100644 vllm/model_executor/layers/int8_linear/test.py delete mode 100644 vllm/model_executor/layers/int8_linear/w8a8linear.py diff --git a/vllm/model_executor/layers/int8_linear/__init__.py b/vllm/model_executor/layers/int8_linear/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/model_executor/layers/int8_linear/quantization.py b/vllm/model_executor/layers/int8_linear/quantization.py deleted file mode 100644 index 555e09c39618..000000000000 --- a/vllm/model_executor/layers/int8_linear/quantization.py +++ /dev/null @@ -1,97 +0,0 @@ -# adapt from https://github.com/Guangxuan-Xiao/torch-int -import torch -import numpy as np - -@torch.no_grad() -def quantize_per_tensor_absmax(t): - scale = t.abs().max() / 127 - if not t.is_cuda: - # half rounding is not supported on CPU - t = t.float() - # use inplace operation to save memory - t.div_(scale).round_() - t_q = t.to(torch.int8) - return t_q, scale - -@torch.no_grad() -def quantize_weight_per_channel_absmax(w): - # w: [out_channel, in_channel] - scales = w.abs().max(dim=1)[0] / 127 - scales = scales.view(-1, 1) - if not w.is_cuda: - # half rounding is not supported on CPU - w = w.float() - # use inplace operation to save memory - w.div_(scales).round_().clamp_(-128, 127) - w_q = w.to(torch.int8) - return w_q, scales - - -@torch.no_grad() -def dynamic_quantize_activation_per_tensor_zeropoint(t): - max_val = t.max()[0] - min_val = t.min()[0] - quant_min = -127 - quant_max = 127 - nudged_scale = (max_val - min_val) / (quant_max - quant_min) - zp = (max_val + min_val) / 2 - zp = (zp / nudged_scale).round() * nudged_scale - t -= zp - max_val = (max_val - min_val) / 2 - - max_val = torch.clamp(max_val, min=1e-8) / 127 - q_act = (t / max_val).round().clamp(-128, 127).to(torch.int8) - return q_act, max_val, zp - - -@torch.no_grad() -def dynamic_quantize_activation_per_tensor_absmax(t): - max_val = t.abs().max() - max_val = torch.clamp(max_val, min=1e-8) / 127 - q_act = (t / max_val).round().clamp(-128, 127).to(torch.int8) - return q_act, max_val - - -@torch.no_grad() -def dynamic_quantize_activation_per_token_absmax(t): - max_val = t.abs().max(dim=-1, keepdim=True)[0] - max_val = torch.clamp(max_val, min=1e-8) / 127 - t.div_(max_val).round_().clamp_(-128, 127) - q_act = t.to(torch.int8) - return q_act, max_val - -@torch.no_grad() -def fake_quantize_activation_per_tensor_absmax(t): - max_val = t.abs().max() - max_val = torch.clamp(max_val, min=1e-8) / 127 - t.div_(max_val).round_().clamp_(-128, 127).mul_(max_val) - return t - - -@torch.no_grad() -def fake_quantize_activation_per_token_absmax(t): - max_val = t.abs().max(dim=-1, keepdim=True)[0] - max_val = torch.clamp(max_val, min=1e-8) / 127 - t.div_(max_val).round_().clamp_(-128, 127).mul_(max_val) - return t - - -@torch.no_grad() -def dequantize_activation_w_per_channel_a_per_token(q_act, w_scales, a_scales): - # q_act: [B, dim] - # w_scales: [dim] - # a_scales: [B 1] - dtype = a_scales.dtype - q_act = q_act.to(torch.float32) - q_act.mul_(w_scales.reshape(1, -1)).mul_(a_scales.reshape(-1, 1)) - return q_act.to(dtype) - -@torch.no_grad() -def dequantize_activation_w_per_channel_a_per_tensor(q_act, w_scales, a_scales): - # q_act: [..., dim] - # w_scales: [dim] - # a_scales: [1] - dtype = a_scales.dtype - q_act = q_act.to(torch.float32) - q_act = q_act * w_scales.reshape(1, -1) * a_scales - return q_act.to(dtype) diff --git a/vllm/model_executor/layers/int8_linear/test.py b/vllm/model_executor/layers/int8_linear/test.py deleted file mode 100644 index b83d36e9b08e..000000000000 --- a/vllm/model_executor/layers/int8_linear/test.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from w8a8linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear -from icecream import ic - -@torch.no_grad() -def test_w8a8b8o8_linear(): - B, M, N = 128, 512, 1024 - x = torch.randn(B, M) - x_scale = x.abs().max() / 127 - qx = (x / x_scale).round().to(torch.int8) - linear = torch.nn.Linear(M, N, bias=True) - y_gt = linear(x) - y_scale = y_gt.abs().max() / 127 - q_linear = W8A8B8O8Linear.from_float(linear, x_scale, y_scale).cuda() - q_y = q_linear(qx.cuda()).cpu() - y_hat = q_y * y_scale - r2 = (y_gt - y_hat).pow(2).mean() / y_gt.pow(2).mean() - ic(r2) - -@torch.no_grad() -def test_w8a8bfp32ofp32_linear(): - B, M, N = 128, 512, 1024 - x = torch.randn(B, M) - x_scale = x.abs().max() / 127 - qx = (x / x_scale).round().to(torch.int8) - linear = torch.nn.Linear(M, N, bias=True) - y_gt = linear(x) - q_linear = W8A8BFP32OFP32Linear.from_float(linear, x_scale).cuda() - y_hat = q_linear(qx.cuda()).cpu() - r2 = (y_gt - y_hat).pow(2).mean() / y_gt.pow(2).mean() - ic(r2) - - -if __name__ == '__main__': - print('test_w8a8b8o8_linear') - test_w8a8b8o8_linear() - print('test_w8a8bfp32ofp32_linear') - test_w8a8bfp32ofp32_linear() \ No newline at end of file diff --git a/vllm/model_executor/layers/int8_linear/w8a8linear.py b/vllm/model_executor/layers/int8_linear/w8a8linear.py deleted file mode 100644 index 7cee0c565bef..000000000000 --- a/vllm/model_executor/layers/int8_linear/w8a8linear.py +++ /dev/null @@ -1,205 +0,0 @@ -# adapt from https://github.com/Guangxuan-Xiao/torch-int -import torch -from vllm.i8cugemm import I8CUGEMM -i8cugemm = I8CUGEMM() - -class W8A8OFP32LinearWithSFactorCublas(torch.nn.Module): - # For fc2 and out_proj - def __init__(self, in_features, out_features, alpha=1.0, inscale=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.float32, requires_grad=False)) - self.register_buffer('a', torch.tensor(alpha)) - self.register_buffer('inscale', torch.tensor(inscale)) - - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - self.a = self.a.cpu() - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - self.a = self.a.to(*args, **kwargs) - self.a = self.a.to(torch.float32) - self.inscale = self.inscale.to(*args, **kwargs) - self.inscale = self.inscale.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - # quant activation - x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) - y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) - i8cugemm.linear_a8_w8_o32_(x, self.weight, y) - y = y * self.a.item() - y = y.view(*x_shape[:-1], -1) - return y, None - - -class W8A8O32LinearCublas(torch.nn.Module): - # For fc2 and out_proj - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.float32, requires_grad=False)) - self.register_buffer('a', torch.tensor(alpha)) - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - self.a = self.a.cpu() - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) - i8cugemm.linear_a8_w8_o32_(x, self.weight, y) - y = y * self.a.item() - y = y.view(*x_shape[:-1], -1) - return y, None - - -class W8A8OFP32LinearWithSFactorCublasNoQuant(torch.nn.Module): - # For fc2 and out_proj - def __init__(self, in_features, out_features): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.float32, requires_grad=False)) - # self.register_buffer('a', torch.tensor(alpha)) - # self.register_buffer('inscale', torch.tensor(inscale)) - - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - # self.a = self.a.cpu() - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - # self.a = self.a.to(*args, **kwargs) - # self.a = self.a.to(torch.float32) - # self.inscale = self.inscale.to(*args, **kwargs) - # self.inscale = self.inscale.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - # quant activation - # x = (x / self.inscale).round().clamp(-128, 127).to(torch.int8) - y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) - i8cugemm.linear_a8_w8_o32_(x, self.weight, y) - # y = y * self.a.item() - y = y.view(*x_shape[:-1], -1) - return y, None - - -class W8A8O32LinearCublasNoDequant(torch.nn.Module): - # For fc2 and out_proj - def __init__(self, in_features, out_features): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.float32, requires_grad=False)) - # self.register_buffer('a', torch.tensor(alpha)) - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - # self.a = self.a.cpu() - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) - i8cugemm.linear_a8_w8_o32_(x, self.weight, y) - # y = y * self.a.item() - y = y.view(*x_shape[:-1], -1) - return y, None - - -class W8A8O32Linear(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, - self.in_features), dtype=torch.int8, requires_grad=False)) - self.register_buffer('bias', torch.zeros( - (1, self.out_features), dtype=torch.float32, requires_grad=False)) - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = torch.empty((x.shape[0], self.out_features), dtype=torch.int32, device=x.device) - i8cugemm.linear_a8_w8_o32_(x, self.weight, y) - y = y.view(*x_shape[:-1], -1) - return y, None \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a485997f3388..799d130321cc 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,12 +37,6 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, DequantPagedAttentionWithRoPEQuant from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear -from vllm.model_executor.layers.int8_linear.w8a8linear import ( - W8A8OFP32LinearWithSFactorCublas, - W8A8O32LinearCublas, - W8A8O32LinearCublasNoDequant, - W8A8OFP32LinearWithSFactorCublasNoQuant, - W8A8O32Linear) from vllm.model_executor.layers.fusion import DequantAddResidual from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) From e9b2fa428ca758f615ea66c4615e673777ee717d Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Tue, 24 Oct 2023 17:09:48 +0800 Subject: [PATCH 51/52] adjust i8 llama weight load --- vllm/model_executor/models/llama.py | 177 ++++------------------------ 1 file changed, 26 insertions(+), 151 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 799d130321cc..995974e8ac9a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -344,19 +344,27 @@ def forward( _column_parallel_layers = [] _row_parallel_layers = ["o_proj", "down_proj"] + _int8_scale_params = { + "self_attn.q_proj.a": "self_attn.attn.a", + "self_attn.k_proj.a": "self_attn.attn.a", + "self_attn.v_proj.a": "self_attn.attn.a", + "self_attn.o_proj.inscale": "self_attn.attn.inscale", + "self_attn.o_proj.a": "dequant_add_residual_layernorm_quant.a", + "post_attention_layernorm.weight": "dequant_add_residual_layernorm_quant.weight", + "mlp.gate_proj.a": "mlp.act_fn.a", + "mlp.up_proj.a": "mlp.act_fn.a", + "mlp.down_proj.inscale": "mlp.act_fn.inscale", + "mlp.down_proj.a": "dequant_add_residual.a" + } def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): + int8_fusion = False if self.quant_config is not None and self.quant_config.get_name() == "smoothquant": - return self._load_int8_weights( - model_name_or_path, - cache_dir, - load_format, - revision - ) + int8_fusion = True if self.quant_config is None: weight_suffixes = ["weight"] @@ -387,116 +395,6 @@ def load_weights(self, ] state_dict = self.state_dict() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - - is_packed = False - is_transposed = False - if self.quant_config is not None: - is_packed = self.quant_config.is_packed(name) - is_transposed = self.quant_config.is_transposed(name) - if is_transposed: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight = loaded_weight.T - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T - - if is_packed: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor - - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T - - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue - - param = state_dict[name] - if is_transposed: - param = param.T - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, - tensor_model_parallel_rank) - - def _load_int8_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - # TODO: support tp in intlinear - tp_size = 1 - tensor_model_parallel_rank = 0 - q_proj_shard_size = (self.config.hidden_size // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) - - if self.quant_config is None: - weight_suffixes = ["weight"] - else: - weight_suffixes = self.quant_config.get_tp_tensor_names() - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), - ] - state_dict = self.state_dict() - for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: @@ -513,28 +411,17 @@ def _load_int8_weights(self, if is_transposed: loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight = loaded_weight.T - - is_fusion_weight = False - name_dict = { - "self_attn.q_proj.a": "self_attn.attn.a", - "self_attn.k_proj.a": "self_attn.attn.a", - "self_attn.v_proj.a": "self_attn.attn.a", - "self_attn.o_proj.inscale": "self_attn.attn.inscale", - "self_attn.o_proj.a": "dequant_add_residual_layernorm_quant.a", - "post_attention_layernorm.weight": "dequant_add_residual_layernorm_quant.weight", - "mlp.gate_proj.a": "mlp.act_fn.a", - "mlp.up_proj.a": "mlp.act_fn.a", - "mlp.down_proj.inscale": "mlp.act_fn.inscale", - "mlp.down_proj.a": "dequant_add_residual.a" - } - for weight_name in name_dict.keys(): - if weight_name not in name: + + if int8_fusion: + is_fusion_weight = False + for weight_name in self._int8_scale_params.keys(): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, self._int8_scale_params[weight_name])] + param.copy_(loaded_weight) + is_fusion_weight = True + if is_fusion_weight: continue - param = state_dict[name.replace(weight_name, name_dict[weight_name])] - param.copy_(loaded_weight) - is_fusion_weight = True - if is_fusion_weight: - continue is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: @@ -547,12 +434,6 @@ def _load_int8_weights(self, if is_packed: shard_size //= self.quant_config.pack_factor offset //= self.quant_config.pack_factor - - # share use same scale in quantizatin - if "proj.a" in name or "proj.inscale" in name: - param.copy_(loaded_weight) - is_attention_weight = True - continue loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * @@ -572,13 +453,7 @@ def _load_int8_weights(self, continue param = state_dict[name.replace(weight_name, "gate_up_proj")] if is_transposed: - loaded_weight = loaded_weight.T - - # share use same scale in quantizatin - if "proj.a" in name or "proj.inscale" in name: - param.copy_(loaded_weight) - is_gate_up_weight = True - continue + param = param.T shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[ @@ -605,4 +480,4 @@ def _load_int8_weights(self, load_tensor_parallel_weights(param, loaded_weight, name, column_parallel_weights, row_parallel_weights, - tensor_model_parallel_rank) + tensor_model_parallel_rank) \ No newline at end of file From 6f8878788587741042e738a594f69bfe08651a65 Mon Sep 17 00:00:00 2001 From: "aniz1905@gmail.com" Date: Thu, 26 Oct 2023 11:55:16 +0800 Subject: [PATCH 52/52] add fusion.py --- tests/kernels/test_layernorm.py | 2 +- vllm/model_executor/layers/fusion.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/fusion.py diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 22dd822f36b7..1083d88e368b 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -154,7 +154,7 @@ def test_dequant_add_residual( out2 = torch.empty_like(x, dtype=dtype) fused_kernels.invoke_dequant_add_residual(out2, x, residual, scale) - assert torch.allclose(out1, out2, atol=0.001) + assert torch.allclose(out1, out2, atol=0.001), f"diff: {torch.max(out1 - out2)}" @pytest.mark.parametrize("num_tokens", NUM_TOKENS) diff --git a/vllm/model_executor/layers/fusion.py b/vllm/model_executor/layers/fusion.py new file mode 100644 index 000000000000..11832e3d8738 --- /dev/null +++ b/vllm/model_executor/layers/fusion.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +from vllm import fused_kernels + + +class DequantAddResidual(nn.Module): + def __init__(self, scale: float = 1.0) -> None: + super().__init__() + self.register_buffer( + "a", torch.tensor(scale, dtype=torch.float32, requires_grad=False) + ) + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + return self + + def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(residual) + fused_kernels.invoke_dequant_add_residual(out, x, residual, self.a.item()) + return out \ No newline at end of file