Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OpenAI API : offline batch(file) processing #699

Merged
merged 11 commits into from
Jul 29, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: stable
rev: 24.4.2
hooks:
- id: black
86 changes: 86 additions & 0 deletions examples/usage/openai_batch_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import os
import time

import openai
from openai import OpenAI


class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")

self.client = client

def process_batch(self, input_file_path, endpoint, completion_window):

# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")

# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)

# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)

# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None

# If the batch job is completed, process the results
if batch_job.status == "completed":

# print result of batch job
print("batch", batch_job.request_counts)

result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file

# Save the content to a local file
result_file_name = "batch_job_chat_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)

return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None


# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)

# Process the batch job
input_file_path = "input.jsonl"
endpoint = "/v1/chat/completions"
completion_window = "24h"

# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)

# Print the results
print(results)
86 changes: 86 additions & 0 deletions examples/usage/openai_batch_complete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import os
import time

import openai
from openai import OpenAI


class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")

self.client = client

def process_batch(self, input_file_path, endpoint, completion_window):

# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")

# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)

# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)

# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None

# If the batch job is completed, process the results
if batch_job.status == "completed":

# print result of batch job
print("batch", batch_job.request_counts)

result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file

# Save the content to a local file
result_file_name = "batch_job_complete_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)

return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None


# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)

# Process the batch job
input_file_path = "input_complete.jsonl"
endpoint = "/v1/completions"
completion_window = "24h"

# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)

# Print the results
print(results)
37 changes: 37 additions & 0 deletions examples/usage/openai_parallel_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
print(response)


# Text completion
response = client.completions.create(
model="default",
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
n=1,
temperature=0.8,
max_tokens=32,
)
print(response)


# Text completion
response = client.completions.create(
model="default",
Expand All @@ -24,6 +35,17 @@
print(response)


# Text completion
response = client.completions.create(
model="default",
prompt=["The name of the famous soccer player is"],
n=1,
temperature=0.8,
max_tokens=128,
)
print(response)


# Text completion
response = client.completions.create(
model="default",
Expand Down Expand Up @@ -60,6 +82,21 @@
)
print(response)

# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=64,
logprobs=True,
n=1,
)
print(response)


# Chat completion
response = client.chat.completions.create(
model="default",
Expand Down
22 changes: 20 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,26 @@ def post_init(self):
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
else:

parallel_sample_num = self.sampling_params.get("n", 1)
parallel_sample_num_list = []
if isinstance(self.sampling_params, dict):
parallel_sample_num = self.sampling_params.get("n", 1)
elif isinstance(self.sampling_params, list):
for sp in self.sampling_params:
parallel_sample_num = sp.get("n", 1)
parallel_sample_num_list.append(parallel_sample_num)
parallel_sample_num = max(parallel_sample_num_list)
all_equal = all(
element == parallel_sample_num
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
else:
parallel_sample_num = 1
self.parallel_sample_num = parallel_sample_num

if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
Expand Down
24 changes: 13 additions & 11 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)

if server_args.context_length is not None:
self.context_len = server_args.context_length
else:
Expand Down Expand Up @@ -152,31 +153,33 @@ async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False
):
if not is_cache_for_prefill:
rid = obj.rid if index is None else obj.rid[index]
input_text = obj.text if index is None else obj.text[index]
not_use_index = not (index is not None)
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = (
self.tokenizer.encode(input_text)
if obj.input_ids is None
else obj.input_ids
)
if index is not None and obj.input_ids:
if not not_use_index and obj.input_ids:
input_ids = obj.input_ids[index]

self._validate_input_length(input_ids)

sampling_params = self._get_sampling_params(
obj.sampling_params if index is None else obj.sampling_params[index]
obj.sampling_params if not_use_index else obj.sampling_params[index]
)
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if index is None else obj.image_data[index]
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if index is None else obj.return_logprob[index]
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
)
else:
if isinstance(obj.text, list):
Expand Down Expand Up @@ -224,7 +227,7 @@ async def _handle_single_request(

async def _handle_batch_request(self, obj: GenerateReqInput, request):
batch_size = obj.batch_size
parallel_sample_num = obj.sampling_params[0].get("n", 1)
parallel_sample_num = obj.parallel_sample_num

if parallel_sample_num != 1:
# Send prefill requests to cache the common input
Expand All @@ -241,15 +244,14 @@ async def _handle_batch_request(self, obj: GenerateReqInput, request):
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]

# First send out all requests
for i in range(batch_size):
for j in range(parallel_sample_num):
if j == 0 and parallel_sample_num != 1:
continue
index = i * parallel_sample_num + j
if parallel_sample_num != 1:
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index += batch_size - 1 - i
rid = obj.rid[index]
if parallel_sample_num == 1:
Expand Down
Loading
Loading