Skip to content

Commit 3b42cc3

Browse files
chu-tianxiangjimpang
authored and
jimpang
committed
Add Support for 2/3/8-bit GPTQ Quantization Models (vllm-project#2330)
1 parent 4bb5344 commit 3b42cc3

File tree

8 files changed

+1736
-229
lines changed

8 files changed

+1736
-229
lines changed

csrc/ops.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ torch::Tensor gptq_gemm(
9898
torch::Tensor b_gptq_qzeros,
9999
torch::Tensor b_gptq_scales,
100100
torch::Tensor b_g_idx,
101-
bool use_exllama);
101+
bool use_exllama,
102+
int bit);
102103

103104
void gptq_shuffle(
104105
torch::Tensor q_weight,
105-
torch::Tensor q_perm);
106+
torch::Tensor q_perm,
107+
int bit);
106108

107109
void moe_align_block_size(
108110
torch::Tensor topk_ids,

csrc/quantization/gptq/matrix_view.cuh

+123
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,129 @@ public:
146146
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
147147
};
148148

149+
class MatrixView_q2_row
150+
{
151+
public:
152+
const uint32_t* data;
153+
const int height;
154+
const int width;
155+
156+
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
157+
: data(data), height(height), width(width)
158+
{ }
159+
160+
__device__ __forceinline__ int item(int row, int column) const
161+
{
162+
int shift = (column & 0x0f) * 2;
163+
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
164+
}
165+
166+
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
167+
{
168+
int shift = (column & 0x0f) * 2;
169+
uint32_t d = data[row * width / 16 + column / 16] >> shift;
170+
items[0] = d & 0x03;
171+
items[1] = (d >> 2) & 0x03;
172+
}
173+
174+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
175+
{
176+
int shift = (column & 0x0f) * 2;
177+
uint32_t d = data[row * width / 16 + column / 16] >> shift;
178+
items[0] = d & 0x03;
179+
items[1] = (d >> 2) & 0x03;
180+
items[2] = (d >> 4) & 0x03;
181+
items[3] = (d >> 6) & 0x03;
182+
}
183+
};
184+
185+
class MatrixView_q3_row
186+
{
187+
public:
188+
const uint32_t* data;
189+
const int height;
190+
const int width;
191+
192+
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
193+
: data(data), height(height), width(width)
194+
{ }
195+
196+
__device__ __forceinline__ int item(int row, int column) const
197+
{
198+
int z_w = column * 3 / 32;
199+
int z_mod = column & 0x1f;
200+
201+
if (z_mod == 10) {
202+
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
203+
} else if (z_mod == 21) {
204+
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
205+
} else if (z_mod < 10) {
206+
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
207+
} else if (z_mod < 21) {
208+
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
209+
} else {
210+
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
211+
}
212+
}
213+
214+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
215+
{
216+
int shift = (column & 0x1f);
217+
uint32_t d;
218+
if (shift <= 4) {
219+
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
220+
} else if (shift == 8) {
221+
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
222+
} else if (shift <= 16) {
223+
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
224+
} else if (shift == 20) {
225+
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
226+
} else {
227+
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
228+
}
229+
items[0] = d & 0x07;
230+
items[1] = (d >> 3) & 0x07;
231+
items[2] = (d >> 6) & 0x07;
232+
items[3] = (d >> 9) & 0x07;
233+
}
234+
};
235+
236+
class MatrixView_q8_row
237+
{
238+
public:
239+
const uint32_t* data;
240+
const int height;
241+
const int width;
242+
243+
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
244+
: data(data), height(height), width(width)
245+
{ }
246+
247+
__device__ __forceinline__ int item(int row, int column) const
248+
{
249+
int shift = (column & 0x03) * 8;
250+
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
251+
}
252+
253+
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
254+
{
255+
int shift = (column & 0x03) * 8;
256+
uint32_t d = data[row * width / 4 + column / 4] >> shift;
257+
items[0] = d & 0xff;
258+
items[1] = (d >> 8) & 0xff;
259+
}
260+
261+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
262+
{
263+
int shift = (column & 0x03) * 2;
264+
uint32_t d = data[row * width / 4 + column / 4] >> shift;
265+
items[0] = d & 0xff;
266+
items[1] = (d >> 8) & 0xff;
267+
items[2] = (d >> 16) & 0xff;
268+
items[3] = (d >> 24) & 0xff;
269+
}
270+
};
271+
149272
} // namespace gptq
150273
} // namespace vllm
151274
#endif

0 commit comments

Comments
 (0)