Skip to content

Commit

Permalink
[Misc] Update ShareGPT Dataset Sampling in Serving Benchmark (vllm-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored and jimpang committed Apr 25, 2024
1 parent c5cea15 commit 07c1d0a
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator, List, Tuple
from typing import AsyncGenerator, List, Optional, Tuple

import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
Expand Down Expand Up @@ -58,7 +58,11 @@ def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
Expand All @@ -68,38 +72,32 @@ def sample_sharegpt_requests(
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# some of these will be filtered out, so sample more than we need
sampled_indices = random.sample(range(len(dataset)),
int(num_requests * 1.2))
dataset = [dataset[i] for i in sampled_indices]

# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Shuffle the dataset.
random.shuffle(dataset)

# Filter out too long sequences.
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))

# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
return filtered_dataset


def sample_sonnet_requests(
Expand Down Expand Up @@ -361,13 +359,15 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)

elif args.dataset_name == "sharegpt":
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)

elif args.dataset_name == "sonnet":
Expand Down Expand Up @@ -524,6 +524,12 @@ def main(args: argparse.Namespace):
default=1000,
help="Number of prompts to process.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
parser.add_argument(
"--sonnet-input-len",
type=int,
Expand Down

0 comments on commit 07c1d0a

Please sign in to comment.