Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit ab469e5

Browse files
authored
Add bias support for sparse layers (#25)
1 parent 47280b4 commit ab469e5

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,32 @@ def apply_weights(
5858
assert not w.has_compressed_data
5959
output = F.linear(x, w.uncompressed_data, bias)
6060
elif self.storage_format_cls == SparseSemiStructuredStorageFormat:
61-
assert bias is None
6261
w_encap = w.compressed_data.encapsulated_torch_sparse_tensor
6362
out_shape = (x.shape[:-1] + (w_encap.shape[0], ))
6463
reshaped_x, valid_rows_range = pad_tensor_to_multiple(
6564
x.reshape(-1, x.shape[-1]), 8)
65+
if bias is None:
66+
bias = torch.nn.Parameter(
67+
torch.zeros(
68+
(w_encap.shape[0], ),
69+
dtype=reshaped_x.dtype,
70+
device=reshaped_x.device,
71+
))
6672
output = F.linear(
67-
reshaped_x, w_encap,
68-
torch.nn.Parameter(torch.zeros((w_encap.shape[0], ))).to(
69-
reshaped_x.dtype).to(reshaped_x.device)).contiguous()
70-
output = extract_valid_rows(output, valid_rows_range)
71-
return output.reshape(out_shape)
73+
reshaped_x,
74+
w_encap,
75+
bias,
76+
).contiguous()
77+
output = extract_valid_rows(output,
78+
valid_rows_range).reshape(out_shape)
7279
elif self.storage_format_cls == SparseBEGemmStorageFormat:
73-
assert bias is None
7480
assert w.compress_transposed
7581
out_shape = (x.shape[:-1] + (w.shape[0], ))
7682
reshaped_x = x.reshape(-1, x.shape[-1])
77-
y = be_ds_gemm(reshaped_x, w.compressed_data)
78-
return y.reshape(out_shape)
83+
output = be_ds_gemm(reshaped_x,
84+
w.compressed_data).reshape(out_shape)
85+
if bias is not None:
86+
output = output + bias
7987
else:
8088
# Standard matrix multiply
8189
# Uncompress to dense

0 commit comments

Comments
 (0)