Skip to content

Commit

Permalink
[Benchmark] Refactor sample_requests in benchmark_throughput (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#3613)

Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
gty111 and ywang96 authored Apr 4, 2024
1 parent 819a309 commit b778200
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,23 @@ def sample_requests(
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Shuffle the dataset.
random.shuffle(dataset)

# Filter out too long sequences.
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
Expand All @@ -53,9 +54,7 @@ def sample_requests(
continue
filtered_dataset.append((prompt, prompt_len, output_len))

# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
return filtered_dataset


def run_vllm(
Expand Down

0 comments on commit b778200

Please sign in to comment.