Skip to content

Commit

Permalink
Refactor to take into account /metrics endpoint which can be implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
gsolard committed Jul 31, 2024
1 parent a477301 commit a37c375
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 52 deletions.
34 changes: 30 additions & 4 deletions src/benchmark_llm_serving/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ class BackEnd():
TEMPERATURE = 0
REPETITION_PENALTY = 1.2

def __init__(self, backend_name: str, chunk_prefix: str = "data: ", last_chunk: str = "[DONE]"):
def __init__(self, backend_name: str, chunk_prefix: str = "data: ", last_chunk: str = "[DONE]", metrics_endpoint_exists: bool = True):
self.backend_name = backend_name
self.chunk_prefix = chunk_prefix
self.last_chunk = last_chunk
self.metrics_endpoint_exists = metrics_endpoint_exists

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
Expand All @@ -35,6 +36,17 @@ def get_newly_generated_text(self, json_chunk: dict) -> str:
str : The newly generated text
"""
raise NotImplemented("The subclass should implement this method") # type: ignore

def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict:
"""Gets the useful metrics from the parsed output of the /metrics endpoint
Args:
metrics_dict (dict) : The parsed output of the /metrics endpoint
Returns:
dict : The useful metrics
"""
raise NotImplemented("The subclass should implement this method if metrics_endpoint_exists") # type: ignore

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.
Expand Down Expand Up @@ -90,7 +102,6 @@ def add_prompt_length(self, json_chunk: dict, output: QueryOutput) -> None:
output.prompt_length = json_chunk['usage']['prompt_tokens']



class BackendHappyVllm(BackEnd):

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
Expand Down Expand Up @@ -127,6 +138,21 @@ def get_newly_generated_text(self, json_chunk: dict) -> str:
return data
else:
return ""

def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict:
"""Gets the useful metrics from the parsed output of the /metrics endpoint
Args:
metrics_dict (dict) : The parsed output of the /metrics endpoint
Returns:
dict : The useful metrics
"""
metrics = {}
metrics['num_requests_running'] = metrics_dict['vllm:num_requests_running'][0]['value']
metrics['num_requests_waiting'] = metrics_dict['vllm:num_requests_waiting'][0]['value']
metrics['gpu_cache_usage_perc'] = metrics_dict['vllm:gpu_cache_usage_perc'][0]['value']
return metrics


class BackEndMistral(BackEnd):
Expand Down Expand Up @@ -194,7 +220,7 @@ def get_backend(backend_name: str) -> BackEnd:
if backend_name not in implemented_backends:
raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}")
if backend_name == "happy_vllm":
return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]")
return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=True)
if backend_name == "mistral":
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]")
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=False)
return BackEnd("not_implemented")
10 changes: 5 additions & 5 deletions src/benchmark_llm_serving/query_profiles/constant_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def get_benchmark_results_constant_number(queries_dataset: List[QueryInput
for _ in range(args.n_workers)]
# Query the /metrics endpoint for one second before adding queries to the queue
for i in range(int(1/args.step_live_metrics)):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)
start_queries_timestamp = datetime.now().timestamp()
# Add the queries to the queue
Expand All @@ -113,22 +113,22 @@ async def get_benchmark_results_constant_number(queries_dataset: List[QueryInput
if continue_condition(current_timestamp, start_queries_timestamp, args, count_query):
# While the queue is full, periodically query the /metrics endpoint
while queue.full():
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await queue.put(query_input)
count_query += 1
if current_timestamp - start_queries_timestamp >= args.max_duration:
now = get_now()
logger.info(f"{now} Max duration {args.max_duration}s has been reached")
# Wait for all enqueued items to be processed and during this time, periodically query the /metrics endpoint
while not queue.empty():
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)
await queue.join()
# Query the /metrics endpoint for one second after the queries finished
for i in range(int(1/args.step_live_metrics)):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)

# The workers are now idly waiting for the next queue item and we
Expand Down
6 changes: 3 additions & 3 deletions src/benchmark_llm_serving/query_profiles/growing_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def get_benchmark_results_growing_requests(queries_dataset: List[QueryInpu
async with aiohttp.ClientSession(connector=connector) as session:
# Query the /metrics endpoint for one second before launching the first queries
for i in range(int(1/args.step_live_metrics)):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)
start_queries_timestamp = datetime.now().timestamp()
# For a number of queries
Expand All @@ -80,14 +80,14 @@ async def get_benchmark_results_growing_requests(queries_dataset: List[QueryInpu
nb_queries_launched += n
# While we wait for the tasks to be done, we query the /metrics endpoint
while not tasks_are_done(tasks):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)
if current_timestamp - start_queries_timestamp >= args.max_duration:
now = get_now()
logger.info(f"{now} Max duration {args.max_duration}s has been reached")
# Query the /metrics endpoint for one second after launching the queries are done
for i in range(int(1/args.step_live_metrics)):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)

return results, all_live_metrics
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def get_benchmark_results_scheduled_requests(queries_dataset: List[QueryIn
async with aiohttp.ClientSession(connector=connector) as session:
# Query the /metrics endpoint for one second before launching the first queries
for i in range(int(1/args.step_live_metrics)):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)
start_queries_timestamp = datetime.now().timestamp()
# Add the initial timestamp to the queries
Expand All @@ -164,17 +164,17 @@ async def get_benchmark_results_scheduled_requests(queries_dataset: List[QueryIn
logger.info(f"{now} {current_query_index_to_launch} requests in total have been launched")
tasks += [asyncio.create_task(query_function(query_input, session, completions_url, results, args, backend)) for query_input in queries_to_launch]

asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)

# Once all queries have been sent, we still query the /metrics endpoint
# Until all the queries are done
while not tasks_are_done(tasks):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)

# Query the /metrics endpoint for one second after launching the queries are done
for i in range(int(1/args.step_live_metrics)):
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args))
asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, backend))
await asyncio.sleep(args.step_live_metrics)
return results, all_live_metrics
27 changes: 5 additions & 22 deletions src/benchmark_llm_serving/utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,24 @@ def parse_metrics_response(response_text: str) -> dict:
return results


def get_live_vllm_metrics(response_text: str) -> dict:
"""From the text of the response of the /metrics endpoint, gets the metrics relevant to us
Args:
response_text (str) : The string obtained from querying the /metrics endpoint
Returns:
dict : A dictionary containing the percentage of KV cache used, the number of running requests
and the number of waiting requests
"""
vllm_metrics = {}
parsed_response = parse_metrics_response(response_text)
vllm_metrics['num_requests_running'] = parsed_response['vllm:num_requests_running'][0]['value']
vllm_metrics['num_requests_waiting'] = parsed_response['vllm:num_requests_waiting'][0]['value']
vllm_metrics['gpu_cache_usage_perc'] = parsed_response['vllm:gpu_cache_usage_perc'][0]['value']
return vllm_metrics


async def get_live_metrics(session: aiohttp.ClientSession, metrics_url: str, all_live_metrics: List[dict],
args: argparse.Namespace) -> None:
backend) -> None:
"""Queries the /metrics endpoint, gets the live metrics and add them to the list all_live_metrics
Args:
session (aiohttp.ClientSession) : The aiohttp session
metrics_url (str) : The url to the /metrics endpoint
all_live_metrics (list) : The list to which we add the live metrics results
args (argparse.Namespace) : The CLI args
backend (Backend) : The backend
"""
tmp_list = []
if args.backend == "happy_vllm":
if backend.metrics_endpoint_exists:
async with session.get(url=metrics_url) as response:
if response.status == 200:
async for chunk_bytes in response.content:
tmp_list.append(chunk_bytes.decode('utf-8'))
live_metrics = get_live_vllm_metrics("".join(tmp_list))
parsed_metrics = parse_metrics_response("".join(tmp_list))
live_metrics = backend.get_metrics_from_metrics_dict(parsed_metrics)
timestamp = datetime.now().timestamp()
live_metrics['timestamp'] = timestamp
all_live_metrics.append(live_metrics.copy())
23 changes: 23 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import os
import pytest
import argparse
from pathlib import Path

from benchmark_llm_serving import backends
from benchmark_llm_serving import utils_metrics
from benchmark_llm_serving.io_classes import QueryOutput, QueryInput


def get_metrics_response():
current_directory = Path(os.path.dirname(os.path.realpath(__file__)))
metrics_response_file = current_directory / "data" / "metrics_response.txt"
with open(metrics_response_file, 'r') as txt_file:
metrics_response = txt_file.read()
return metrics_response


def test_get_backend():
# happy_vllm backend
backend = backends.get_backend("happy_vllm")
Expand Down Expand Up @@ -112,6 +123,18 @@ def test_backend_happy_vllm_add_prompt_length():
assert output.prompt_length == 0


def test_backend_happy_vllm_get_metrics_from_metrics_dict():
backend = backends.get_backend("happy_vllm")
metrics_response = get_metrics_response()
parsed_metrics = utils_metrics.parse_metrics_response(metrics_response)
live_metrics = backend.get_metrics_from_metrics_dict(parsed_metrics)
assert isinstance(live_metrics, dict)
assert set(live_metrics) == {"num_requests_running", "num_requests_waiting", "gpu_cache_usage_perc"}
assert live_metrics['num_requests_running'] == pytest.approx(2.0)
assert live_metrics['num_requests_waiting'] == pytest.approx(1.0)
assert live_metrics['gpu_cache_usage_perc'] == pytest.approx(4.2)


def test_backend_mistral_get_payload():
backend = backends.get_backend("mistral")
prompts_list = ["Hey. How are you?", "Fine, you ?"]
Expand Down
19 changes: 5 additions & 14 deletions tests/test_utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aioresponses import aioresponses

from benchmark_llm_serving import utils_metrics
from benchmark_llm_serving.backends import get_backend


def get_metrics_response():
Expand All @@ -28,31 +29,21 @@ def test_parse_metrics_response():
assert 'value' in parsed_metrics_response[metric][0]


def test_get_live_vllm_metrics():
metrics_response = get_metrics_response()
live_vllm_metrics = utils_metrics.get_live_vllm_metrics(metrics_response)
assert isinstance(live_vllm_metrics, dict)
assert set(live_vllm_metrics) == {"num_requests_running", "num_requests_waiting", "gpu_cache_usage_perc"}
assert live_vllm_metrics['num_requests_running'] == pytest.approx(2.0)
assert live_vllm_metrics['num_requests_waiting'] == pytest.approx(1.0)
assert live_vllm_metrics['gpu_cache_usage_perc'] == pytest.approx(4.2)


@pytest.mark.asyncio()
async def test_get_live_metrics():
# backend happy_vllm
backend = get_backend("happy_vllm")
metrics_response = get_metrics_response()
all_live_metrics = []
nb_query = 10
args = argparse.Namespace(backend="happy_vllm")
with aioresponses() as mocked:
for i in range(nb_query):
new_metrics_response = metrics_response.replace('vllm:num_requests_running{model_name="/home/data/models/Meta-Llama-3-8B-Instruct"} 2.0',
f'vllm:num_requests_running{{model_name="/home/data/models/Meta-Llama-3-8B-Instruct"}} {i}.0')
mocked.get('my_url', status=200, body=new_metrics_response)
session = aiohttp.ClientSession()
for i in range(nb_query):
results = await utils_metrics.get_live_metrics(session, 'my_url', all_live_metrics, args)
results = await utils_metrics.get_live_metrics(session, 'my_url', all_live_metrics, backend)
await session.close()
assert len(all_live_metrics) == nb_query
for i in range(nb_query):
Expand All @@ -64,17 +55,17 @@ async def test_get_live_metrics():
assert all_live_metrics[i]['timestamp'] != all_live_metrics[i-1]['timestamp']

# backend mistral
backend = get_backend("mistral")
metrics_response = get_metrics_response()
all_live_metrics = []
nb_query = 10
args = argparse.Namespace(backend="mistral")
with aioresponses() as mocked:
for i in range(nb_query):
new_metrics_response = metrics_response.replace('vllm:num_requests_running{model_name="/home/data/models/Meta-Llama-3-8B-Instruct"} 2.0',
f'vllm:num_requests_running{{model_name="/home/data/models/Meta-Llama-3-8B-Instruct"}} {i}.0')
mocked.get('my_url', status=200, body=new_metrics_response)
session = aiohttp.ClientSession()
for i in range(nb_query):
results = await utils_metrics.get_live_metrics(session, 'my_url', all_live_metrics, args)
results = await utils_metrics.get_live_metrics(session, 'my_url', all_live_metrics, backend)
await session.close()
assert len(all_live_metrics) == 0

0 comments on commit a37c375

Please sign in to comment.