Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit c45f20d

Browse files
afeldman-nmrsnm2
authored and
rsnm2
committed
Semi-structured 2:4 sparsity via SparseSemiStructuredTensor #4
magic_wand semi_structured_sparse_tensor_linear branch integrates 2:4 semi-structured sparsity into SparseTensor. This PR adds a new sparsity config for 2:4 sparsity to neuralmagic-vllm, using the SparseTensor 2:4 support. This PR also refactors the sparse linear method into a separate file, vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py, which supports all sparsity formats.
1 parent 0c9e195 commit c45f20d

8 files changed

+159
-66
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from vllm import LLM, SamplingParams
2+
3+
model = LLM("nm-testing/zephyr-50sparse-24",
4+
sparsity="semi_structured_sparse_w16a16",
5+
enforce_eager=True,
6+
dtype="float16",
7+
tensor_parallel_size=1,
8+
max_model_len=1024)
9+
10+
sampling_params = SamplingParams(max_tokens=100, temperature=0)
11+
outputs = model.generate("Hello my name is", sampling_params=sampling_params)
12+
print(outputs[0].outputs[0].text)

vllm/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _verify_tokenizer_mode(self) -> None:
158158
self.tokenizer_mode = tokenizer_mode
159159

160160
def _verify_sparsity(self) -> None:
161-
supported_sparsity = ["sparse_w16a16"]
161+
supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"]
162162

163163
if self.quantization is not None:
164164
raise ValueError("Both sparsity and quantization detected. Only "

vllm/model_executor/layers/parameters/sparsity.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
11
import torch
22

3-
from magic_wand import SparseTensor, SparseBitmaskStorageFormat
3+
from typing import Type
4+
from magic_wand import (SparseTensor, CompressedStorageFormat,
5+
SparseBitmaskStorageFormat)
46

57

68
class SparseParameter(SparseTensor):
79

810
@staticmethod
9-
def __new__(
10-
cls,
11-
shape: torch.Size,
12-
dtype: torch.dtype,
13-
):
11+
def __new__(cls,
12+
shape: torch.Size,
13+
dtype: torch.dtype,
14+
storage_format_cls: Type[
15+
CompressedStorageFormat] = SparseBitmaskStorageFormat):
1416
assert torch.__version__ > (1,
1517
10), "SparseTensor requires PyTorch 1.11+"
18+
1619
self = torch.Tensor._make_wrapper_subclass(cls,
1720
size=shape,
1821
dtype=dtype,
1922
requires_grad=False)
20-
self.storage_format_cls = SparseBitmaskStorageFormat
23+
self.storage_format_cls = storage_format_cls
2124
self.compressed_data = None
2225
self.dense_data = None
2326
self._is_param = True
2427

2528
return self
2629

30+
def has_compressed_data(self) -> bool:
31+
return (self.compressed_data is not None)
32+
2733
def get_dense_data(self) -> torch.Tensor:
2834
if self.dense_data is not None:
2935
raise ValueError(
@@ -39,6 +45,20 @@ def _unpack(self) -> torch.Tensor:
3945
dtype=self.dtype,
4046
device="cuda")
4147

48+
@classmethod
49+
def _copy(cls, arg0, arg1):
50+
assert arg0.shape == arg1.shape
51+
52+
if arg0.has_compressed_data():
53+
arg0.compressed_data.copy_(arg1)
54+
else:
55+
arg0.compressed_data = arg0.storage_format_cls.compress(arg1)
56+
57+
return arg0
58+
59+
def copy_(self, src, non_blocking=False):
60+
return SparseParameter._copy(self, src)
61+
4262
def pack(self) -> None:
4363
if self.dense_data is None:
4464
raise ValueError("Called pack() but dense_data does not exist.")

vllm/model_executor/layers/sparsity/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
44
from vllm.model_executor.layers.sparsity.sparse_w16a16 import SparseW16A16Config
5+
from vllm.model_executor.layers.sparsity.semi_structured_sparse_w16a16 import SemiStructuredSparseW16A16Config
56

67
_SPARSITY_CONFIG_REGISTRY = {
78
"sparse_w16a16": SparseW16A16Config,
9+
"semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config,
810
}
911

1012

vllm/model_executor/layers/sparsity/base_config.py

+7
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22
from typing import Any, Dict, List
33

44
import torch
5+
from typing import Type
56

67
from vllm.model_executor.layers.linear import LinearMethodBase
8+
from magic_wand import CompressedStorageFormat
79

810

911
class SparsityConfig(ABC):
1012
"""Base class for sparsity configs."""
1113

14+
@abstractmethod
15+
def get_storage_format_cls(self) -> Type[CompressedStorageFormat]:
16+
"""Sparse representation format"""
17+
raise NotImplementedError
18+
1219
@abstractmethod
1320
def get_name(self) -> str:
1421
"""Name of the sparse method."""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
3+
from typing import Any, Dict, List, Type
4+
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
5+
from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
6+
from magic_wand import (CompressedStorageFormat,
7+
SparseSemiStructuredStorageFormat)
8+
9+
10+
class SemiStructuredSparseW16A16Config(SparsityConfig):
11+
"""Config class for SemiStructuredSparseW16A16."""
12+
13+
def __init__(self) -> None:
14+
pass
15+
16+
def __repr__(self) -> str:
17+
return "SemiStructuredSparseW16A16Config()"
18+
19+
@classmethod
20+
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
21+
return SparseSemiStructuredStorageFormat
22+
23+
@classmethod
24+
def get_name(cls) -> str:
25+
return "semi_structured_sparse_w16a16"
26+
27+
@classmethod
28+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
29+
return [torch.float16, torch.bfloat16]
30+
31+
@classmethod
32+
def get_min_capability(cls) -> int:
33+
# TODO: Update after checks on more GPUs
34+
return 80
35+
36+
@classmethod
37+
def get_config_filenames(cls) -> List[str]:
38+
return ["sparsity_config.json"]
39+
40+
@classmethod
41+
def from_config(
42+
cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config":
43+
return cls()
44+
45+
def get_linear_method(self) -> "SparseW16A16LinearMethod":
46+
return SparseW16A16LinearMethod(self, self.get_storage_format_cls())
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Type
22

33
import torch
4-
import torch.nn.functional as F
54

6-
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
75
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
8-
from vllm.model_executor.layers.parameters import SparseParameter
6+
7+
from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
8+
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat)
99

1010

1111
class SparseW16A16Config(SparsityConfig):
@@ -21,6 +21,10 @@ def __init__(self) -> None:
2121
def __repr__(self) -> str:
2222
return "SparseW16A16Config()"
2323

24+
@classmethod
25+
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
26+
return SparseBitmaskStorageFormat
27+
2428
@classmethod
2529
def get_name(cls) -> str:
2630
return "sparse_w16a16"
@@ -43,57 +47,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config":
4347
return cls()
4448

4549
def get_linear_method(self) -> "SparseW16A16LinearMethod":
46-
return SparseW16A16LinearMethod(self)
47-
48-
49-
class SparseW16A16LinearMethod(LinearMethodBase):
50-
"""Linear method for Sparse W16A16.
51-
52-
Args:
53-
sparsity_config: The sparse config.
54-
"""
55-
56-
def __init__(self, sparsity_config: SparseW16A16Config):
57-
self.sparsity_config = sparsity_config
58-
59-
def create_weights(
60-
self,
61-
input_size_per_partition: int,
62-
output_size_per_partition: int,
63-
input_size: int,
64-
output_size: int,
65-
params_dtype: torch.dtype,
66-
) -> Dict[str, Any]:
67-
weight = SparseParameter(
68-
shape=torch.Size(
69-
(output_size_per_partition, input_size_per_partition)),
70-
dtype=params_dtype,
71-
)
72-
73-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
74-
75-
return {"weight": weight}
76-
77-
def apply_weights(
78-
self,
79-
weights: Dict[str, Any],
80-
x: torch.Tensor,
81-
bias: Optional[torch.Tensor] = None,
82-
) -> torch.Tensor:
83-
sparse_weight = weights["weight"]
84-
85-
# Uncompress to dense
86-
dense_weight = sparse_weight.to_dense()
87-
88-
# # Uncomment to verify sparsity
89-
# density = torch.count_nonzero(
90-
# dense_weight).item() / dense_weight.numel()
91-
# print(f"sparsity = {1.0 - density}")
92-
93-
# Standard matrix multiply
94-
if bias is not None:
95-
output = F.linear(x, dense_weight, bias)
96-
else:
97-
output = F.linear(x, dense_weight)
98-
99-
return output
50+
return SparseW16A16LinearMethod(self, self.get_storage_format_cls())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Any, Dict, Optional, Type
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
7+
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
8+
from vllm.model_executor.layers.parameters import SparseParameter
9+
from magic_wand import (CompressedStorageFormat,
10+
SparseSemiStructuredStorageFormat)
11+
12+
13+
class SparseW16A16LinearMethod(LinearMethodBase):
14+
"""Linear method for Sparse W16A16.
15+
16+
Args:
17+
sparsity_config: The sparse config.
18+
"""
19+
storage_format_cls: Type[CompressedStorageFormat] = None
20+
21+
def __init__(self, sparsity_config: SparsityConfig,
22+
storage_format_cls: Type[CompressedStorageFormat]):
23+
self.sparsity_config = sparsity_config
24+
self.storage_format_cls = storage_format_cls
25+
26+
def create_weights(self, input_size_per_partition: int,
27+
output_size_per_partition: int, input_size: int,
28+
output_size: int,
29+
params_dtype: torch.dtype) -> Dict[str, Any]:
30+
weight = SparseParameter(shape=torch.Size(
31+
(output_size_per_partition, input_size_per_partition)),
32+
dtype=params_dtype,
33+
storage_format_cls=self.storage_format_cls)
34+
35+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
36+
37+
return {"weight": weight}
38+
39+
def apply_weights(
40+
self,
41+
weights: Dict[str, Any],
42+
x: torch.Tensor,
43+
bias: Optional[torch.Tensor] = None,
44+
) -> torch.Tensor:
45+
sparse_weight = weights["weight"]
46+
47+
if self.storage_format_cls == SparseSemiStructuredStorageFormat:
48+
output = F.linear(x, sparse_weight, bias)
49+
return output
50+
else:
51+
52+
# Standard matrix multiply
53+
# Uncompress to dense
54+
output = F.linear(x, sparse_weight.to_dense(), bias)
55+
return output

0 commit comments

Comments
 (0)