From 32491c612ef628e0ab94bc1bb87a83f6929e4a75 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sat, 27 Jan 2024 15:40:43 +0000 Subject: [PATCH 01/12] inital commit --- vllm/config.py | 27 +++++++++++++++++++++ vllm/engine/arg_utils.py | 14 +++++++++-- vllm/engine/llm_engine.py | 1 + vllm/entrypoints/llm.py | 7 ++++++ vllm/model_executor/layers/linear.py | 36 +++++++++++++++++++++++----- vllm/model_executor/weight_utils.py | 18 ++++++++++---- 6 files changed, 91 insertions(+), 12 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 8acd15a3b7d9a..3df920805d7a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -72,6 +72,7 @@ def __init__( tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, + sparsity: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, ) -> None: @@ -85,6 +86,7 @@ def __init__( self.revision = revision self.tokenizer_revision = tokenizer_revision self.quantization = quantization + self.sparsity = sparsity self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture @@ -106,6 +108,7 @@ def __init__( self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() + self._verify_sparsity() self._verify_cuda_graph() def _verify_load_format(self) -> None: @@ -144,6 +147,30 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode + def _verify_sparsity(self) -> None: + supported_sparsity = ["sparse_w16a16"] + + if self.quantization is not None: + raise ValueError(f"Both sparsity and quantization detected. Only " + "one or the other is supported at a time.") + + if self.sparsity is not None: + if self.sparsity not in supported_sparsity: + raise ValueError(f"Unknown sparse method: {self.sparsity}. Must " + f"be one of {supported_sparse}.") + + hf_sparsity_config = getattr(self.hf_config, "sparsity_config", None) + if hf_sparsity_config is not None: + hf_sparsity_method = str(hf_sparse_config["sparse_method"]).lower() + if self.sparsity is None: + self.sparsity = hf_sparsity_method + elif self.sparsity != hf_sparsity_method: + raise ValueError( + "Sparsity method specified in the model config " + f"({hf_sparsity_method}) does not match the sparsity " + f"method specified in the `sparsity` argument " + f"({self.sparsity}).") + def _verify_quantization(self) -> None: supported_quantization = ["awq", "gptq", "squeezellm"] rocm_not_supported_quantization = ["awq"] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 090fa95bcac02..194356e365b85 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -33,6 +33,7 @@ class EngineArgs: revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None + sparsity: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 enable_lora: bool = False @@ -197,6 +198,15 @@ def add_cli_args( 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') + parser.add_argument('--sparsity', + '-s', + type=str, + choices=['sparse_w16a16', None], + default=None, + help='Method used to compress sparse weights. If ' + 'None, we first check the `sparsity_config` attribute ' + 'in the model config gile. If that is None we assume ' + 'the model weights are dense') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -260,8 +270,8 @@ def create_engine_configs( self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.tokenizer_revision, self.max_model_len, - self.quantization, self.enforce_eager, - self.max_context_len_to_capture) + self.quantization, self.sparsity_config, + self.enforce_eager, self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0dedc232292dd..6c916695a74e2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -83,6 +83,7 @@ def __init__( f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"quantization={model_config.quantization}, " + f"sparsity={model_config.sparsity}, " f"enforce_eager={model_config.enforce_eager}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index aab0c9615f411..1f0fbb41729ce 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -43,6 +43,11 @@ class LLM: the `quantization_config` attribute in the model config file. If that is None, we assume the model weights are not quantized and use `dtype` to determine the data type of the weights. + sparsity: The format of the sparse model weights. Currently, + we support "sparse_w16a16". If None, we first check the `sparsity` + attribute in the model config file. If that is None, we assume the + model weights are dense and use `dtype` to determine the data + type of the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a @@ -75,6 +80,7 @@ def __init__( tensor_parallel_size: int = 1, dtype: str = "auto", quantization: Optional[str] = None, + sparsity: Optional[str] = None, revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, seed: int = 0, @@ -94,6 +100,7 @@ def __init__( tensor_parallel_size=tensor_parallel_size, dtype=dtype, quantization=quantization, + sparsity=sparsity, revision=revision, tokenizer_revision=tokenizer_revision, seed=seed, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5e1d63a6a62eb..0ffc49c083b14 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,6 +13,8 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger +from vllm.model_executor.layers.sparsity import SparseParameter +from vllm.model_executor.weight_utils import get_param_data logger = init_logger(__name__) @@ -195,7 +197,8 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) - param_data = param.data + param_data = get_param_data(param) + if output_dim is not None: shard_size = param_data.shape[output_dim] start_idx = tp_rank * shard_size @@ -204,6 +207,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + # If SparsParameter, repack dense data as sparse. + if isinstance(param, SparseParameter): + param.pack() + def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -218,7 +225,6 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None return output, output_bias - class MergedColumnParallelLinear(ColumnParallelLinear): """Packed linear layers with column parallelism. @@ -260,9 +266,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): - param_data = param.data + param_data = get_param_data(param) output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: + if isinstance(param, SparseParameter): + raise NotImplementedError( + "Passing loaded_shard_id=None not yet supported for SparseParameter") + # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape @@ -312,6 +322,9 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + # If Parameter, repack dense data as sparse. + if isinstance(param, SparseParameter): + param.pack() class QKVParallelLinear(ColumnParallelLinear): """Linear layers for the attention's QKV transformation. @@ -373,10 +386,14 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - param_data = param.data + loaded_shard_id: Optional[str] = None): + param_data = get_param_data(param) output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: + if isinstance(param, SparseParameter): + raise NotImplementedError( + "Passing loaded_shard_id=None not yet supported for SparseParameter") + # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape @@ -440,6 +457,9 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + # If SparseParameter, repack dense data as sparse. + if isinstance(param, SparseParameter): + param.pack() class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. @@ -522,7 +542,7 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) - param_data = param.data + param_data = get_param_data(param) if input_dim is not None: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size @@ -530,6 +550,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + + # If SparseParameter, repack dense data as sparse. + if isinstance(param, SparseParameter): + param.pack() def forward(self, input_): # Set up backprop all-reduce. diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 8e6f7a174f219..c314e15f53e4a 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -17,9 +17,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (get_quantization_config, QuantizationConfig) +from vllm.model_executor.layers.sparsity import SparseParameter -logger = init_logger(__name__) +# TODO: import safely +import nm_gpu +logger = init_logger(__name__) class Disabledtqdm(tqdm): @@ -276,13 +279,20 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: x = x[:] return x +def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: + """Gets parameter data in dense format.""" + if isinstance(param, SparseParameter): + return param.get_data_dense() + else: + return param.data -def default_weight_loader(param: torch.Tensor, +def default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - + get_param_data(param).copy_(loaded_weight) + if isinstance(param, SparseParameter): + param.pack() def initialize_dummy_weights( model: torch.nn.Module, From 8db3a546442ef9c604783416bc05d1fa4f9c0327 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sat, 27 Jan 2024 20:00:56 +0000 Subject: [PATCH 02/12] end to end appears to be working --- vllm/engine/arg_utils.py | 2 +- vllm/model_executor/layers/linear.py | 6 ++-- vllm/model_executor/model_loader.py | 25 ++++++++++++-- vllm/model_executor/parameters/__init__.py | 9 +++++ vllm/model_executor/parameters/sparsity.py | 39 ++++++++++++++++++++++ vllm/model_executor/utils.py | 1 - vllm/model_executor/weight_utils.py | 28 ++++++++++------ 7 files changed, 90 insertions(+), 20 deletions(-) create mode 100644 vllm/model_executor/parameters/__init__.py create mode 100644 vllm/model_executor/parameters/sparsity.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 194356e365b85..5f7d2ed25ea5d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -270,7 +270,7 @@ def create_engine_configs( self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.tokenizer_revision, self.max_model_len, - self.quantization, self.sparsity_config, + self.quantization, self.sparsity, self.enforce_eager, self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 0ffc49c083b14..b06c76e031c56 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,12 +13,10 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger -from vllm.model_executor.layers.sparsity import SparseParameter -from vllm.model_executor.weight_utils import get_param_data +from vllm.model_executor.parameters import SparseParameter, get_param_data logger = init_logger(__name__) - class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @@ -386,7 +384,7 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): + loaded_shard_id: Optional[str] = None): param_data = get_param_data(param) output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 0f1125e5c8e3e..3dc16011b8144 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -8,10 +8,9 @@ from vllm.config import ModelConfig, LoRAConfig from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.weight_utils import (get_quant_config, +from vllm.model_executor.weight_utils import (get_quant_config, get_sparse_config, initialize_dummy_weights) - @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" @@ -36,7 +35,7 @@ def get_model(model_config: ModelConfig, lora_config: Optional[LoRAConfig] = None) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) - # Get the (maybe quantized) linear method. + # Get the (maybe sparse or quantized) linear method. linear_method = None if model_config.quantization is not None: quant_config = get_quant_config(model_config.quantization, @@ -58,6 +57,26 @@ def get_model(model_config: ModelConfig, f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") linear_method = quant_config.get_linear_method() + if model_config.sparsity is not None: + sparse_config = get_sparse_config(model_config.sparsity, + model_config.model, + model_config.hf_config, + model_config.download_dir) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < sparse_config.get_min_capability(): + raise ValueError( + f"The sparsity method {model_config.sparsity} is not " + "supported for the current GPU. " + f"Minimum capability: {sparse_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = sparse_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for sparsity " + f"method {model_config.sparsity}. Supported dtypes: " + f"{supported_dtypes}") + linear_method = sparse_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): # Create a model instance. diff --git a/vllm/model_executor/parameters/__init__.py b/vllm/model_executor/parameters/__init__.py new file mode 100644 index 0000000000000..75833a845c99e --- /dev/null +++ b/vllm/model_executor/parameters/__init__.py @@ -0,0 +1,9 @@ +import torch +from vllm.model_executor.parameters.sparsity import SparseParameter + +def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: + """Gets parameter data in dense format.""" + if isinstance(param, SparseParameter): + return param.get_dense_data() + else: + return param.data \ No newline at end of file diff --git a/vllm/model_executor/parameters/sparsity.py b/vllm/model_executor/parameters/sparsity.py new file mode 100644 index 0000000000000..b8a3f14b5e219 --- /dev/null +++ b/vllm/model_executor/parameters/sparsity.py @@ -0,0 +1,39 @@ +import torch +import nm_gpu +from nm_gpu.SparseTensor import ( + SparseBitmaskStorageFormat, SparseTensor +) + +class SparseParameter(SparseTensor): + @staticmethod + def __new__( + cls, + shape: torch.Size, + dtype: torch.dtype, + ): + assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" + self = torch.Tensor._make_wrapper_subclass(cls, size=shape, dtype=dtype, requires_grad=False) + self.storage_format_cls = SparseBitmaskStorageFormat + self.compressed_data = None + self.dense_data = None + self._is_param = True + + return self + + def get_dense_data(self) -> torch.Tensor: + if self.dense_data is not None: + raise ValueError("Called get_data_dense() but dense_data already exists.") + self.dense_data = self._unpack() + return self.dense_data + + def _unpack(self) -> torch.Tensor: + if self.has_compressed_data(): + return self.compressed_data.to_dense() + else: + return torch.empty(size=self.shape, dtype=self.dtype, device="cuda") + + def pack(self) -> None: + if self.dense_data is None: + raise ValueError("Called pack() but dense_data does not exist.") + self.copy_(self.dense_data) + self.dense_data = None diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 336bc1cd005cf..23ae3eb1146c7 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -5,7 +5,6 @@ import numpy as np import torch - def set_random_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index c314e15f53e4a..46d2915729fda 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -17,10 +17,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (get_quantization_config, QuantizationConfig) -from vllm.model_executor.layers.sparsity import SparseParameter - -# TODO: import safely -import nm_gpu +from vllm.model_executor.layers.sparsity import (get_sparsity_config, + SparsityConfig) +from vllm.model_executor.parameters import SparseParameter, get_param_data logger = init_logger(__name__) @@ -84,6 +83,20 @@ def convert_bin_to_safetensor_file( if not torch.equal(pt_tensor, sf_tensor): raise RuntimeError(f"The output tensors do not match for key {k}") +# TODO(rib-2): Once we define hf_sparsity_config +def get_sparse_config( + sparsity: str, + model_name_or_path: str, + hf_config: PretrainedConfig, + cache_dir: Optional[str] = None, +) -> SparsityConfig: + sparsity_cls = get_sparsity_config(sparsity) + hf_sparsity_config = getattr(hf_config, "sparsity_config", None) + if hf_sparsity_config is not None: + raise NotImplementedError( + "Loading hf sparsity config not yet supported" + ) + return sparsity_cls() # TODO(woosuk): Move this to other place. def get_quant_config( @@ -279,13 +292,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: x = x[:] return x -def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: - """Gets parameter data in dense format.""" - if isinstance(param, SparseParameter): - return param.get_data_dense() - else: - return param.data - def default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" From 16dd160feaf20c0b11d1aa35912650f6f5aade3e Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sat, 27 Jan 2024 20:14:08 +0000 Subject: [PATCH 03/12] updated README.md --- README.md | 133 +++++---------------- vllm/model_executor/layers/linear.py | 2 +- vllm/model_executor/parameters/__init__.py | 9 -- vllm/model_executor/parameters/sparsity.py | 39 ------ vllm/model_executor/weight_utils.py | 3 +- 5 files changed, 34 insertions(+), 152 deletions(-) delete mode 100644 vllm/model_executor/parameters/__init__.py delete mode 100644 vllm/model_executor/parameters/sparsity.py diff --git a/README.md b/README.md index c7ae85e7973db..95285b894d799 100644 --- a/README.md +++ b/README.md @@ -1,112 +1,41 @@ -

- - - vLLM - -

+## Neural Magic vLLM -

-Easy, fast, and cheap LLM serving for everyone -

+Fork of vLLM with sparsity. -

-| Documentation | Blog | Paper | Discord | +### To Run -

- ---- - -**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)** - -We are thrilled to announce our second vLLM Meetup! -The vLLM team will share recent updates and roadmap. -We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations. -Please register [here](https://lu.ma/ygxbpzhl) and join us! - ---- - -*Latest News* 🔥 -- [2023/12] Added ROCm support to vLLM. -- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). -- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. -- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! -- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM. -- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command! -- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds. -- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). - ---- -## About -vLLM is a fast and easy-to-use library for LLM inference and serving. - -vLLM is fast with: - -- State-of-the-art serving throughput -- Efficient management of attention key and value memory with **PagedAttention** -- Continuous batching of incoming requests -- Fast model execution with CUDA/HIP graph -- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629) -- Optimized CUDA kernels - -vLLM is flexible and easy to use with: - -- Seamless integration with popular Hugging Face models -- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more -- Tensor parallelism support for distributed inference -- Streaming outputs -- OpenAI-compatible API server -- Support NVIDIA GPUs and AMD GPUs - -vLLM seamlessly supports many Hugging Face models, including the following architectures: - -- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) -- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) -- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) -- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) -- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) -- GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) -- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) -- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) -- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) -- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) -- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) -- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) -- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.) -- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) -- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) - -Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): +Clone and install nm_gpu: ```bash -pip install vllm +git clone https://github.com/neuralmagic/nm_gpu.git +cd nm_gpu +export TORCH_CUDA_ARCH_LIST=8.6 +pip install -e . ``` -## Getting Started - -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. -- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) -- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) -- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) - -## Contributing +Install: +```bash +cd ../ +pip install -e . +``` -We welcome and value any contributions and collaborations. -Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. +### Run Sample -## Citation +Run a 50% sparse model: -If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): -```bibtex -@inproceedings{kwon2023efficient, - title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, - author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica}, - booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles}, - year={2023} -} -``` +```bash +from vllm import LLM, SamplingParams + +model = LLM( + "nm-testing/Llama-2-7b-pruned50-retrained", + sparsity="sparse_w16a16", # If left off, model will be loaded as dense + enforce_eager=True, # Does not work with cudagraphs yet + dtype="float16", + tensor_parallel_size=1, + max_model_len=1024 +) + +sampling_params = SamplingParams(max_tokens=100, temperature=0) +outputs = model.generate("Hello my name is", sampling_params=sampling_params) +outputs[0].outputs[0].text +``` \ No newline at end of file diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b06c76e031c56..f1e244e9c1f8d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,7 +13,7 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger -from vllm.model_executor.parameters import SparseParameter, get_param_data +from vllm.model_executor.layers.parameters import SparseParameter, get_param_data logger = init_logger(__name__) diff --git a/vllm/model_executor/parameters/__init__.py b/vllm/model_executor/parameters/__init__.py deleted file mode 100644 index 75833a845c99e..0000000000000 --- a/vllm/model_executor/parameters/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from vllm.model_executor.parameters.sparsity import SparseParameter - -def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: - """Gets parameter data in dense format.""" - if isinstance(param, SparseParameter): - return param.get_dense_data() - else: - return param.data \ No newline at end of file diff --git a/vllm/model_executor/parameters/sparsity.py b/vllm/model_executor/parameters/sparsity.py deleted file mode 100644 index b8a3f14b5e219..0000000000000 --- a/vllm/model_executor/parameters/sparsity.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -import nm_gpu -from nm_gpu.SparseTensor import ( - SparseBitmaskStorageFormat, SparseTensor -) - -class SparseParameter(SparseTensor): - @staticmethod - def __new__( - cls, - shape: torch.Size, - dtype: torch.dtype, - ): - assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" - self = torch.Tensor._make_wrapper_subclass(cls, size=shape, dtype=dtype, requires_grad=False) - self.storage_format_cls = SparseBitmaskStorageFormat - self.compressed_data = None - self.dense_data = None - self._is_param = True - - return self - - def get_dense_data(self) -> torch.Tensor: - if self.dense_data is not None: - raise ValueError("Called get_data_dense() but dense_data already exists.") - self.dense_data = self._unpack() - return self.dense_data - - def _unpack(self) -> torch.Tensor: - if self.has_compressed_data(): - return self.compressed_data.to_dense() - else: - return torch.empty(size=self.shape, dtype=self.dtype, device="cuda") - - def pack(self) -> None: - if self.dense_data is None: - raise ValueError("Called pack() but dense_data does not exist.") - self.copy_(self.dense_data) - self.dense_data = None diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 46d2915729fda..90328dd023900 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -19,7 +19,8 @@ QuantizationConfig) from vllm.model_executor.layers.sparsity import (get_sparsity_config, SparsityConfig) -from vllm.model_executor.parameters import SparseParameter, get_param_data +from vllm.model_executor.layers.parameters import (get_param_data, + SparseParameter) logger = init_logger(__name__) From f7e2bf54f8dd37ab787d3d8933325c522a32114e Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sat, 27 Jan 2024 20:14:24 +0000 Subject: [PATCH 04/12] readded moved files --- .../layers/parameters/__init__.py | 9 +++++ .../layers/parameters/sparsity.py | 39 +++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 vllm/model_executor/layers/parameters/__init__.py create mode 100644 vllm/model_executor/layers/parameters/sparsity.py diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py new file mode 100644 index 0000000000000..40ab2297e369e --- /dev/null +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -0,0 +1,9 @@ +import torch +from vllm.model_executor.layers.parameters.sparsity import SparseParameter + +def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: + """Gets parameter data in dense format.""" + if isinstance(param, SparseParameter): + return param.get_dense_data() + else: + return param.data \ No newline at end of file diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py new file mode 100644 index 0000000000000..b8a3f14b5e219 --- /dev/null +++ b/vllm/model_executor/layers/parameters/sparsity.py @@ -0,0 +1,39 @@ +import torch +import nm_gpu +from nm_gpu.SparseTensor import ( + SparseBitmaskStorageFormat, SparseTensor +) + +class SparseParameter(SparseTensor): + @staticmethod + def __new__( + cls, + shape: torch.Size, + dtype: torch.dtype, + ): + assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" + self = torch.Tensor._make_wrapper_subclass(cls, size=shape, dtype=dtype, requires_grad=False) + self.storage_format_cls = SparseBitmaskStorageFormat + self.compressed_data = None + self.dense_data = None + self._is_param = True + + return self + + def get_dense_data(self) -> torch.Tensor: + if self.dense_data is not None: + raise ValueError("Called get_data_dense() but dense_data already exists.") + self.dense_data = self._unpack() + return self.dense_data + + def _unpack(self) -> torch.Tensor: + if self.has_compressed_data(): + return self.compressed_data.to_dense() + else: + return torch.empty(size=self.shape, dtype=self.dtype, device="cuda") + + def pack(self) -> None: + if self.dense_data is None: + raise ValueError("Called pack() but dense_data does not exist.") + self.copy_(self.dense_data) + self.dense_data = None From b3c32f36fa022cb76066472570868bc56613830d Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Sat, 27 Jan 2024 15:17:58 -0500 Subject: [PATCH 05/12] Update __init__.py --- vllm/model_executor/layers/parameters/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py index 40ab2297e369e..117ba5747d751 100644 --- a/vllm/model_executor/layers/parameters/__init__.py +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -6,4 +6,4 @@ def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: if isinstance(param, SparseParameter): return param.get_dense_data() else: - return param.data \ No newline at end of file + return param.data From 7d59c40af42c1f2e86a47305f5f9dd4f8e4cc991 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Sat, 27 Jan 2024 15:42:29 -0500 Subject: [PATCH 06/12] Update linear.py --- vllm/model_executor/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f1e244e9c1f8d..fb2b227a2d50c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -205,7 +205,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If SparsParameter, repack dense data as sparse. + # If SparseParameter, repack dense data as sparse. if isinstance(param, SparseParameter): param.pack() From cd9c32d7aad5ed94b77b88e63190c2954131bfb1 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Sat, 27 Jan 2024 15:43:02 -0500 Subject: [PATCH 07/12] Update arg_utils.py --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5f7d2ed25ea5d..9b3a3e5823454 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -205,7 +205,7 @@ def add_cli_args( default=None, help='Method used to compress sparse weights. If ' 'None, we first check the `sparsity_config` attribute ' - 'in the model config gile. If that is None we assume ' + 'in the model config file. If that is None we assume ' 'the model weights are dense') parser.add_argument('--enforce-eager', action='store_true', From 669ab5b03b25df8fc0c16fae90255c0c9ca84b51 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 28 Jan 2024 22:44:36 +0000 Subject: [PATCH 08/12] missed files --- .../layers/sparsity/__init__.py | 18 ++++ .../layers/sparsity/base_config.py | 50 ++++++++++ .../layers/sparsity/sparse_w16a16.py | 98 +++++++++++++++++++ 3 files changed, 166 insertions(+) create mode 100644 vllm/model_executor/layers/sparsity/__init__.py create mode 100644 vllm/model_executor/layers/sparsity/base_config.py create mode 100644 vllm/model_executor/layers/sparsity/sparse_w16a16.py diff --git a/vllm/model_executor/layers/sparsity/__init__.py b/vllm/model_executor/layers/sparsity/__init__.py new file mode 100644 index 0000000000000..6c74bb66be70b --- /dev/null +++ b/vllm/model_executor/layers/sparsity/__init__.py @@ -0,0 +1,18 @@ +from typing import Type + +from vllm.model_executor.layers.sparsity.base_config import SparsityConfig +from vllm.model_executor.layers.sparsity.sparse_w16a16 import SparseW16A16Config + +_SPARSITY_CONFIG_REGISTRY = { + "sparse_w16a16": SparseW16A16Config, +} + +def get_sparsity_config(sparsity: str) -> Type[SparsityConfig]: + if sparsity not in _SPARSITY_CONFIG_REGISTRY: + raise ValueError(f"Invalid sparsity method: {sparsity}") + return _SPARSITY_CONFIG_REGISTRY[sparsity] + +__all__ = [ + "SparsityConfig", + "get_sparsity_config", +] \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/base_config.py b/vllm/model_executor/layers/sparsity/base_config.py new file mode 100644 index 0000000000000..4017636a26d86 --- /dev/null +++ b/vllm/model_executor/layers/sparsity/base_config.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.layers.linear import LinearMethodBase + +class SparsityConfig(ABC): + """Base class for sparsity configs.""" + + @abstractmethod + def get_name(self) -> str: + """Name of the sparse method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + """List of supported act_dtypes.""" + raise NotImplementedError + + @abstractmethod + def get_min_capability(self) -> int: + """Minimum GPU capability to support the sparsity method.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: Dict[str, Any]) -> "BaseSparsityConfig": + """Create a config class from the model's sparse config.""" + raise NotImplementedError + + @staticmethod + def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + """Get a value from the model's sparsity config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError(f"Cannot find any of {keys} in the model's " + "sparsity config.") + + @abstractmethod + def get_linear_method(self) -> LinearMethodBase: + """Get the linear method to use for the sparse linear layer.""" + raise NotImplementedError \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py new file mode 100644 index 0000000000000..8f13920af64ec --- /dev/null +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -0,0 +1,98 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.sparsity.base_config import SparsityConfig +from vllm.model_executor.layers.parameters import SparseParameter + +import nm_gpu + +class SparseW16A16Config(SparsityConfig): + """Config class for SparseW16A16. + + TODO: Add based on need + """ + + def __init__(self) -> None: + # TODO: Add new configs here + pass + + def __repr__(self) -> str: + return "SparseW16A16Config()" + + @classmethod + def get_name(cls) -> str: + return "sparse_w16a16" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # TODO: Update after checks on more GPUs + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["sparsity_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config": + return cls() + + def get_linear_method(self) -> "SparseW16A16LinearMethod": + return SparseW16A16LinearMethod(self) + +class SparseW16A16LinearMethod(LinearMethodBase): + """Linear method for Sparse W16A16. + + Args: + sparsity_config: The sparse config. + """ + + def __init__(self, sparsity_config: SparseW16A16Config): + self.sparsity_config = sparsity_config + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + weight = SparseParameter( + shape=torch.Size((output_size_per_partition, input_size_per_partition)), + dtype=params_dtype, + ) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + return {"weight": weight} + + def apply_weights( + self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + sparse_weight = weights["weight"] + + # Uncompress to dense + dense_weight = sparse_weight.to_dense() + + # # Uncomment to verify sparsity + # density = torch.count_nonzero( + # dense_weight).item() / dense_weight.numel() + # print(f"sparsity = {1.0 - density}") + + # Standard matrix multiply + if bias is not None: + output = F.linear(x, dense_weight, bias) + else: + output = F.linear(x, dense_weight) + + return output \ No newline at end of file From 611cb236bb05474bd696405a5b58984f8ba9abe4 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Mon, 29 Jan 2024 06:49:51 -0500 Subject: [PATCH 09/12] Update __init__.py newline --- vllm/model_executor/layers/sparsity/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sparsity/__init__.py b/vllm/model_executor/layers/sparsity/__init__.py index 6c74bb66be70b..49df5223f6161 100644 --- a/vllm/model_executor/layers/sparsity/__init__.py +++ b/vllm/model_executor/layers/sparsity/__init__.py @@ -15,4 +15,4 @@ def get_sparsity_config(sparsity: str) -> Type[SparsityConfig]: __all__ = [ "SparsityConfig", "get_sparsity_config", -] \ No newline at end of file +] From f9fe7813de55cf23556120163460f89f5160e5c9 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Mon, 29 Jan 2024 06:50:14 -0500 Subject: [PATCH 10/12] Update base_config.py --- vllm/model_executor/layers/sparsity/base_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sparsity/base_config.py b/vllm/model_executor/layers/sparsity/base_config.py index 4017636a26d86..52c787d18de3d 100644 --- a/vllm/model_executor/layers/sparsity/base_config.py +++ b/vllm/model_executor/layers/sparsity/base_config.py @@ -47,4 +47,4 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: @abstractmethod def get_linear_method(self) -> LinearMethodBase: """Get the linear method to use for the sparse linear layer.""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError From aa434c2371076c16f1fcf117186ee7b6363c3616 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Mon, 29 Jan 2024 06:50:32 -0500 Subject: [PATCH 11/12] Update sparse_w16a16.py newline --- vllm/model_executor/layers/sparsity/sparse_w16a16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py index 8f13920af64ec..497996266794c 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -95,4 +95,4 @@ def apply_weights( else: output = F.linear(x, dense_weight) - return output \ No newline at end of file + return output From 87e7516228a19754a97eaa6071db924a66ac05e5 Mon Sep 17 00:00:00 2001 From: alexm Date: Thu, 1 Feb 2024 17:00:36 -0500 Subject: [PATCH 12/12] Sync with latest changes from magic_wand --- README.md | 6 +- examples/offline_bench.py | 111 ++++++++++++++++++ vllm/config.py | 18 +-- vllm/engine/arg_utils.py | 32 ++--- vllm/model_executor/layers/linear.py | 14 ++- .../layers/parameters/__init__.py | 1 + .../layers/parameters/sparsity.py | 27 +++-- .../layers/sparsity/__init__.py | 2 + .../layers/sparsity/base_config.py | 3 +- .../layers/sparsity/sparse_w16a16.py | 5 +- vllm/model_executor/model_loader.py | 8 +- vllm/model_executor/utils.py | 1 + vllm/model_executor/weight_utils.py | 12 +- 13 files changed, 188 insertions(+), 52 deletions(-) create mode 100644 examples/offline_bench.py diff --git a/README.md b/README.md index 95285b894d799..c126c65717f42 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,11 @@ Fork of vLLM with sparsity. ### To Run -Clone and install nm_gpu: +Clone and install magic_wand: ```bash -git clone https://github.com/neuralmagic/nm_gpu.git -cd nm_gpu +git clone https://github.com/neuralmagic/magic_wand.git +cd magic_wand export TORCH_CUDA_ARCH_LIST=8.6 pip install -e . ``` diff --git a/examples/offline_bench.py b/examples/offline_bench.py new file mode 100644 index 0000000000000..ae7b391da0c39 --- /dev/null +++ b/examples/offline_bench.py @@ -0,0 +1,111 @@ +import random +import time +import argparse + +from vllm import LLM, SamplingParams + +NUM_REQUESTS_DEFAULT = 256 +MAX_SEQ_LEN_DEFAULT = 1024 +MAX_TOKENS_DEFAULT = 128 +SAMPLE_PROMPTS = [ + # "Hello, my name is", + # "The president of the United States is", + # "The capital of France is", + "The future of AI is", +] + + +def run_bench(model_name, + model_revision, + is_sparse, + quant_method, + max_seq_len, + max_tokens, + num_requests, + num_gpus, + num_warmup_iters=1, + num_bench_iters=5, + possible_prompts=SAMPLE_PROMPTS, + enforce_eager=True): + print("Run bench with:") + print(f" model_name = {model_name}") + print(f" model_revision = {model_revision}") + print(f" is_sparse = {is_sparse}") + print(f" quant_method = {quant_method}") + print(f" max_seq_len = {max_seq_len}") + print(f" max_tokens = {max_tokens}") + print(f" num_requests = {num_requests}") + print(f" num_gpus = {num_gpus}") + print(f" num_warmup_iters = {num_warmup_iters}") + print(f" num_bench_iters = {num_bench_iters}") + + prompts = [] + for _ in range(num_requests): + index = random.randint(0, len(possible_prompts) - 1) + prompts.append(possible_prompts[index]) + + # Create sampling params + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=max_tokens) + + # Create LLM + llm = LLM( + model=model_name, + revision=model_revision, + sparsity="sparse_w16a16" if is_sparse else None, + enforce_eager=enforce_eager, + # dtype=torch.bfloat16, + tensor_parallel_size=num_gpus, + gpu_memory_utilization=0.9, + max_model_len=max_seq_len, + quantization=quant_method, + ) + + for i in range(num_warmup_iters): + start_time = time.time() + outputs = llm.generate(prompts, sampling_params) + elapsed_time = time.time() - start_time + print(f"Warmup iter {i} time: {elapsed_time} [secs]") + + iter_times = [] + for i in range(num_bench_iters): + start_time = time.time() + outputs = llm.generate(prompts, sampling_params) + iter_times.append(time.time() - start_time) + print(f"Bench iter {i} time: {iter_times[-1]} [secs]") + + average_iter_time = sum(iter_times) / num_bench_iters + print(f"Average per iter time: {average_iter_time} [secs]") + + # Print outputs of the last iter + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + return average_iter_time + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--model_revision", type=str, default=None) + parser.add_argument('--is_sparse', action='store_true') + parser.add_argument("--quant_method", type=str, default=None) + parser.add_argument("--max_seq_len", type=int, default=MAX_SEQ_LEN_DEFAULT) + parser.add_argument("--max_tokens", type=int, default=MAX_TOKENS_DEFAULT) + parser.add_argument("--num_requests", + type=int, + default=NUM_REQUESTS_DEFAULT) + parser.add_argument("--num_gpus", type=int, default=1) + parser.add_argument("--num_warmup_iters", type=int, default=1) + parser.add_argument("--num_bench_iters", type=int, default=5) + + args = parser.parse_args() + + run_bench(args.model_name, args.model_revision, args.is_sparse, + args.quant_method, args.max_seq_len, args.max_tokens, + args.num_requests, args.num_gpus, args.num_warmup_iters, + args.num_bench_iters) diff --git a/vllm/config.py b/vllm/config.py index 3df920805d7a2..d735819c0c2b1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -151,17 +151,17 @@ def _verify_sparsity(self) -> None: supported_sparsity = ["sparse_w16a16"] if self.quantization is not None: - raise ValueError(f"Both sparsity and quantization detected. Only " - "one or the other is supported at a time.") - - if self.sparsity is not None: - if self.sparsity not in supported_sparsity: - raise ValueError(f"Unknown sparse method: {self.sparsity}. Must " - f"be one of {supported_sparse}.") - + raise ValueError("Both sparsity and quantization detected. Only " + "one or the other is supported at a time.") + + if self.sparsity is not None and self.sparsity not in supported_sparsity: + raise ValueError(f"Unknown sparse method: {self.sparsity}. Must " + f"be one of {supported_sparsity}.") + hf_sparsity_config = getattr(self.hf_config, "sparsity_config", None) if hf_sparsity_config is not None: - hf_sparsity_method = str(hf_sparse_config["sparse_method"]).lower() + hf_sparsity_method = str( + hf_sparsity_config["sparse_method"]).lower() if self.sparsity is None: self.sparsity = hf_sparsity_method elif self.sparsity != hf_sparsity_method: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9b3a3e5823454..834908e0cf238 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -198,15 +198,16 @@ def add_cli_args( 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') - parser.add_argument('--sparsity', - '-s', - type=str, - choices=['sparse_w16a16', None], - default=None, - help='Method used to compress sparse weights. If ' - 'None, we first check the `sparsity_config` attribute ' - 'in the model config file. If that is None we assume ' - 'the model weights are dense') + parser.add_argument( + '--sparsity', + '-s', + type=str, + choices=['sparse_w16a16', None], + default=None, + help='Method used to compress sparse weights. If ' + 'None, we first check the `sparsity_config` attribute ' + 'in the model config file. If that is None we assume ' + 'the model weights are dense') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -265,13 +266,12 @@ def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, Optional[LoRAConfig]]: - model_config = ModelConfig(self.model, self.tokenizer, - self.tokenizer_mode, self.trust_remote_code, - self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, - self.tokenizer_revision, self.max_model_len, - self.quantization, self.sparsity, - self.enforce_eager, self.max_context_len_to_capture) + model_config = ModelConfig( + self.model, self.tokenizer, self.tokenizer_mode, + self.trust_remote_code, self.download_dir, self.load_format, + self.dtype, self.seed, self.revision, self.tokenizer_revision, + self.max_model_len, self.quantization, self.sparsity, + self.enforce_eager, self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fb2b227a2d50c..d09db721d712b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -17,6 +17,7 @@ logger = init_logger(__name__) + class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @@ -223,6 +224,7 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None return output, output_bias + class MergedColumnParallelLinear(ColumnParallelLinear): """Packed linear layers with column parallelism. @@ -269,7 +271,8 @@ def weight_loader(self, if loaded_shard_id is None: if isinstance(param, SparseParameter): raise NotImplementedError( - "Passing loaded_shard_id=None not yet supported for SparseParameter") + "Passing loaded_shard_id=None not yet supported for SparseParameter" + ) # Loaded weight is already packed. if output_dim is None: @@ -324,6 +327,7 @@ def weight_loader(self, if isinstance(param, SparseParameter): param.pack() + class QKVParallelLinear(ColumnParallelLinear): """Linear layers for the attention's QKV transformation. @@ -384,13 +388,14 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): + loaded_shard_id: Optional[str] = None): param_data = get_param_data(param) output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: if isinstance(param, SparseParameter): raise NotImplementedError( - "Passing loaded_shard_id=None not yet supported for SparseParameter") + "Passing loaded_shard_id=None not yet supported for SparseParameter" + ) # Loaded weight is already packed. if output_dim is None: @@ -459,6 +464,7 @@ def weight_loader(self, if isinstance(param, SparseParameter): param.pack() + class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. @@ -548,7 +554,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - + # If SparseParameter, repack dense data as sparse. if isinstance(param, SparseParameter): param.pack() diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py index 117ba5747d751..2d41190087a0d 100644 --- a/vllm/model_executor/layers/parameters/__init__.py +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -1,6 +1,7 @@ import torch from vllm.model_executor.layers.parameters.sparsity import SparseParameter + def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: """Gets parameter data in dense format.""" if isinstance(param, SparseParameter): diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py index b8a3f14b5e219..37ddd05d89636 100644 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ b/vllm/model_executor/layers/parameters/sparsity.py @@ -1,36 +1,43 @@ import torch -import nm_gpu -from nm_gpu.SparseTensor import ( - SparseBitmaskStorageFormat, SparseTensor -) + +from magic_wand import SparseTensor, SparseBitmaskStorageFormat + class SparseParameter(SparseTensor): + @staticmethod def __new__( cls, shape: torch.Size, dtype: torch.dtype, ): - assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" - self = torch.Tensor._make_wrapper_subclass(cls, size=shape, dtype=dtype, requires_grad=False) + assert torch.__version__ > (1, + 10), "SparseTensor requires PyTorch 1.11+" + self = torch.Tensor._make_wrapper_subclass(cls, + size=shape, + dtype=dtype, + requires_grad=False) self.storage_format_cls = SparseBitmaskStorageFormat self.compressed_data = None self.dense_data = None self._is_param = True - + return self def get_dense_data(self) -> torch.Tensor: if self.dense_data is not None: - raise ValueError("Called get_data_dense() but dense_data already exists.") + raise ValueError( + "Called get_data_dense() but dense_data already exists.") self.dense_data = self._unpack() return self.dense_data def _unpack(self) -> torch.Tensor: if self.has_compressed_data(): - return self.compressed_data.to_dense() + return self.compressed_data.decompress() else: - return torch.empty(size=self.shape, dtype=self.dtype, device="cuda") + return torch.empty(size=self.shape, + dtype=self.dtype, + device="cuda") def pack(self) -> None: if self.dense_data is None: diff --git a/vllm/model_executor/layers/sparsity/__init__.py b/vllm/model_executor/layers/sparsity/__init__.py index 49df5223f6161..411d1ff642266 100644 --- a/vllm/model_executor/layers/sparsity/__init__.py +++ b/vllm/model_executor/layers/sparsity/__init__.py @@ -7,11 +7,13 @@ "sparse_w16a16": SparseW16A16Config, } + def get_sparsity_config(sparsity: str) -> Type[SparsityConfig]: if sparsity not in _SPARSITY_CONFIG_REGISTRY: raise ValueError(f"Invalid sparsity method: {sparsity}") return _SPARSITY_CONFIG_REGISTRY[sparsity] + __all__ = [ "SparsityConfig", "get_sparsity_config", diff --git a/vllm/model_executor/layers/sparsity/base_config.py b/vllm/model_executor/layers/sparsity/base_config.py index 52c787d18de3d..aa09fb623bc00 100644 --- a/vllm/model_executor/layers/sparsity/base_config.py +++ b/vllm/model_executor/layers/sparsity/base_config.py @@ -5,6 +5,7 @@ from vllm.model_executor.layers.linear import LinearMethodBase + class SparsityConfig(ABC): """Base class for sparsity configs.""" @@ -31,7 +32,7 @@ def get_config_filenames() -> List[str]: @classmethod @abstractmethod - def from_config(cls, config: Dict[str, Any]) -> "BaseSparsityConfig": + def from_config(cls, config: Dict[str, Any]) -> "SparsityConfig": """Create a config class from the model's sparse config.""" raise NotImplementedError diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py index 497996266794c..771fae9b8ff45 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -7,7 +7,6 @@ from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from vllm.model_executor.layers.parameters import SparseParameter -import nm_gpu class SparseW16A16Config(SparsityConfig): """Config class for SparseW16A16. @@ -46,6 +45,7 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config": def get_linear_method(self) -> "SparseW16A16LinearMethod": return SparseW16A16LinearMethod(self) + class SparseW16A16LinearMethod(LinearMethodBase): """Linear method for Sparse W16A16. @@ -65,7 +65,8 @@ def create_weights( params_dtype: torch.dtype, ) -> Dict[str, Any]: weight = SparseParameter( - shape=torch.Size((output_size_per_partition, input_size_per_partition)), + shape=torch.Size( + (output_size_per_partition, input_size_per_partition)), dtype=params_dtype, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 3dc16011b8144..aa777b88c216c 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -8,9 +8,11 @@ from vllm.config import ModelConfig, LoRAConfig from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.weight_utils import (get_quant_config, get_sparse_config, +from vllm.model_executor.weight_utils import (get_quant_config, + get_sparse_config, initialize_dummy_weights) + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" @@ -58,9 +60,9 @@ def get_model(model_config: ModelConfig, f"{supported_dtypes}") linear_method = quant_config.get_linear_method() if model_config.sparsity is not None: - sparse_config = get_sparse_config(model_config.sparsity, + sparse_config = get_sparse_config(model_config.sparsity, model_config.model, - model_config.hf_config, + model_config.hf_config, model_config.download_dir) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 23ae3eb1146c7..336bc1cd005cf 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -5,6 +5,7 @@ import numpy as np import torch + def set_random_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 90328dd023900..33332e77ae8e2 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -19,11 +19,12 @@ QuantizationConfig) from vllm.model_executor.layers.sparsity import (get_sparsity_config, SparsityConfig) -from vllm.model_executor.layers.parameters import (get_param_data, +from vllm.model_executor.layers.parameters import (get_param_data, SparseParameter) logger = init_logger(__name__) + class Disabledtqdm(tqdm): def __init__(self, *args, **kwargs): @@ -84,7 +85,8 @@ def convert_bin_to_safetensor_file( if not torch.equal(pt_tensor, sf_tensor): raise RuntimeError(f"The output tensors do not match for key {k}") -# TODO(rib-2): Once we define hf_sparsity_config + +# TODO(rib-2): Once we define hf_sparsity_config def get_sparse_config( sparsity: str, model_name_or_path: str, @@ -95,10 +97,10 @@ def get_sparse_config( hf_sparsity_config = getattr(hf_config, "sparsity_config", None) if hf_sparsity_config is not None: raise NotImplementedError( - "Loading hf sparsity config not yet supported" - ) + "Loading hf sparsity config not yet supported") return sparsity_cls() + # TODO(woosuk): Move this to other place. def get_quant_config( quantization: str, @@ -293,6 +295,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: x = x[:] return x + def default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" @@ -301,6 +304,7 @@ def default_weight_loader(param: torch.nn.Parameter, if isinstance(param, SparseParameter): param.pack() + def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3,