Skip to content

Commit

Permalink
Refactor to have BackEnd classes
Browse files Browse the repository at this point in the history
  • Loading branch information
gsolard committed Jul 30, 2024
1 parent c2747b7 commit 3aff1e8
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 320 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
![Build & Tests](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/build_and_tests.yaml/badge.svg)
![Wheel setup](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/wheel.yaml/badge.svg)

benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, it is focused on LLMs served via [vllm](https://github.com/vllm-project/vllm) and more specifically via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables.

benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, two backends are possible : `mistral` and [vLLM](https://github.com/vllm-project/vllm) (via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables).
## Installation

It is advised to clone the repository in order to get the datasets used for the benchmarks (you can find them in `src/benchmark_llm_serving/datasets`) and build it from source:
Expand Down
268 changes: 158 additions & 110 deletions src/benchmark_llm_serving/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,149 +3,197 @@
from benchmark_llm_serving.io_classes import QueryOutput, QueryInput


IMPLEMENTED_BACKENDS = "'happy_vllm', 'mistral'"
class BackEnd():

TEMPERATURE = 0
REPETITION_PENALTY = 1.2

def get_payload(query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
def __init__(self, backend_name: str, chunk_prefix: str = "data: ", last_chunk: str = "[DONE]"):
self.backend_name = backend_name
self.chunk_prefix = chunk_prefix
self.last_chunk = last_chunk

Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args
def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
Returns:
dict : The payload
"""
temperature = 0
repetition_penalty = 1.2
if args.backend == "happy_vllm":
Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args
Returns:
dict : The payload
"""
raise NotImplemented("The subclass should implement this method")

def get_newly_generated_text(self, json_chunk: str) -> str:
"""Gets the newly generated text
Args:
json_chunk (dict) : The chunk containing the generated text
Returns:
str : The newly generated text
"""
raise NotImplemented("The subclass should implement this method")

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.
Args:
chunk (str) : The chunk to consider
Returns:
bool : Whether the chunk is valid or not
"""
return True

def get_completions_headers(self) -> dict:
"""Gets the headers (depending on the backend) to use for the request
Returns:
dict: The headers
"""
return {}

def remove_response_prefix(self, chunk: str) -> str:
"""Removes the prefix in the response of a model
Args:
chunk (str) : The chunk received
Returns:
str : The string without the prefix
"""
return chunk.removeprefix(self.chunk_prefix)

def check_end_of_stream(self, chunk: str) -> bool:
"""Checks whether this is the last chunk of the stream
Args:
chunk (str) : The chunk to test
Returns:
bool : Whether it is the last chunk of the stream
"""
return chunk == self.last_chunk

def add_prompt_length(self, json_chunk: dict, output: QueryOutput) -> None:
"""Add the prompt length to the QueryOutput if the key "usage" is in the chunk
Args:
json_chunk (dict) : The chunk containing the prompt length
output (QueryOutput) : The output
"""
if "usage" in json_chunk:
if json_chunk['usage'] is not None:
output.prompt_length = json_chunk['usage']['prompt_tokens']



class BackendHappyVllm(BackEnd):

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args
Returns:
dict : The payload
"""
return {"prompt": query_input.prompt,
"model": args.model,
"max_tokens": args.output_length,
"min_tokens": args.output_length,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"temperature": self.TEMPERATURE,
"repetition_penalty": self.REPETITION_PENALTY,
"stream": True,
"stream_options": {"include_usage": True}
}
elif args.backend == "mistral":

def get_newly_generated_text(self, json_chunk: str) -> str:
"""Gets the newly generated text
Args:
json_chunk (dict) : The chunk containing the generated text
Returns:
str : The newly generated text
"""
if len(json_chunk['choices']):
data = json_chunk['choices'][0]['text']
return data
else:
return ""


class BackEndMistral(BackEnd):

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args
Returns:
dict : The payload
"""
return {"messages": [{"role": "user", "content": query_input.prompt}],
"model": args.model,
"max_tokens": args.output_length,
"min_tokens": args.output_length,
"temperature": temperature,
"temperature": self.TEMPERATURE,
"stream": True
}
else:
raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}")

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.
def test_chunk_validity(chunk: str, args: argparse.Namespace) -> bool:
"""Tests if the chunk is valid or should not be considered.
Args:
chunk (str) : The chunk to consider
Args:
chunk (str) : The chunk to consider
args (argparse.Namespace) : The cli args
Returns:
bool : Whether the chunk is valid or not
"""
if args.backend in ["happy_vllm"]:
return True
elif args.backend in ["mistral"]:
Returns:
bool : Whether the chunk is valid or not
"""
if chunk[:4] == "tok-":
return False
else:
return True
else:
raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}")

def get_completions_headers(self) -> dict:
"""Gets the headers (depending on the backend) to use for the request
def get_completions_headers(args: argparse.Namespace) -> dict:
"""Gets the headers (depending on the backend) to use for the request
Returns:
dict: The headers
Args:
args (argparse.Namespace) : The cli args
Returns:
dict: The headers
"""
if args.backend in ["happy_vllm"]:
return {}
elif args.backend == "mistral":
"""
return {"Accept": "application/json",
"Content-Type": "application/json"}
else:
raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}")


def decode_remove_response_prefix(chunk_bytes: bytes, args: argparse.Namespace) -> str:
"""Removes the prefix in the response of a model and converts the bytes in str
def get_newly_generated_text(self, json_chunk: str) -> str:
"""Gets the newly generated text
Args:
chunk_bytes (bytes) : The chunk received
args (argparse.Namespace) : The cli args
Args:
json_chunk (dict) : The chunk containing the generated text
Returns:
str : The decoded string without the prefix
"""
chunk = chunk_bytes.decode("utf-8")
if args.backend in ["happy_vllm", "mistral"]:
return chunk.removeprefix("data: ")
else:
raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}")


def check_end_of_stream(chunk: str, args: argparse.Namespace) -> bool:
"""Checks if this is the last chunk of the stream
Args:
chunk (str) : The chunk to test
args (argparse.Namespace) : The cli args
Returns:
bool : Whether it is the last chunk of the stream
"""
if args.backend in ["happy_vllm", "mistral"]:
return chunk == "[DONE]"
else:
raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}")


def get_newly_generated_text(json_chunk: dict, args: argparse.Namespace) -> str:
"""Gets the newly generated text
Args:
json_chunk (dict) : The chunk containing the generated text
args (argparse.Namespace) : The cli args
Returns:
str : The newly generated text
"""
if args.backend == "happy_vllm":
if len(json_chunk['choices']):
data = json_chunk['choices'][0]['text']
return data
elif args.backend == "mistral":
Returns:
str : The newly generated text
"""
if len(json_chunk['choices']):
data = json_chunk['choices'][0]['delta']["content"]
return data
else:
raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}")
return ""

else:
return ""

def add_prompt_length(json_chunk: dict, output: QueryOutput, args: argparse.Namespace) -> None:
"""Add the prompt length to the QueryOutput

Args:
json_chunk (dict) : The chunk containing the prompt length
args (argparse.Namespace) : The cli args
"""
if args.backend in ["happy_vllm", 'mistral']:
if "usage" in json_chunk:
if json_chunk['usage'] is not None:
output.prompt_length = json_chunk['usage']['prompt_tokens']
else:
raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}")
def get_backend(backend_name: str) -> BackEnd:
implemented_backends = ["mistral", "happy_vllm"]
if backend_name not in implemented_backends:
raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}")
if backend_name == "happy_vllm":
return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]")
if backend_name == "mistral":
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]")
13 changes: 8 additions & 5 deletions src/benchmark_llm_serving/bench_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from benchmark_llm_serving import utils
from benchmark_llm_serving.io_classes import QueryInput
from benchmark_llm_serving.make_readmes import make_readme
from benchmark_llm_serving.backends import get_backend, BackEnd
from benchmark_llm_serving.make_graphs import draw_and_save_graphs
from benchmark_llm_serving.benchmark import launch_benchmark, augment_dataset
from benchmark_llm_serving.utils_args import get_parser_base_arguments, add_arguments_to_parser
Expand Down Expand Up @@ -143,6 +144,8 @@ def main():
for input_length in input_lengths:
for output_length in output_lengths:
input_output_lengths.append((input_length, output_length))

backend = get_backend(args.backend)

# Launch the benchmark for prompt ingestion speed
now = utils.get_now()
Expand All @@ -162,13 +165,13 @@ def main():
logger.info(f"{now} Benchmark for the prompt ingestion speed : instance {i} ")
args.output_file = os.path.join(raw_results_folder, f"prompt_ingestion_{i}.json")
dataset = add_prefixes_to_dataset(datasets[args.prompt_length], 4)
launch_benchmark(args, dataset, suite_id)
launch_benchmark(args, dataset, suite_id, backend=backend)
now = utils.get_now()
logger.info(f"{now} Benchmark for the prompt ingestion speed : instance {i} : DONE")
now = utils.get_now()
logger.info(f"{now} Benchmark for the prompt ingestion speed : DONE")

if args.backend == "happy_vllm":
if backend.backend_name == "happy_vllm":
# Launch the benchmark for the KV cache profile
now = utils.get_now()
logger.info(f"{now} Beginning the benchmarks for the KV cache profile")
Expand All @@ -185,7 +188,7 @@ def main():
now = utils.get_now()
dataset = add_prefixes_to_dataset(datasets[args.prompt_length], 4)
logger.info(f"{now} Beginning the benchmark for the KV cache profile, input length : {input_length}, output_length : {output_length}")
launch_benchmark(args, dataset, suite_id)
launch_benchmark(args, dataset, suite_id, backend=backend)
now = utils.get_now()
logger.info(f"{now} Benchmark for the KV cache profile, input length : {input_length}, output_length : {output_length} : DONE")
now = utils.get_now()
Expand Down Expand Up @@ -214,7 +217,7 @@ def main():
now = utils.get_now()
logger.info(f"{now} Benchmarks for the generation speed, input length : {input_length}, output_length : {output_length}, nb_requests : {nb_constant_requests}")
dataset = add_prefixes_to_dataset(datasets[args.prompt_length], 4)
launch_benchmark(args, dataset, suite_id)
launch_benchmark(args, dataset, suite_id, backend=backend)
now = utils.get_now()
logger.info(f"{now} Benchmarks for the generation speed, input length : {input_length}, output_length : {output_length}, nb_requests : {nb_constant_requests} : DONE")
current_timestamp = datetime.now().timestamp()
Expand All @@ -228,7 +231,7 @@ def main():
now = utils.get_now()
logger.info(f"{now} Drawing graphs")
draw_and_save_graphs(output_folder, speed_threshold=args.speed_threshold, gpu_name=args.gpu_name,
min_number_of_valid_queries=args.min_number_of_valid_queries, backend=args.backend)
min_number_of_valid_queries=args.min_number_of_valid_queries, backend=backend)
now = utils.get_now()
logger.info(f"{now} Drawing graphs : DONE")

Expand Down
Loading

0 comments on commit 3aff1e8

Please sign in to comment.