Skip to content

Commit

Permalink
Refactor the runner
Browse files Browse the repository at this point in the history
  • Loading branch information
aoyulong committed Oct 15, 2024
1 parent d5ea05a commit eebc703
Show file tree
Hide file tree
Showing 14 changed files with 691 additions and 596 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ experiment:
backend: vllm
entrypoint: ./flagscale/inference/inference_aquila.py
runner:
hostfile: /share/project/zhaoyingli/hostfile
hostfile: xxxx
cmds:
before_start: source /root/miniconda3/bin/activate flagscale
envs:
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7
CUDA_VISIBLE_DEVICES: 0
CUDA_DEVICE_MAX_CONNECTIONS: 1

action: run
Expand Down
15 changes: 6 additions & 9 deletions examples/aquila/conf/inference/inference_aquila_7b.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
engine:
model: BAAI/Aquila-7B/
tokenizer: BAAI/Aquila-7B/
llm:
model: xxxx
trust_remote_code: true
tensor_parallel_size: 1
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
dtype: bfloat16
seed: 1234

data:
generate:
prompts: [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# prompts_path: null
top_p: 0.95
top_k: 100
max_tokens: 7
temperature: 0.9
sampling:
top_p: 0.95
temperature: 0.8
8 changes: 4 additions & 4 deletions flagscale/auto_tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from omegaconf import DictConfig, OmegaConf

from flagscale.launcher.job_status import JobStatus
from flagscale.launcher.runner import SSHRunner
from flagscale.runner.runner_base import JobStatus
from flagscale.runner.runner_train import SSHTrainRunner

from .generate import Generator
from .platform import set_jiuding_platform_args
Expand Down Expand Up @@ -160,7 +160,7 @@ def tune(self):
raise ValueError(f"No strategy can run.")
best_task = self.generator.gen_best_task(best_strategy, self.orig_config)
best_task.action = "run"
runner = SSHRunner(best_task)
runner = SSHTrainRunner(best_task)
runner.run(monitor=True, interval=60)

def need_stop(self):
Expand Down Expand Up @@ -213,7 +213,7 @@ def run(self, task=None):
# Instantiate a runner and run the task
if task is None:
task = self.cur_task
self.runner = SSHRunner(task)
self.runner = SSHTrainRunner(task)
self.runner.run()
# set start time
self.task_start_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion flagscale/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import subprocess
from types import SimpleNamespace

from flagscale.launcher.runner import parse_hostfile
from flagscale.runner.runner import parse_hostfile


def divisible(x, y):
Expand Down
179 changes: 51 additions & 128 deletions flagscale/inference/inference_aquila.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,63 @@
import os
import yaml
import argparse
from typing import List, Union

import torch

from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams

from arguments import parse_args


def process_requests(prompts: List[str],
engine: LLMEngine,
sampling_params: SamplingParams):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while prompts:
prompt = prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1
from omegaconf import OmegaConf, ListConfig
from vllm import LLM, SamplingParams


def get_config():
parser = argparse.ArgumentParser()
parser.add_argument("--config-path", type=str, required=True, help="Path to the configuration YAML file")
args = parser.parse_args()

config_path = args.config_path
# Open the YAML file and convert it into a dictionary
with open(config_path, 'r') as file:
config_dict = yaml.safe_load(file)

# Convert the dictionary into a DictConfig
config = OmegaConf.create(config_dict)
return config


def get_prompts(prompts):
print(prompts, type(prompts))
if isinstance(prompts, str) and os.path.isfile(prompts):
with open(prompts, 'r') as file:
return [line.strip() for line in file.readlines()]
elif isinstance(prompts, (list, ListConfig)):
return prompts
else:
raise ValueError("Prompts should be either a list of strings or a path to a file containing a list of strings.")

outputs: List[Union[RequestOutput]] = []
while engine.has_unfinished_requests():
step_outputs = engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)

outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs
def inference():
# Get the configuration.
config = get_config()

# Get the prompts.
prompts = get_prompts(config.generate.prompts)

def inference(args: argparse.Namespace, prompts: List[str]):
"""Initialize the LLMEngine"""
engine_args = EngineArgs.from_cli_args(args)
llm_engine = LLMEngine.from_engine_args(engine_args)
# Create a sampling params object.
sampling_args = config.get("sampling", {})
sampling_params = SamplingParams(**sampling_args)

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
llm_engine.tokenizer.tokenizer = tokenizer
# Create an LLM.
llm_args = config.get("llm", {})
model = llm_args.pop("model", None)
assert model is not None
llm = LLM(model, **llm_args)

"""Initialize the SamplingParams"""
sampling_params = SamplingParams(
n=args.n,
best_of=args.best_of,
frequency_penalty=args.frequency_penalty,
repetition_penalty=args.repetition_penalty,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
seed=args.seed,
use_beam_search=args.use_beam_search,
length_penalty=args.length_penalty,
early_stopping=args.early_stopping,
stop=args.stop,
stop_token_ids=args.stop_token_ids,
include_stop_str_in_output=args.include_stop_str_in_output,
ignore_eos=args.ignore_eos,
max_tokens=args.max_tokens,
min_tokens=args.min_tokens,
logprobs=args.logprobs,
prompt_logprobs=args.prompt_logprobs,
detokenize=args.detokenize,
skip_special_tokens=args.skip_special_tokens,
spaces_between_special_tokens=args.spaces_between_special_tokens,
# logits_processors=,
# truncate_prompt_tokens=,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)

outputs = process_requests(prompts, llm_engine, sampling_params)
# Print the outputs.
for output in outputs:
print("\n")
print("="*50)
print("=> RequestOutput:", output)
token_ids = output.outputs[0].token_ids
print("=> generated text:", tokenizer.decode(token_ids))


def generate(args: argparse.Namespace, prompts: List[str]):

model = LlamaForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True
).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)

for prompt in prompts:
print("\n")
print("="*50)
print("=> prompt:", prompt)
tokens = tokenizer.encode_plus(prompt)["input_ids"]
tokens = torch.tensor(tokens)[None,].to(model.device)
input_length = len(tokens[0])
generation_config = GenerationConfig(
do_sample=True,
eos_token_id=tokenizer.convert_tokens_to_ids('<|extra_204|>'),
pad_token_id=tokenizer.convert_tokens_to_ids('<|endoftext|>'),
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
out = model.generate(
tokens,
generation_config,
return_dict_in_generate=True,
output_scores=True,
)
out_ids = out["sequences"][0][input_length:].cpu().numpy()
out_text = tokenizer.decode(out_ids.tolist())
print("=> generated text:", out_text)
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
args = parse_args()

prompts = []
if args.prompts_path is not None:
with open(args.prompts_path, "r") as f:
while True:
prompt = f.readline()
if not prompt:
break
prompts.append(prompt[:-1]) # remove the last '\n' of prompt
elif len(args.prompts) > 1:
prompts = args.prompts
else:
raise ValueError("Pleace set right prompts_path or prompts data.")

"""
vllm inference
"""
inference(args, prompts)

"""
transformers inference
"""
# generate(args, prompts)
# Run the inference
inference()
7 changes: 0 additions & 7 deletions flagscale/launcher/job_status.py

This file was deleted.

File renamed without changes.
22 changes: 22 additions & 0 deletions flagscale/runner/runner_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from omegaconf import DictConfig
from enum import Enum


class JobStatus(Enum):
RUNNING = "Running"
TRANSITIONAL = "Transitional (Stopping or Starting)"
COMPLETED_OR_IDLE = "Completed or Not Started"


class RunnerBase(ABC):
def __init__(self, config: DictConfig):
self.config = config

@abstractmethod
def run(self, *args, **kwargs):
raise NotImplementedError

def stop(self, *args, **kwargs):
"""Optional method to override."""
pass
Loading

0 comments on commit eebc703

Please sign in to comment.