Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support int8 KVCacheQuant and W8A8 inference in vllm #1112

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
e08acaa
add llama quant
Aug 14, 2023
387c804
change weight path
Aug 14, 2023
68cd1e0
fix weight load
Aug 15, 2023
ca088d6
merge gate and up matrix
Aug 16, 2023
6bde51e
use FTLlamaRMSNorm
Aug 17, 2023
931e51c
support bitsandbytes int8
Aug 28, 2023
c0c2a4d
llama support bnb 4bit
Aug 30, 2023
3bb6e31
support kv cache quantization
Sep 19, 2023
bc9fada
fix python code
Sep 19, 2023
976874d
merge and reformat
Sep 20, 2023
2c0c311
add int8gemm
Sep 20, 2023
27e3b4b
support int8 inference
Sep 20, 2023
e6f45ff
Reduce alpha,beta unnecessary d2h
sleepcoo Sep 21, 2023
96c10ca
fix weight load
Sep 21, 2023
4be7d83
fix weight load
Sep 22, 2023
be6f7b8
fix ln layer init
Sep 22, 2023
5ffc537
rms norm fusion
Sep 26, 2023
347397c
fix w8a8 linear
Sep 26, 2023
030a100
use same scale across tensors
Sep 26, 2023
2805edc
add ftgemm
Sep 27, 2023
06cfa3f
fix cublas linear
Sep 27, 2023
97b5c69
clean cublass gemm code
Sep 27, 2023
4d5c1a7
code clean
Sep 27, 2023
9176b1f
support generating kv quant parameters and evaluting kv quant models
Sep 27, 2023
9f872d9
modify test functions
Sep 28, 2023
892c589
fix test code
Sep 28, 2023
538947d
fix test attention
Sep 28, 2023
bf3eb58
evaluation support quant
Sep 28, 2023
a0be417
fuse dequant silu and quant
Sep 28, 2023
52af06e
fuse dequant and add residual
Sep 28, 2023
627b766
fuse dequant, add residual, rms_norm and quant
Sep 28, 2023
dfc9572
fuse dequant and pos_encoding
Sep 28, 2023
e025b66
setup for fused kernels
Sep 28, 2023
9eba3c3
fix bugs
Sep 28, 2023
1e60348
add tests for fusion kernels
Oct 9, 2023
eab850d
modify attention kernel test using pytest
Oct 12, 2023
d3735c7
fix quant parameter passing
Oct 16, 2023
3e7874c
fix uncontiguous tensor case
Oct 17, 2023
4ee29a9
add quant, dequant kernel
Oct 17, 2023
b746c0c
optimize layernorm kernel
Oct 17, 2023
8893069
support quant method in examples
Oct 17, 2023
b3bdc50
add python class DequantAddResidualI8RMSNormQuant, DequantPagedAttent…
Oct 17, 2023
219738f
add tests
Oct 17, 2023
074e86b
add w8a8linear without quant and dequant
Oct 17, 2023
d69100d
adjust code for fusion
Oct 17, 2023
3e81f3d
rm obsolete file
Oct 18, 2023
0ea256f
fix llama
Oct 19, 2023
29939aa
remove cutlass dependency
Oct 19, 2023
d8f7d5a
add sq quantized linear
Oct 24, 2023
74bd08f
rm unit test for w8a8 linear
Oct 24, 2023
e9b2fa4
adjust i8 llama weight load
Oct 24, 2023
6f88787
add fusion.py
Oct 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions benchmarks/benchmark_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import argparse
# import asyncio
# import json
import os
# import random
# import time
from typing import List, Tuple, Dict

# import aiohttp
import numpy as np
import pandas as pd
# from transformers import PreTrainedTokenizerBase
# from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm import LLM, SamplingParams, RequestOutput
from mmlu_template import MMLUTemplate

TEMPLATE_REGITRY = {
"mmlu": MMLUTemplate,
}


def sample_requests(
# dataset_path: str,
# num_requests: int,
# tokenizer: PreTrainedTokenizerBase,
dev_data_path: str,
test_data_path: str,
subjects: List[str],
dataset_template: str = "mmlu",
is_analyse: bool = False,
) -> Tuple[List[str], List[str], List[int]]:
# Load the dataset.
nums_questions = []
dataset = []
labels = []
template_class = TEMPLATE_REGITRY[dataset_template]
for subject in subjects:
test_dataset = pd.read_csv(os.path.join(test_data_path, subject + "_test.csv"), header=None)
nums_questions.append(len(test_dataset))
template = template_class(subject, os.path.join(dev_data_path, subject + "_dev.csv"), is_analyse)
for idx in range(len(test_dataset)):
prompt = template.getTemplate(test_dataset, idx)
dataset.append(prompt)
labels.append(test_dataset.iloc[idx, -1])
return dataset, labels, nums_questions


def run_vllm(
requests: List[str],
output_len: int,
model: str,
tokenizer: str,
kv_cache_dtype: str = "int8",
kv_quant_params_path: str = None,
tensor_parallel_size: int = 1,
seed: int = 0,
n: int = 1,
use_beam_search: bool = False,
trust_remote_code: bool = False,
quantmethod: str = None,
) -> List[RequestOutput]:
llm = LLM(
model=model,
tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
kv_cache_dtype=kv_cache_dtype,
kv_quant_params_path=kv_quant_params_path,
quantization = quantmethod
)
for prompt in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)

# FIXME(woosuk): Do use internal method.
return llm._run_engine(use_tqdm=True)


def evalute(
request_outputs: List[RequestOutput],
labels: List[str],
nums_questions: List[int],
subjects: List[str],
dataset_template: str = "mmlu",
) -> Dict[str, float]:
template_class = TEMPLATE_REGITRY[dataset_template]
pred = [template_class.findAnswer(r.outputs[0].text) for r in request_outputs]
ids = np.cumsum(nums_questions)
lhs = 0
accs: List[float] = []
for rhs in ids:
pred_paritition = np.array(pred[lhs: rhs])
labels_partition = np.array(labels[lhs: rhs])
acc = np.mean(pred_paritition == labels_partition)
accs.append(acc)
sub2acc = {sub: acc for sub, acc in zip(subjects, accs)}
return sub2acc


def main(args: argparse.Namespace):
subjects = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
]
dataset, labels, nums_questions = sample_requests(
args.dev_data_path,
args.test_data_path,
subjects,
is_analyse=args.is_analyse
)
request_outputs = run_vllm(
dataset,
args.output_len,
args.model,
args.tokenizer,
args.kv_cache_dtype,
args.kv_quant_params_path,
args.tensor_parallel_size,
args.seed, args.n,
args.use_beam_search,
args.trust_remote_code,
args.quantization
)
sub2acc = evalute(
request_outputs,
labels,
nums_questions,
subjects,
)
print(sub2acc)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="evaluation for quantization.")

parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument("--dev-data-path",
type=str,
default=None,
help="path to few-shot dataset")
parser.add_argument("--test-data-path",
type=str,
default=None,
help="path to test dataset")
parser.add_argument("--is-analyse",
action="store_true")
parser.add_argument("--output-len",
type=int,
default=100,
help="nums of max token for evaluation outputs")
parser.add_argument("--kv-cache-dtype",
type=str,
default="float16")
parser.add_argument("--kv-quant-params-path",
type=str,
default=None)
parser.add_argument("--quantization",
type=str,
default=None)
args = parser.parse_args()
main(args)
119 changes: 119 additions & 0 deletions benchmarks/mmlu_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pandas as pd
import json
from langchain.prompts import PromptTemplate

template = PromptTemplate(
input_variables=["question", "A", "B", "C", "D", "Answer"],
template=
"""
USER: {question}
A. {A}
B. {B}
C. {C}
D. {D} ASSISTANT: Answer: {Answer}</s>
""",
)

template_with_analyse = PromptTemplate(
input_variables=["question", "A", "B", "C", "D"],
template=
"""
Q:{question}
(A) {A} (B) {B} (C) {C} (D) {D}
A: Let's think step by step.
""",
)


def gen_prompt(train_df, subject, k=1):
prompt = "SYSTEM: The following are multiple choice questions (with answers) about {}," \
"Please select the correct answer from the options.".format(subject.replace('_', ' '))

for i in range(k):
prompt += template.format(question=train_df.iloc[i, 0],
A=train_df.iloc[i, 1],
B=train_df.iloc[i, 2],
C=train_df.iloc[i, 3],
D=train_df.iloc[i, 4],
Answer=train_df.iloc[i, 5]
)[1:-1]
return prompt


## add an abstract base class or common base class for generality
class MMLUTemplate():

def __init__(self, subject, file_path, is_analyse):
self.fiveShotTemplate = ""
self.file_path = file_path
self.subject = subject
self.choices = ["A", "B", "C", "D"]
self.is_analyse = is_analyse
self.few_shot_template = ""
if not is_analyse:
self.getFewShotBaseTemplates()
else:
self.getFewShotBaseTemplateAnalyse()

def getFewShotBaseTemplates(self, k=5):
"""few_shot模板不带分析"""
dev_df = pd.read_csv(self.file_path, header=None)

self.few_shot_template = gen_prompt(dev_df, self.subject, k)
return self.few_shot_template

def getFewShotBaseTemplateAnalyse(self):
"""few_shot模板带分析,更改json文件就行"""
mmlu_prompt = json.load(open('templates/lib_prompt/mmlu-cot.json'))
self.few_shot_template = mmlu_prompt[self.subject]
return self.few_shot_template

def getTemplate(self, test_df, i):
"""获得模板"""
if self.is_analyse:
templ = template_with_analyse.format(
question=test_df.iloc[i, 0],
A=test_df.iloc[i, 1],
B=test_df.iloc[i, 2],
C=test_df.iloc[i, 3],
D=test_df.iloc[i, 4]
)

return self.few_shot_template + "\n" + templ

else:
prompt_end = template.format(
question=test_df.iloc[i, 0],
A=test_df.iloc[i, 1],
B=test_df.iloc[i, 2],
C=test_df.iloc[i, 3],
D=test_df.iloc[i, 4],
Answer='')[1:-5]
return self.few_shot_template + prompt_end
@staticmethod
def findAnswer(res):
"""解析函数"""
# print("模型输出为:", res)
d = "NO"
for d_ in res:
if 65 <= ord(d_) <= 68:
d = d_
break
# print("答案解析为:", d)
return d

@staticmethod
def findAnwerUsingRule(res):
# print("模型输出为:", res)
result = "NO"
pattern = 'the answer is ('
try:
pred = res.lower().split(pattern)[1][0]

if 65 <= ord(pred.upper()) <= 68:
result = pred.upper()
except:
pass

# print("答案解析为:",result)
return result
33 changes: 12 additions & 21 deletions csrc/activation.cpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
#include <torch/extension.h>

void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void silu_and_mul(torch::Tensor &out, torch::Tensor &input);

void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(torch::Tensor &out, torch::Tensor &input);

void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(torch::Tensor &out, torch::Tensor &input);

void invoke_dequant_silu_and_mul_quant(torch::Tensor &out, torch::Tensor &input,
const float scale_gate,
const float scale_up,
const float scale_out);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
m.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
m.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
m.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
m.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
m.def("invoke_dequant_silu_and_mul_quant", &invoke_dequant_silu_and_mul_quant, "Dequant input, apply silu act and quant output");
}
Loading