Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mistral backend #5

Merged
merged 11 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# MODEL=
# BASE_URL=
# MODEL_NAME=
HOST="localhost"
PORT=8000

Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
![Build & Tests](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/build_and_tests.yaml/badge.svg)
![Wheel setup](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/wheel.yaml/badge.svg)

benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, it is focused on LLMs served via [vllm](https://github.com/vllm-project/vllm) and more specifically via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables.
benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, two backends are implemented : [mistral](https://docs.mistral.ai/api/) and [vLLM](https://github.com/vllm-project/vllm) (via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables).

## Installation

Expand Down Expand Up @@ -45,8 +45,8 @@ After the bench suite ends, you obtain a folder containing :
- `prompt_ingestion_graph.png` containing the graph of the speed of prompt ingestion by the model. It is the time taken to produce the first token vs the length of the prompt. The speed is the slope of this line and is indicated in the title of the graph. The data used for this graph is contained in the `data` folder.
- `thresholds.csv` is a .csv containing, for each couple of input length/output length, the number of parallel requests such that : the kv cache usage is inferior to 100% and the speed generation is above a specified threshold (by default, 20 tokens per second)
- `total_speed_generation_graph.png` is a graph containing, for each couple of input length/output length, the total speed generation vs the number of parallel requests. So, for example, if the model can answer to 10 parallel requests each with a speed of 20 tokens per second, the value on the graph will be 200 tokens per second (20 x 10). The data used for this graph is contained in the `data` folder.
- A folder `kv_cache_profile` containing, for each couple of input length/output length, a graph showing the response of the LLMs to n requests launched at the same time. On the y-axis, you have the kv cache usage, the number of requests running and the number of requests waiting. On the x-axis, you have the time. The graph is obtained by sending one request, watching the response of the LLM then two requests, then three requests, ...
- A folder `speed_generation` containing, for each couple of input length/output length, a graph showing the speed generation (per request) in token per second vs the number of parallel requests. The graph also shows the time to the first token generated in milliseconds and the max kv cache usage for this number of parallel requests. The corresponding data is in the `data` folder
- If the backend is `happy_vllm` : a folder `kv_cache_profile` containing, for each couple of input length/output length, a graph showing the response of the LLMs to n requests launched at the same time. On the y-axis, you have the kv cache usage, the number of requests running and the number of requests waiting. On the x-axis, you have the time. The graph is obtained by sending one request, watching the response of the LLM then two requests, then three requests, ...
- A folder `speed_generation` containing, for each couple of input length/output length, a graph showing the speed generation (per request) in token per second vs the number of parallel requests. The graph also shows the time to the first token generated in milliseconds. If the backend is `happy_vllm` it also shows the max kv cache usage for this number of parallel requests. The corresponding data is in the `data` folder

Note that the various input lengths are "32", "1024" and "4096" to simulate small, medium and long prompt. These length are to be understood as roughly this size (and generally speaking a bit above this size). The various output lengths are 16, 128 and 1024. Contrary to the input lengths, these are exact : the model produced exactly this number of tokens.

Expand All @@ -70,9 +70,10 @@ Here is a list of the arguments:
- `min-duration-speed-generation` : For each individual script benchmarking the speed generation, if this min duration (in seconds) is reached and the target-queries-nb is also reached, the script will end (default `60`)
- `target-queries-nb-speed-generation` : For each individual script benchmarking the speed generation, if this target-queries-nb is reached and the min-duration is also reached, the script will end (default `100`)
- `min-number-of-valid-queries`: The minimal number of valid queries that should be present in a file to be considered for graph drawing (default `50`)
- `backend` : For now, only happy_vllm is supported.
- `backend` : Only `happy_vllm`and `mistral` are supported.
- `completions-endpoint` : The endpoint for completions (default `/v1/completions`)
- `metrics-endpoint` : The endpoint for the metrics (default `/metrics/`)
- `info-endpoint` : The info endpoint (default `/v1/info`)
- `launch-arguments-endpoint` : The endpoint for getting the launch arguments of the API (default `/v1/launch_arguments`)
- `speed-threshold` : The speed generation above which the model is considered ok (default value `20`). It is only useful when writing `thresholds.csv`
- `model-name`: The name that should be displayed on the graph (default value : `None`). If it is `None`, the name displayed will be the one of the argument `model`
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ requires-python = ">=3.10,<4.0"
dependencies = [
"aiohttp>=3.9.5,<4.0",
"prometheus_client>=0.20.0,<1.0",
"matplotlib>=3.8.4,<4.0",
"pydantic>=2.7.1,<3.0",
"pydantic-settings>=2.2.1,<3.0",
"requests>=2.32.0,<3.0",
"matplotlib>=3.9.1,<4.0",
"pydantic>=2.8.2,<3.0",
"pydantic-settings>=2.3.4,<3.0",
"requests>=2.32.3,<3.0",
"mdutils>=1.6.0,<2.0"
]
classifiers = [
Expand All @@ -44,7 +44,7 @@ include = ["benchmark_llm_serving*"]
bench-suite = "benchmark_llm_serving.bench_suite:main"

[project.optional-dependencies]
test = ["httpx>=0.23,<1.0", "pytest>=8.2.0,<9.0", "pytest-cov>=5.0.0,<6.0", "mypy>=1.7.1,<2.0", "pytest-asyncio>=0.23.6,<1.0",
test = ["httpx>=0.27,<1.0", "pytest>=8.3.2,<9.0", "pytest-cov>=5.0.0,<6.0", "mypy>=1.11.0,<2.0", "pytest-asyncio>=0.23.8,<1.0",
"aioresponses>=0.7.6,<1.0", "requests-mock>=1.12.1,<2.0"]

[tool.pytest.ini_options]
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aiohttp==3.9.5
prometheus_client==0.20.0
matplotlib==3.8.4
pydantic==2.7.1
pydantic-settings==2.2.1
requests==2.32.0
matplotlib==3.9.1
pydantic==2.8.2
pydantic-settings==2.3.4
requests==2.32.3
mdutils==1.6.0
226 changes: 226 additions & 0 deletions src/benchmark_llm_serving/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import argparse

from benchmark_llm_serving.io_classes import QueryOutput, QueryInput


class BackEnd():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this class is defined in first place?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to force the declaration of some methods in the class which inherits from BackEnd, and to declare common methods as well.


TEMPERATURE = 0
REPETITION_PENALTY = 1.2

def __init__(self, backend_name: str, chunk_prefix: str = "data: ", last_chunk: str = "[DONE]", metrics_endpoint_exists: bool = True):
self.backend_name = backend_name
self.chunk_prefix = chunk_prefix
self.last_chunk = last_chunk
self.metrics_endpoint_exists = metrics_endpoint_exists

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model

Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args

Returns:
dict : The payload
"""
raise NotImplemented("The subclass should implement this method") # type: ignore

def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text

Args:
json_chunk (dict) : The chunk containing the generated text

Returns:
str : The newly generated text
"""
raise NotImplemented("The subclass should implement this method") # type: ignore

def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict:
"""Gets the useful metrics from the parsed output of the /metrics endpoint

Args:
metrics_dict (dict) : The parsed output of the /metrics endpoint

Returns:
dict : The useful metrics
"""
raise NotImplemented("The subclass should implement this method if metrics_endpoint_exists") # type: ignore

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.

Args:
chunk (str) : The chunk to consider

Returns:
bool : Whether the chunk is valid or not
"""
return True

def get_completions_headers(self) -> dict:
"""Gets the headers (depending on the backend) to use for the request

Returns:
dict: The headers

"""
return {}

def remove_response_prefix(self, chunk: str) -> str:
"""Removes the prefix in the response of a model

Args:
chunk (str) : The chunk received

Returns:
str : The string without the prefix
"""
return chunk.removeprefix(self.chunk_prefix)

def check_end_of_stream(self, chunk: str) -> bool:
"""Checks whether this is the last chunk of the stream

Args:
chunk (str) : The chunk to test

Returns:
bool : Whether it is the last chunk of the stream
"""
return chunk == self.last_chunk

def add_prompt_length(self, json_chunk: dict, output: QueryOutput) -> None:
"""Add the prompt length to the QueryOutput if the key "usage" is in the chunk

Args:
json_chunk (dict) : The chunk containing the prompt length
output (QueryOutput) : The output
"""
if "usage" in json_chunk:
if json_chunk['usage'] is not None:
output.prompt_length = json_chunk['usage']['prompt_tokens']


class BackendHappyVllm(BackEnd):

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model

Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args

Returns:
dict : The payload
"""
return {"prompt": query_input.prompt,
"model": args.model,
"max_tokens": args.output_length,
"min_tokens": args.output_length,
"temperature": self.TEMPERATURE,
"repetition_penalty": self.REPETITION_PENALTY,
"stream": True,
"stream_options": {"include_usage": True}
}

def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text

Args:
json_chunk (dict) : The chunk containing the generated text

Returns:
str : The newly generated text
"""
if len(json_chunk['choices']):
data = json_chunk['choices'][0]['text']
return data
else:
return ""

def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict:
"""Gets the useful metrics from the parsed output of the /metrics endpoint

Args:
metrics_dict (dict) : The parsed output of the /metrics endpoint

Returns:
dict : The useful metrics
"""
metrics = {}
metrics['num_requests_running'] = metrics_dict['vllm:num_requests_running'][0]['value']
metrics['num_requests_waiting'] = metrics_dict['vllm:num_requests_waiting'][0]['value']
metrics['gpu_cache_usage_perc'] = metrics_dict['vllm:gpu_cache_usage_perc'][0]['value']
return metrics


class BackEndMistral(BackEnd):

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model

Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args

Returns:
dict : The payload
"""
return {"messages": [{"role": "user", "content": query_input.prompt}],
"model": args.model,
"max_tokens": args.output_length,
"min_tokens": args.output_length,
"temperature": self.TEMPERATURE,
"stream": True
}

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.

Args:
chunk (str) : The chunk to consider

Returns:
bool : Whether the chunk is valid or not
"""
if chunk[:4] == "tok-":
return False
else:
return True

def get_completions_headers(self) -> dict:
"""Gets the headers (depending on the backend) to use for the request

Returns:
dict: The headers

"""
return {"Accept": "application/json",
"Content-Type": "application/json"}

def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text

Args:
json_chunk (dict) : The chunk containing the generated text

Returns:
str : The newly generated text
"""
if len(json_chunk['choices']):
data = json_chunk['choices'][0]['delta']["content"]
return data
else:
return ""


def get_backend(backend_name: str) -> BackEnd:
implemented_backends = ["mistral", "happy_vllm"]
if backend_name not in implemented_backends:
raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}")
if backend_name == "happy_vllm":
return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=True)
if backend_name == "mistral":
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=False)
return BackEnd("not_implemented")
maxDavid40 marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading