forked from FlagOpen/FlagScale
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
691 additions
and
596 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,19 @@ | ||
engine: | ||
model: BAAI/Aquila-7B/ | ||
tokenizer: BAAI/Aquila-7B/ | ||
llm: | ||
model: xxxx | ||
trust_remote_code: true | ||
tensor_parallel_size: 1 | ||
pipeline_parallel_size: 1 | ||
gpu_memory_utilization: 0.6 | ||
dtype: bfloat16 | ||
seed: 1234 | ||
|
||
data: | ||
generate: | ||
prompts: [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
# prompts_path: null | ||
top_p: 0.95 | ||
top_k: 100 | ||
max_tokens: 7 | ||
temperature: 0.9 | ||
sampling: | ||
top_p: 0.95 | ||
temperature: 0.8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,140 +1,63 @@ | ||
import os | ||
import yaml | ||
import argparse | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig | ||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams | ||
|
||
from arguments import parse_args | ||
|
||
|
||
def process_requests(prompts: List[str], | ||
engine: LLMEngine, | ||
sampling_params: SamplingParams): | ||
"""Continuously process a list of prompts and handle the outputs.""" | ||
request_id = 0 | ||
while prompts: | ||
prompt = prompts.pop(0) | ||
engine.add_request(str(request_id), prompt, sampling_params) | ||
request_id += 1 | ||
from omegaconf import OmegaConf, ListConfig | ||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def get_config(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config-path", type=str, required=True, help="Path to the configuration YAML file") | ||
args = parser.parse_args() | ||
|
||
config_path = args.config_path | ||
# Open the YAML file and convert it into a dictionary | ||
with open(config_path, 'r') as file: | ||
config_dict = yaml.safe_load(file) | ||
|
||
# Convert the dictionary into a DictConfig | ||
config = OmegaConf.create(config_dict) | ||
return config | ||
|
||
|
||
def get_prompts(prompts): | ||
print(prompts, type(prompts)) | ||
if isinstance(prompts, str) and os.path.isfile(prompts): | ||
with open(prompts, 'r') as file: | ||
return [line.strip() for line in file.readlines()] | ||
elif isinstance(prompts, (list, ListConfig)): | ||
return prompts | ||
else: | ||
raise ValueError("Prompts should be either a list of strings or a path to a file containing a list of strings.") | ||
|
||
outputs: List[Union[RequestOutput]] = [] | ||
while engine.has_unfinished_requests(): | ||
step_outputs = engine.step() | ||
for output in step_outputs: | ||
if output.finished: | ||
outputs.append(output) | ||
|
||
outputs = sorted(outputs, key=lambda x: int(x.request_id)) | ||
return outputs | ||
def inference(): | ||
# Get the configuration. | ||
config = get_config() | ||
|
||
# Get the prompts. | ||
prompts = get_prompts(config.generate.prompts) | ||
|
||
def inference(args: argparse.Namespace, prompts: List[str]): | ||
"""Initialize the LLMEngine""" | ||
engine_args = EngineArgs.from_cli_args(args) | ||
llm_engine = LLMEngine.from_engine_args(engine_args) | ||
# Create a sampling params object. | ||
sampling_args = config.get("sampling", {}) | ||
sampling_params = SamplingParams(**sampling_args) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) | ||
llm_engine.tokenizer.tokenizer = tokenizer | ||
# Create an LLM. | ||
llm_args = config.get("llm", {}) | ||
model = llm_args.pop("model", None) | ||
assert model is not None | ||
llm = LLM(model, **llm_args) | ||
|
||
"""Initialize the SamplingParams""" | ||
sampling_params = SamplingParams( | ||
n=args.n, | ||
best_of=args.best_of, | ||
frequency_penalty=args.frequency_penalty, | ||
repetition_penalty=args.repetition_penalty, | ||
temperature=args.temperature, | ||
top_p=args.top_p, | ||
top_k=args.top_k, | ||
min_p=args.min_p, | ||
seed=args.seed, | ||
use_beam_search=args.use_beam_search, | ||
length_penalty=args.length_penalty, | ||
early_stopping=args.early_stopping, | ||
stop=args.stop, | ||
stop_token_ids=args.stop_token_ids, | ||
include_stop_str_in_output=args.include_stop_str_in_output, | ||
ignore_eos=args.ignore_eos, | ||
max_tokens=args.max_tokens, | ||
min_tokens=args.min_tokens, | ||
logprobs=args.logprobs, | ||
prompt_logprobs=args.prompt_logprobs, | ||
detokenize=args.detokenize, | ||
skip_special_tokens=args.skip_special_tokens, | ||
spaces_between_special_tokens=args.spaces_between_special_tokens, | ||
# logits_processors=, | ||
# truncate_prompt_tokens=, | ||
) | ||
# Generate texts from the prompts. | ||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
outputs = process_requests(prompts, llm_engine, sampling_params) | ||
# Print the outputs. | ||
for output in outputs: | ||
print("\n") | ||
print("="*50) | ||
print("=> RequestOutput:", output) | ||
token_ids = output.outputs[0].token_ids | ||
print("=> generated text:", tokenizer.decode(token_ids)) | ||
|
||
|
||
def generate(args: argparse.Namespace, prompts: List[str]): | ||
|
||
model = LlamaForCausalLM.from_pretrained( | ||
args.model, | ||
torch_dtype=torch.bfloat16, | ||
attn_implementation="flash_attention_2", | ||
trust_remote_code=True | ||
).to('cuda') | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) | ||
|
||
for prompt in prompts: | ||
print("\n") | ||
print("="*50) | ||
print("=> prompt:", prompt) | ||
tokens = tokenizer.encode_plus(prompt)["input_ids"] | ||
tokens = torch.tensor(tokens)[None,].to(model.device) | ||
input_length = len(tokens[0]) | ||
generation_config = GenerationConfig( | ||
do_sample=True, | ||
eos_token_id=tokenizer.convert_tokens_to_ids('<|extra_204|>'), | ||
pad_token_id=tokenizer.convert_tokens_to_ids('<|endoftext|>'), | ||
max_new_tokens=args.max_tokens, | ||
temperature=args.temperature, | ||
top_k=args.top_k, | ||
top_p=args.top_p, | ||
) | ||
out = model.generate( | ||
tokens, | ||
generation_config, | ||
return_dict_in_generate=True, | ||
output_scores=True, | ||
) | ||
out_ids = out["sequences"][0][input_length:].cpu().numpy() | ||
out_text = tokenizer.decode(out_ids.tolist()) | ||
print("=> generated text:", out_text) | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
|
||
prompts = [] | ||
if args.prompts_path is not None: | ||
with open(args.prompts_path, "r") as f: | ||
while True: | ||
prompt = f.readline() | ||
if not prompt: | ||
break | ||
prompts.append(prompt[:-1]) # remove the last '\n' of prompt | ||
elif len(args.prompts) > 1: | ||
prompts = args.prompts | ||
else: | ||
raise ValueError("Pleace set right prompts_path or prompts data.") | ||
|
||
""" | ||
vllm inference | ||
""" | ||
inference(args, prompts) | ||
|
||
""" | ||
transformers inference | ||
""" | ||
# generate(args, prompts) | ||
# Run the inference | ||
inference() |
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from abc import ABC, abstractmethod | ||
from omegaconf import DictConfig | ||
from enum import Enum | ||
|
||
|
||
class JobStatus(Enum): | ||
RUNNING = "Running" | ||
TRANSITIONAL = "Transitional (Stopping or Starting)" | ||
COMPLETED_OR_IDLE = "Completed or Not Started" | ||
|
||
|
||
class RunnerBase(ABC): | ||
def __init__(self, config: DictConfig): | ||
self.config = config | ||
|
||
@abstractmethod | ||
def run(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
def stop(self, *args, **kwargs): | ||
"""Optional method to override.""" | ||
pass |
Oops, something went wrong.