-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mistral backend #5
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e7543f7
Fix some mypy errors
gsolard 6923e8a
Delete a useless line
gsolard 7be07d0
Change the query function to reflect changes in vLLM
gsolard 8561d89
Add mistral backend
gsolard 45d8330
Added docstrings and fixed the existing tests
gsolard e5080fa
Added new unit tests
gsolard c2747b7
Fix a mypy error
gsolard 3aff1e8
Refactor to have BackEnd classes
gsolard 2a7d233
Fix mypy errors
gsolard a477301
Update readme
gsolard a37c375
Refactor to take into account /metrics endpoint which can be implemented
gsolard File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# MODEL= | ||
# BASE_URL= | ||
# MODEL_NAME= | ||
HOST="localhost" | ||
PORT=8000 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
aiohttp==3.9.5 | ||
prometheus_client==0.20.0 | ||
matplotlib==3.8.4 | ||
pydantic==2.7.1 | ||
pydantic-settings==2.2.1 | ||
requests==2.32.0 | ||
matplotlib==3.9.1 | ||
pydantic==2.8.2 | ||
pydantic-settings==2.3.4 | ||
requests==2.32.3 | ||
mdutils==1.6.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
import argparse | ||
|
||
from benchmark_llm_serving.io_classes import QueryOutput, QueryInput | ||
|
||
|
||
class BackEnd(): | ||
|
||
TEMPERATURE = 0 | ||
REPETITION_PENALTY = 1.2 | ||
|
||
def __init__(self, backend_name: str, chunk_prefix: str = "data: ", last_chunk: str = "[DONE]", metrics_endpoint_exists: bool = True): | ||
self.backend_name = backend_name | ||
self.chunk_prefix = chunk_prefix | ||
self.last_chunk = last_chunk | ||
self.metrics_endpoint_exists = metrics_endpoint_exists | ||
|
||
def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict: | ||
"""Gets the payload to give to the model | ||
|
||
Args: | ||
query_input (QueryInput) : The query input to use | ||
args (argparse.Namespace) : The cli args | ||
|
||
Returns: | ||
dict : The payload | ||
""" | ||
raise NotImplemented("The subclass should implement this method") # type: ignore | ||
|
||
def get_newly_generated_text(self, json_chunk: dict) -> str: | ||
"""Gets the newly generated text | ||
|
||
Args: | ||
json_chunk (dict) : The chunk containing the generated text | ||
|
||
Returns: | ||
str : The newly generated text | ||
""" | ||
raise NotImplemented("The subclass should implement this method") # type: ignore | ||
|
||
def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict: | ||
"""Gets the useful metrics from the parsed output of the /metrics endpoint | ||
|
||
Args: | ||
metrics_dict (dict) : The parsed output of the /metrics endpoint | ||
|
||
Returns: | ||
dict : The useful metrics | ||
""" | ||
raise NotImplemented("The subclass should implement this method if metrics_endpoint_exists") # type: ignore | ||
|
||
def test_chunk_validity(self, chunk: str) -> bool: | ||
"""Tests if the chunk is valid or should not be considered. | ||
|
||
Args: | ||
chunk (str) : The chunk to consider | ||
|
||
Returns: | ||
bool : Whether the chunk is valid or not | ||
""" | ||
return True | ||
|
||
def get_completions_headers(self) -> dict: | ||
"""Gets the headers (depending on the backend) to use for the request | ||
|
||
Returns: | ||
dict: The headers | ||
|
||
""" | ||
return {} | ||
|
||
def remove_response_prefix(self, chunk: str) -> str: | ||
"""Removes the prefix in the response of a model | ||
|
||
Args: | ||
chunk (str) : The chunk received | ||
|
||
Returns: | ||
str : The string without the prefix | ||
""" | ||
return chunk.removeprefix(self.chunk_prefix) | ||
|
||
def check_end_of_stream(self, chunk: str) -> bool: | ||
"""Checks whether this is the last chunk of the stream | ||
|
||
Args: | ||
chunk (str) : The chunk to test | ||
|
||
Returns: | ||
bool : Whether it is the last chunk of the stream | ||
""" | ||
return chunk == self.last_chunk | ||
|
||
def add_prompt_length(self, json_chunk: dict, output: QueryOutput) -> None: | ||
"""Add the prompt length to the QueryOutput if the key "usage" is in the chunk | ||
|
||
Args: | ||
json_chunk (dict) : The chunk containing the prompt length | ||
output (QueryOutput) : The output | ||
""" | ||
if "usage" in json_chunk: | ||
if json_chunk['usage'] is not None: | ||
output.prompt_length = json_chunk['usage']['prompt_tokens'] | ||
|
||
|
||
class BackendHappyVllm(BackEnd): | ||
|
||
def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict: | ||
"""Gets the payload to give to the model | ||
|
||
Args: | ||
query_input (QueryInput) : The query input to use | ||
args (argparse.Namespace) : The cli args | ||
|
||
Returns: | ||
dict : The payload | ||
""" | ||
return {"prompt": query_input.prompt, | ||
"model": args.model, | ||
"max_tokens": args.output_length, | ||
"min_tokens": args.output_length, | ||
"temperature": self.TEMPERATURE, | ||
"repetition_penalty": self.REPETITION_PENALTY, | ||
"stream": True, | ||
"stream_options": {"include_usage": True} | ||
} | ||
|
||
def get_newly_generated_text(self, json_chunk: dict) -> str: | ||
"""Gets the newly generated text | ||
|
||
Args: | ||
json_chunk (dict) : The chunk containing the generated text | ||
|
||
Returns: | ||
str : The newly generated text | ||
""" | ||
if len(json_chunk['choices']): | ||
data = json_chunk['choices'][0]['text'] | ||
return data | ||
else: | ||
return "" | ||
|
||
def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict: | ||
"""Gets the useful metrics from the parsed output of the /metrics endpoint | ||
|
||
Args: | ||
metrics_dict (dict) : The parsed output of the /metrics endpoint | ||
|
||
Returns: | ||
dict : The useful metrics | ||
""" | ||
metrics = {} | ||
metrics['num_requests_running'] = metrics_dict['vllm:num_requests_running'][0]['value'] | ||
metrics['num_requests_waiting'] = metrics_dict['vllm:num_requests_waiting'][0]['value'] | ||
metrics['gpu_cache_usage_perc'] = metrics_dict['vllm:gpu_cache_usage_perc'][0]['value'] | ||
return metrics | ||
|
||
|
||
class BackEndMistral(BackEnd): | ||
|
||
def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict: | ||
"""Gets the payload to give to the model | ||
|
||
Args: | ||
query_input (QueryInput) : The query input to use | ||
args (argparse.Namespace) : The cli args | ||
|
||
Returns: | ||
dict : The payload | ||
""" | ||
return {"messages": [{"role": "user", "content": query_input.prompt}], | ||
"model": args.model, | ||
"max_tokens": args.output_length, | ||
"min_tokens": args.output_length, | ||
"temperature": self.TEMPERATURE, | ||
"stream": True | ||
} | ||
|
||
def test_chunk_validity(self, chunk: str) -> bool: | ||
"""Tests if the chunk is valid or should not be considered. | ||
|
||
Args: | ||
chunk (str) : The chunk to consider | ||
|
||
Returns: | ||
bool : Whether the chunk is valid or not | ||
""" | ||
if chunk[:4] == "tok-": | ||
return False | ||
else: | ||
return True | ||
|
||
def get_completions_headers(self) -> dict: | ||
"""Gets the headers (depending on the backend) to use for the request | ||
|
||
Returns: | ||
dict: The headers | ||
|
||
""" | ||
return {"Accept": "application/json", | ||
"Content-Type": "application/json"} | ||
|
||
def get_newly_generated_text(self, json_chunk: dict) -> str: | ||
"""Gets the newly generated text | ||
|
||
Args: | ||
json_chunk (dict) : The chunk containing the generated text | ||
|
||
Returns: | ||
str : The newly generated text | ||
""" | ||
if len(json_chunk['choices']): | ||
data = json_chunk['choices'][0]['delta']["content"] | ||
return data | ||
else: | ||
return "" | ||
|
||
|
||
def get_backend(backend_name: str) -> BackEnd: | ||
implemented_backends = ["mistral", "happy_vllm"] | ||
if backend_name not in implemented_backends: | ||
raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}") | ||
if backend_name == "happy_vllm": | ||
return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=True) | ||
if backend_name == "mistral": | ||
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=False) | ||
return BackEnd("not_implemented") | ||
maxDavid40 marked this conversation as resolved.
Show resolved
Hide resolved
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this class is defined in first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to force the declaration of some methods in the class which inherits from BackEnd, and to declare common methods as well.