Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit db726e6

Browse files
authored
Merge pull request #1 from afeldman-nm/enc_dec_t5
T5 enc/dec example file; linting/formatting
2 parents 2fb6905 + 42a6e2b commit db726e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+4666
-278
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
7878
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
7979
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
8080
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
81+
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
8182
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
8283

8384
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):

csrc/ops.h

+13-2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ torch::Tensor awq_dequantize(
8686
int split_k_iters,
8787
int thx,
8888
int thy);
89+
90+
torch::Tensor marlin_gemm(
91+
torch::Tensor& a,
92+
torch::Tensor& b_q_weight,
93+
torch::Tensor& b_scales,
94+
torch::Tensor& workspace,
95+
int64_t size_m,
96+
int64_t size_n,
97+
int64_t size_k);
8998
#endif
9099

91100
void squeezellm_gemm(
@@ -100,11 +109,13 @@ torch::Tensor gptq_gemm(
100109
torch::Tensor b_gptq_qzeros,
101110
torch::Tensor b_gptq_scales,
102111
torch::Tensor b_g_idx,
103-
bool use_exllama);
112+
bool use_exllama,
113+
int bit);
104114

105115
void gptq_shuffle(
106116
torch::Tensor q_weight,
107-
torch::Tensor q_perm);
117+
torch::Tensor q_perm,
118+
int bit);
108119

109120
void moe_align_block_size(
110121
torch::Tensor topk_ids,

csrc/pybind.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5252
&rotary_embedding,
5353
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
5454

55-
// Quantization ops
55+
// Quantization ops
5656
#ifndef USE_ROCM
5757
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
58+
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
5859
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
5960
#endif
61+
6062
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
6163
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
6264
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");

csrc/quantization/gptq/matrix_view.cuh

+123
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,129 @@ public:
146146
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
147147
};
148148

149+
class MatrixView_q2_row
150+
{
151+
public:
152+
const uint32_t* data;
153+
const int height;
154+
const int width;
155+
156+
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
157+
: data(data), height(height), width(width)
158+
{ }
159+
160+
__device__ __forceinline__ int item(int row, int column) const
161+
{
162+
int shift = (column & 0x0f) * 2;
163+
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
164+
}
165+
166+
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
167+
{
168+
int shift = (column & 0x0f) * 2;
169+
uint32_t d = data[row * width / 16 + column / 16] >> shift;
170+
items[0] = d & 0x03;
171+
items[1] = (d >> 2) & 0x03;
172+
}
173+
174+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
175+
{
176+
int shift = (column & 0x0f) * 2;
177+
uint32_t d = data[row * width / 16 + column / 16] >> shift;
178+
items[0] = d & 0x03;
179+
items[1] = (d >> 2) & 0x03;
180+
items[2] = (d >> 4) & 0x03;
181+
items[3] = (d >> 6) & 0x03;
182+
}
183+
};
184+
185+
class MatrixView_q3_row
186+
{
187+
public:
188+
const uint32_t* data;
189+
const int height;
190+
const int width;
191+
192+
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
193+
: data(data), height(height), width(width)
194+
{ }
195+
196+
__device__ __forceinline__ int item(int row, int column) const
197+
{
198+
int z_w = column * 3 / 32;
199+
int z_mod = column & 0x1f;
200+
201+
if (z_mod == 10) {
202+
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
203+
} else if (z_mod == 21) {
204+
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
205+
} else if (z_mod < 10) {
206+
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
207+
} else if (z_mod < 21) {
208+
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
209+
} else {
210+
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
211+
}
212+
}
213+
214+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
215+
{
216+
int shift = (column & 0x1f);
217+
uint32_t d;
218+
if (shift <= 4) {
219+
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
220+
} else if (shift == 8) {
221+
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
222+
} else if (shift <= 16) {
223+
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
224+
} else if (shift == 20) {
225+
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
226+
} else {
227+
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
228+
}
229+
items[0] = d & 0x07;
230+
items[1] = (d >> 3) & 0x07;
231+
items[2] = (d >> 6) & 0x07;
232+
items[3] = (d >> 9) & 0x07;
233+
}
234+
};
235+
236+
class MatrixView_q8_row
237+
{
238+
public:
239+
const uint32_t* data;
240+
const int height;
241+
const int width;
242+
243+
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
244+
: data(data), height(height), width(width)
245+
{ }
246+
247+
__device__ __forceinline__ int item(int row, int column) const
248+
{
249+
int shift = (column & 0x03) * 8;
250+
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
251+
}
252+
253+
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
254+
{
255+
int shift = (column & 0x03) * 8;
256+
uint32_t d = data[row * width / 4 + column / 4] >> shift;
257+
items[0] = d & 0xff;
258+
items[1] = (d >> 8) & 0xff;
259+
}
260+
261+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
262+
{
263+
int shift = (column & 0x03) * 2;
264+
uint32_t d = data[row * width / 4 + column / 4] >> shift;
265+
items[0] = d & 0xff;
266+
items[1] = (d >> 8) & 0xff;
267+
items[2] = (d >> 16) & 0xff;
268+
items[3] = (d >> 24) & 0xff;
269+
}
270+
};
271+
149272
} // namespace gptq
150273
} // namespace vllm
151274
#endif

0 commit comments

Comments
 (0)