Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xFastTransformer framework support #2615

Merged
merged 11 commits into from
Nov 3, 2023
90 changes: 90 additions & 0 deletions docs/xFasterTransformer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# xFasterTransformer Inference Framework

Integrated [xFasterTransformer](https://github.com/intel/xFasterTransformer) customized framework into Fastchat to provide **Faster** inference speed on Intel CPU.

## Install xFasterTransformer

Setup environment (please refer to [this link](https://github.com/intel/xFasterTransformer#installation) for more details):

```bash
pip install xfastertransformer
```

## Prepare models

Prepare Model (please refer to [this link](https://github.com/intel/xFasterTransformer#prepare-model) for more details):
```bash
python ./tools/chatglm_convert.py -i ${HF_DATASET_DIR} -o ${OUTPUT_DIR}
```

## Parameters of xFasterTransformer
--enable-xft to enable xfastertransformer in Fastchat
--xft-max-seq-len to set the max token length the model can process. max token length include input token length.
--xft-dtype to set datatype used in xFasterTransformer for computation. xFasterTransformer can support fp32, fp16, int8, bf16 and hybrid data types like : bf16_fp16, bf16_int8. For datatype details please refer to [this link](https://github.com/intel/xFasterTransformer/wiki/Data-Type-Support-Platform)


Chat with the CLI:
```bash
#run inference on all CPUs and using float16
python3 -m fastchat.serve.cli \
--model-path /path/to/models \
--enable-xft \
--xft-dtype fp16
```
or with numactl on multi-socket server for better performance
```bash
#run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16)
numactl -N 0 --localalloc \
python3 -m fastchat.serve.cli \
--model-path /path/to/models/chatglm2_6b_cpu/ \
--enable-xft \
--xft-dtype bf16_fp16
```
or using MPI to run inference on 2 sockets for better performance
```bash
#run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16)
OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \
-n 1 numactl -N 0 --localalloc \
python -m fastchat.serve.cli \
--model-path /path/to/models/chatglm2_6b_cpu/ \
--enable-xft \
--xft-dtype bf16_fp16 : \
-n 1 numactl -N 1 --localalloc \
python -m fastchat.serve.cli \
--model-path /path/to/models/chatglm2_6b_cpu/ \
--enable-xft \
--xft-dtype bf16_fp16
```


Start model worker:
```bash
# Load model with default configuration (max sequence length 4096, no GPU split setting).
python3 -m fastchat.serve.model_worker \
--model-path /path/to/models \
--enable-xft \
--xft-dtype bf16_fp16
```
or with numactl on multi-socket server for better performance
```bash
#run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16)
numactl -N 0 --localalloc python3 -m fastchat.serve.model_worker \
--model-path /path/to/models \
--enable-xft \
--xft-dtype bf16_fp16
```
or using MPI to run inference on 2 sockets for better performance
```bash
#run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16)
OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \
-n 1 numactl -N 0 --localalloc python -m fastchat.serve.model_worker \
--model-path /path/to/models \
--enable-xft \
--xft-dtype bf16_fp16 : \
-n 1 numactl -N 1 --localalloc python -m fastchat.serve.model_worker \
--model-path /path/to/models \
--enable-xft \
--xft-dtype bf16_fp16
```

For more details, please refer to [this link](https://github.com/intel/xFasterTransformer#how-to-run)
27 changes: 27 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
from fastchat.model.model_exllama import generate_stream_exllama
from fastchat.model.model_xfastertransformer import generate_stream_xft
from fastchat.model.monkey_patch_non_inplace import (
replace_llama_attn_with_non_inplace_operations,
)
from fastchat.modules.awq import AWQConfig, load_awq_quantized
from fastchat.modules.exllama import ExllamaConfig, load_exllama_model
from fastchat.modules.xfastertransformer import XftConfig, load_xft_model
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
from fastchat.utils import get_gpu_memory

Expand Down Expand Up @@ -170,6 +172,7 @@ def load_model(
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
revision: str = "main",
debug: bool = False,
):
Expand Down Expand Up @@ -297,6 +300,9 @@ def load_model(
elif exllama_config:
model, tokenizer = load_exllama_model(model_path, exllama_config)
return model, tokenizer
elif xft_config:
model, tokenizer = load_xft_model(model_path, xft_config)
return model, tokenizer
kwargs["revision"] = revision

if dtype is not None: # Overwrite dtype if it is provided in the arguments.
Expand Down Expand Up @@ -344,6 +350,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
is_codet5p = "codet5p" in model_type
is_peft = "peft" in model_type
is_exllama = "exllama" in model_type
is_xft = "xft" in model_type

if is_chatglm:
return generate_stream_chatglm
Expand All @@ -353,6 +360,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
return generate_stream_codet5p
elif is_exllama:
return generate_stream_exllama
elif is_xft:
return generate_stream_xft

elif peft_share_base_weights and is_peft:
# Return a curried stream function that loads the right adapter
Expand Down Expand Up @@ -492,6 +501,24 @@ def add_model_args(parser):
default=None,
help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7",
)
parser.add_argument(
"--enable-xft",
action="store_true",
help="Used for xFasterTransformer Enable xFasterTransformer inference framework.",
)
parser.add_argument(
"--xft-max-seq-len",
type=int,
default=4096,
help="Used for xFasterTransformer. Max sequence length to use for xFasterTransformer framework; default 4096 sequence length.",
)
parser.add_argument(
"--xft-dtype",
type=str,
choices=["fp16", "bf16", "int8", "bf16_fp16", "bf16_int8"],
help="Override the default dtype. If not set, it will use bfloat16 for first token and float16 next tokens on CPU.",
default=None,
)


def remove_parent_directory_name(model_path):
Expand Down
80 changes: 80 additions & 0 deletions fastchat/model/model_xfastertransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from transformers import TextIteratorStreamer
from threading import Thread
import gc


@torch.inference_mode()
def generate_stream_xft(
model,
tokenizer,
params,
device,
context_len=8192,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
repetition_penalty = float(params.get("repetition_penalty", 1.0))

# unused now, and placehold for future.
# temperature = float(params.get("temperature", 1.0))
# top_p = float(params.get("top_p", 1.0))

max_new_tokens = int(params.get("max_new_tokens", 4096))
echo = params.get("echo", True)

inputs = tokenizer(
prompt, return_tensors="pt", padding=model.config.padding
).input_ids
input_echo_len = len(inputs[0])
max_len = max_new_tokens + input_echo_len

decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
generation_kwargs = {
"input_ids": inputs,
"streamer": streamer,
"max_length": max_len,
"num_beams": model.config.beam_width,
"length_penalty": repetition_penalty,
"num_return_sequences": model.config.num_return_sequences,
"early_stopping": model.config.early_stopping,
"eos_token_id": model.config.eos_token_id,
"pad_token_id": model.config.pad_token_id,
}

thread = Thread(target=model.model.generate, kwargs=generation_kwargs)
thread.start()
if echo:
# means keep the prompt
output = prompt
else:
output = ""
i = 0
for i, new_text in enumerate(streamer):
output += new_text
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
output = output.strip()
if i == max_new_tokens - 1:
finish_reason = "length"
else:
finish_reason = "stop"
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
gc.collect()
46 changes: 46 additions & 0 deletions fastchat/modules/xfastertransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
import sys


@dataclass
class XftConfig:
max_seq_len: int = 4096
beam_width: int = 1
eos_token_id: int = -1
pad_token_id: int = -1
num_return_sequences: int = 1
is_encoder_decoder: bool = False
padding: bool = True
early_stopping: bool = False
data_type: str = "bf16_fp16"


class XftModel:
def __init__(self, xft_model, xft_config):
self.model = xft_model
self.config = xft_config


def load_xft_model(model_path, xft_config: XftConfig):
try:
import xfastertransformer
from transformers import AutoTokenizer
except ImportError as e:
print(f"Error: Failed to load xFasterTransformer. {e}")
sys.exit(-1)

if xft_config.data_type is None or xft_config.data_type == "":
data_type = "bf16_fp16"
else:
data_type = xft_config.data_type
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, padding_side="left", trust_remote_code=True
)
xft_model = xfastertransformer.AutoModel.from_pretrained(
model_path, dtype=data_type
)
model = XftModel(xft_model=xft_model, xft_config=xft_config)
if model.model.rank > 0:
while True:
model.model.generate()
return model, tokenizer
12 changes: 12 additions & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from fastchat.model.model_adapter import add_model_args
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.xfastertransformer import XftConfig
from fastchat.modules.gptq import GptqConfig
from fastchat.serve.inference import ChatIO, chat_loop
from fastchat.utils import str_to_torch_dtype
Expand Down Expand Up @@ -203,6 +204,16 @@ def main(args):
)
else:
exllama_config = None
if args.enable_xft:
xft_config = XftConfig(
max_seq_len=args.xft_max_seq_len,
data_type=args.xft_dtype,
)
if args.device != "cpu":
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
args.device = "cpu"
else:
xft_config = None
if args.style == "simple":
chatio = SimpleChatIO(args.multiline)
elif args.style == "rich":
Expand Down Expand Up @@ -238,6 +249,7 @@ def main(args):
groupsize=args.awq_groupsize,
),
exllama_config=exllama_config,
xft_config=xft_config,
revision=args.revision,
judge_sent_end=args.judge_sent_end,
debug=args.debug,
Expand Down
4 changes: 4 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from fastchat.modules.awq import AWQConfig
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.xfastertransformer import XftConfig
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length


Expand Down Expand Up @@ -304,6 +305,7 @@ def chat_loop(
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
revision: str = "main",
judge_sent_end: bool = True,
debug: bool = True,
Expand All @@ -321,6 +323,7 @@ def chat_loop(
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
xft_config=xft_config,
revision=revision,
debug=debug,
)
Expand All @@ -329,6 +332,7 @@ def chat_loop(
model_type = str(type(model)).lower()
is_t5 = "t5" in model_type
is_codet5p = "codet5p" in model_type
is_xft = "xft" in model_type

# Hardcode T5's default repetition penalty to be 1.2
if is_t5 and repetition_penalty == 1.0:
Expand Down
Loading