Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit b61bc82

Browse files
authored
Use naive decompress for SM<8.0 (#32)
A warning will be printed out if this case is triggered: ``` WARNING 02-20 22:21:27 sparse_w16a16.py:32] Unstructured sparse kernels are not optimized for NVIDIA SM < 8.0. Naive decompress kernels will be used and can be slower than dense models ``` Works on a T4 with: ```python from vllm import LLM, SamplingParams model = LLM( "nm-testing/opt-125m-pruned2.4", sparsity="sparse_w16a16", enforce_eager=True, dtype="float16", ) sampling_params = SamplingParams(max_tokens=100, temperature=0) outputs = model.generate("Hello my name is", sampling_params=sampling_params) outputs[0].outputs[0].text ``` Test within colab: https://colab.research.google.com/drive/15xRvWX5gNaTb00BcaXhxwMm6yxavIKGN?usp=sharing
1 parent ab469e5 commit b61bc82

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

vllm/model_executor/layers/sparsity/sparse_w16a16.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,36 @@
22

33
import torch
44

5+
from vllm.logger import init_logger
56
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
67

78
from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
8-
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat)
9+
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat,
10+
SparseBEGemmStorageFormat)
911

12+
logger = init_logger(__name__)
1013

11-
class SparseW16A16Config(SparsityConfig):
12-
"""Config class for SparseW16A16.
1314

14-
TODO: Add based on need
15-
"""
15+
class SparseW16A16Config(SparsityConfig):
16+
"""Config class for SparseW16A16."""
1617

1718
def __init__(self) -> None:
18-
# TODO: Add new configs here
1919
pass
2020

2121
def __repr__(self) -> str:
2222
return "SparseW16A16Config()"
2323

2424
@classmethod
2525
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
26-
return SparseBEGemmStorageFormat
26+
cuda_compute_capability = torch.cuda.get_device_capability()
27+
if cuda_compute_capability >= (8, 0):
28+
return SparseBEGemmStorageFormat
29+
else:
30+
# For NVIDIA SM < 8.0
31+
logger.warning("Unstructured sparse kernels are not optimized for "
32+
"NVIDIA SM < 8.0. Naive decompress kernels will be "
33+
"used and can be slower than dense models")
34+
return SparseBitmaskStorageFormat
2735

2836
@classmethod
2937
def get_name(cls) -> str:
@@ -35,8 +43,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
3543

3644
@classmethod
3745
def get_min_capability(cls) -> int:
38-
# TODO: Update after checks on more GPUs
39-
return 80
46+
return 70
4047

4148
@classmethod
4249
def get_config_filenames(cls) -> List[str]:

0 commit comments

Comments
 (0)