@@ -58,24 +58,32 @@ def apply_weights(
58
58
assert not w .has_compressed_data
59
59
output = F .linear (x , w .uncompressed_data , bias )
60
60
elif self .storage_format_cls == SparseSemiStructuredStorageFormat :
61
- assert bias is None
62
61
w_encap = w .compressed_data .encapsulated_torch_sparse_tensor
63
62
out_shape = (x .shape [:- 1 ] + (w_encap .shape [0 ], ))
64
63
reshaped_x , valid_rows_range = pad_tensor_to_multiple (
65
64
x .reshape (- 1 , x .shape [- 1 ]), 8 )
65
+ if bias is None :
66
+ bias = torch .nn .Parameter (
67
+ torch .zeros (
68
+ (w_encap .shape [0 ], ),
69
+ dtype = reshaped_x .dtype ,
70
+ device = reshaped_x .device ,
71
+ ))
66
72
output = F .linear (
67
- reshaped_x , w_encap ,
68
- torch .nn .Parameter (torch .zeros ((w_encap .shape [0 ], ))).to (
69
- reshaped_x .dtype ).to (reshaped_x .device )).contiguous ()
70
- output = extract_valid_rows (output , valid_rows_range )
71
- return output .reshape (out_shape )
73
+ reshaped_x ,
74
+ w_encap ,
75
+ bias ,
76
+ ).contiguous ()
77
+ output = extract_valid_rows (output ,
78
+ valid_rows_range ).reshape (out_shape )
72
79
elif self .storage_format_cls == SparseBEGemmStorageFormat :
73
- assert bias is None
74
80
assert w .compress_transposed
75
81
out_shape = (x .shape [:- 1 ] + (w .shape [0 ], ))
76
82
reshaped_x = x .reshape (- 1 , x .shape [- 1 ])
77
- y = be_ds_gemm (reshaped_x , w .compressed_data )
78
- return y .reshape (out_shape )
83
+ output = be_ds_gemm (reshaped_x ,
84
+ w .compressed_data ).reshape (out_shape )
85
+ if bias is not None :
86
+ output = output + bias
79
87
else :
80
88
# Standard matrix multiply
81
89
# Uncompress to dense
0 commit comments