@@ -146,6 +146,129 @@ public:
146
146
__device__ __forceinline__ const uint32_t * item_uint32_ptr (int row, int column) { return &data[row / 8 * width + column]; }
147
147
};
148
148
149
+ class MatrixView_q2_row
150
+ {
151
+ public:
152
+ const uint32_t * data;
153
+ const int height;
154
+ const int width;
155
+
156
+ __device__ __forceinline__ MatrixView_q2_row (const uint32_t * data, const int height, const int width)
157
+ : data(data), height(height), width(width)
158
+ { }
159
+
160
+ __device__ __forceinline__ int item (int row, int column) const
161
+ {
162
+ int shift = (column & 0x0f ) * 2 ;
163
+ return (data[row * width / 16 + column / 16 ] >> shift) & 0x03 ;
164
+ }
165
+
166
+ __device__ __forceinline__ void item2 (int (&items)[2], int row, int column) const
167
+ {
168
+ int shift = (column & 0x0f ) * 2 ;
169
+ uint32_t d = data[row * width / 16 + column / 16 ] >> shift;
170
+ items[0 ] = d & 0x03 ;
171
+ items[1 ] = (d >> 2 ) & 0x03 ;
172
+ }
173
+
174
+ __device__ __forceinline__ void item4 (int (&items)[4], int row, int column) const
175
+ {
176
+ int shift = (column & 0x0f ) * 2 ;
177
+ uint32_t d = data[row * width / 16 + column / 16 ] >> shift;
178
+ items[0 ] = d & 0x03 ;
179
+ items[1 ] = (d >> 2 ) & 0x03 ;
180
+ items[2 ] = (d >> 4 ) & 0x03 ;
181
+ items[3 ] = (d >> 6 ) & 0x03 ;
182
+ }
183
+ };
184
+
185
+ class MatrixView_q3_row
186
+ {
187
+ public:
188
+ const uint32_t * data;
189
+ const int height;
190
+ const int width;
191
+
192
+ __device__ __forceinline__ MatrixView_q3_row (const uint32_t * data, const int height, const int width)
193
+ : data(data), height(height), width(width)
194
+ { }
195
+
196
+ __device__ __forceinline__ int item (int row, int column) const
197
+ {
198
+ int z_w = column * 3 / 32 ;
199
+ int z_mod = column & 0x1f ;
200
+
201
+ if (z_mod == 10 ) {
202
+ return (data[row * width * 3 / 32 + z_w] >> 30 ) | ((data[row * width * 3 / 32 + (z_w + 1 )] << 2 ) & 0x4 );
203
+ } else if (z_mod == 21 ) {
204
+ return (data[row * width * 3 / 32 + z_w] >> 31 ) | ((data[row * width * 3 / 32 + (z_w + 1 )] << 1 ) & 0x6 );
205
+ } else if (z_mod < 10 ) {
206
+ return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 )) & 0x07 ;
207
+ } else if (z_mod < 21 ) {
208
+ return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32 )) & 0x07 ;
209
+ } else {
210
+ return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64 )) & 0x07 ;
211
+ }
212
+ }
213
+
214
+ __device__ __forceinline__ void item4 (int (&items)[4], int row, int column) const
215
+ {
216
+ int shift = (column & 0x1f );
217
+ uint32_t d;
218
+ if (shift <= 4 ) {
219
+ d = data[row * width / 32 * 3 + column * 3 / 32 ] >> (shift * 3 );
220
+ } else if (shift == 8 ) {
221
+ d = (data[row * width / 32 * 3 + column * 3 / 32 ] >> 24 ) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1 ] & 0x0f ) << 8 );
222
+ } else if (shift <= 16 ) {
223
+ d = data[row * width / 32 * 3 + column * 3 / 32 ] >> (shift * 3 - 32 );
224
+ } else if (shift == 20 ) {
225
+ d = (data[row * width / 32 * 3 + column * 3 / 32 ] >> 28 ) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1 ] & 0xff ) << 4 );
226
+ } else {
227
+ d = data[row * width / 32 * 3 + column * 3 / 32 ] >> (shift * 3 - 64 );
228
+ }
229
+ items[0 ] = d & 0x07 ;
230
+ items[1 ] = (d >> 3 ) & 0x07 ;
231
+ items[2 ] = (d >> 6 ) & 0x07 ;
232
+ items[3 ] = (d >> 9 ) & 0x07 ;
233
+ }
234
+ };
235
+
236
+ class MatrixView_q8_row
237
+ {
238
+ public:
239
+ const uint32_t * data;
240
+ const int height;
241
+ const int width;
242
+
243
+ __device__ __forceinline__ MatrixView_q8_row (const uint32_t * data, const int height, const int width)
244
+ : data(data), height(height), width(width)
245
+ { }
246
+
247
+ __device__ __forceinline__ int item (int row, int column) const
248
+ {
249
+ int shift = (column & 0x03 ) * 8 ;
250
+ return (data[row * width / 4 + column / 4 ] >> shift) & 0xff ;
251
+ }
252
+
253
+ __device__ __forceinline__ void item2 (int (&items)[2], int row, int column) const
254
+ {
255
+ int shift = (column & 0x03 ) * 8 ;
256
+ uint32_t d = data[row * width / 4 + column / 4 ] >> shift;
257
+ items[0 ] = d & 0xff ;
258
+ items[1 ] = (d >> 8 ) & 0xff ;
259
+ }
260
+
261
+ __device__ __forceinline__ void item4 (int (&items)[4], int row, int column) const
262
+ {
263
+ int shift = (column & 0x03 ) * 2 ;
264
+ uint32_t d = data[row * width / 4 + column / 4 ] >> shift;
265
+ items[0 ] = d & 0xff ;
266
+ items[1 ] = (d >> 8 ) & 0xff ;
267
+ items[2 ] = (d >> 16 ) & 0xff ;
268
+ items[3 ] = (d >> 24 ) & 0xff ;
269
+ }
270
+ };
271
+
149
272
} // namespace gptq
150
273
} // namespace vllm
151
274
#endif
0 commit comments