From d5e549d50d7eb117be6d011ad8920de47ba054d0 Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Mon, 22 Jul 2024 10:51:56 +0000 Subject: [PATCH 1/8] finish the functionable version of oai API batch and file --- examples/usage/openai_batch.py | 98 +++++ .../sglang/srt/managers/tokenizer_manager.py | 28 +- python/sglang/srt/openai_api/adapter.py | 353 ++++++++++++++++-- python/sglang/srt/openai_api/protocol.py | 46 +++ python/sglang/srt/server.py | 32 ++ 5 files changed, 514 insertions(+), 43 deletions(-) create mode 100644 examples/usage/openai_batch.py diff --git a/examples/usage/openai_batch.py b/examples/usage/openai_batch.py new file mode 100644 index 0000000000..fb0c9c343f --- /dev/null +++ b/examples/usage/openai_batch.py @@ -0,0 +1,98 @@ +from openai import OpenAI +import openai +import time +import json +import os + +class OpenAIBatchProcessor: + def __init__(self, api_key): + # client = OpenAI(api_key=api_key) + client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + + self.client = client + + def process_batch(self, input_file_path, endpoint, completion_window): + + ## # Chat completion + response = self.client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals and culture."}, + ], + temperature=0.8, + max_tokens=64, + ) + print(response) + + + # Upload the input file + with open(input_file_path, "rb") as file: + uploaded_file = self.client.files.create( + file=file, + purpose="batch" + ) + print('file response:', uploaded_file) + print('file id:', uploaded_file.id) + + # Create the batch job + batch_job = self.client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window + ) + + print('batch job:', batch_job) + + # Monitor the batch job status + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) # Wait for 3 seconds before checking the status again + print(f"Batch job status: {batch_job.status}...trying again in 3 seconds...") + batch_job = self.client.batches.retrieve(batch_job.id) + + + + + # If the batch job is completed, process the results + if batch_job.status == "completed": + + # print result of batch job + print('batch', batch_job.request_counts) + + result_file_id = batch_job.output_file_id + # Retrieve the file content from the server + file_response = self.client.files.content(result_file_id) + result_content = file_response.read() # Read the content of the file + + # Save the content to a local file + result_file_name = "/home/ubuntu/my_sglang_dev/sglang/examples/usage/batch_job_results.jsonl" + with open(result_file_name, "wb") as file: + file.write(result_content) # Write the binary content to the file + print('read result:', result_content) + # Load data from the saved JSONL file + results = [] + with open(result_file_name, "r", encoding="utf-8") as file: + for line in file: + json_object = json.loads(line.strip()) # Parse each line as a JSON object + results.append(json_object) + + return results + else: + print(f"Batch job failed with status: {batch_job.status}") + return None + +# Initialize the OpenAIBatchProcessor +api_key = os.environ.get("OPENAI_API_KEY") +processor = OpenAIBatchProcessor(api_key) + +# Process the batch job +input_file_path = "/home/ubuntu/playsglang/input.jsonl" +endpoint = "/v1/chat/completions" +completion_window = "24h" + +# Process the batch job +results = processor.process_batch(input_file_path, endpoint, completion_window) + +# Print the results +print(results) + diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f6cc8677c3..a5d0c0390e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -55,6 +55,8 @@ def __init__( model_overide_args: dict = None, ): self.server_args = server_args + print('tokenizer port:', port_args.tokenizer_port) + print('controller port:', port_args.controller_port) context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) @@ -122,6 +124,8 @@ async def generate_request(self, obj: GenerateReqInput, request=None): obj.post_init() is_single = obj.is_single + print('is_single:', is_single) + print('obj:', obj) if is_single: async for response in self._handle_single_request(obj, request): @@ -151,33 +155,36 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] else: - rid = obj.rid if index is None else obj.rid[index] - input_text = obj.text if index is None else obj.text[index] + use_index = index is not None + rid = obj.rid if not use_index else obj.rid[index] + input_text = obj.text if not use_index else obj.text[index] input_ids = ( self.tokenizer.encode(input_text) if obj.input_ids is None else obj.input_ids ) - if index is not None and obj.input_ids: + if use_index and obj.input_ids: input_ids = obj.input_ids[index] self._validate_input_length(input_ids) + sampling_params = self._get_sampling_params( - obj.sampling_params if index is None else obj.sampling_params[index] + obj.sampling_params if not use_index else obj.sampling_params[index] ) pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data if index is None else obj.image_data[index] + obj.image_data if not use_index else obj.image_data[index] ) return_logprob = ( - obj.return_logprob if index is None else obj.return_logprob[index] + obj.return_logprob if not use_index else obj.return_logprob[index] ) logprob_start_len = ( - obj.logprob_start_len if index is None else obj.logprob_start_len[index] + obj.logprob_start_len if not use_index else obj.logprob_start_len[index] ) top_logprobs_num = ( - obj.top_logprobs_num if index is None else obj.top_logprobs_num[index] + obj.top_logprobs_num if not use_index else obj.top_logprobs_num[index] ) + tokenized_obj = TokenizedGenerateReqInput( rid, input_text, @@ -191,6 +198,7 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals top_logprobs_num, obj.stream, ) + print('tokenized_obj:', tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj) event = asyncio.Event() @@ -231,7 +239,7 @@ async def _handle_batch_request(self, obj, request): continue index = i * parallel_sample_num + j if parallel_sample_num != 1: - # Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 + # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 index += batch_size - 1 - i rid = obj.rid[index] if parallel_sample_num == 1: @@ -335,6 +343,7 @@ async def _get_pixel_values(self, image_data): async def _wait_for_response(self, event, state, obj, rid, request): while True: try: + print('wait for response:') await asyncio.wait_for(event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): @@ -348,6 +357,7 @@ async def _wait_for_response(self, event, state, obj, rid, request): obj.top_logprobs_num, obj.return_text_in_logprobs, ) + print('out in wait for response:', out) if self.server_args.log_requests and state.finished: logger.info(f"in={obj.text}, out={out}") diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index ebb95ea241..71f2e0a154 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -3,9 +3,11 @@ import asyncio import json import os +import uuid +import time from http import HTTPStatus -from fastapi import Request +from fastapi import Request, UploadFile, HTTPException from fastapi.responses import JSONResponse, StreamingResponse from sglang.srt.conversation import ( @@ -32,10 +34,30 @@ ErrorResponse, LogProbs, UsageInfo, + FileRequest, + FileResponse, + BatchRequest, + BatchResponse, ) +from pydantic import ValidationError +from typing import Optional, Dict + chat_template_name = None +# In-memory storage for batch jobs and files +batch_storage: Dict[str, BatchResponse] = {} +file_id_request: Dict[str, FileRequest] = {} +file_id_response: Dict[str, FileResponse] = {} +## map file id to file path in SGlang backend +file_id_storage: Dict[str, str] = {} + + +# backend storage directory +storage_dir = "/home/ubuntu/my_sglang_dev/sglang/python/sglang/srt/openai_api/sglang_oai_storage" + + + def create_error_response( message: str, @@ -90,9 +112,204 @@ def load_chat_template_for_openai_api(chat_template_arg): else: chat_template_name = chat_template_arg +async def v1_files_create(file: UploadFile, purpose: str): + try: + # Read the file content + file_content = await file.read() + + # Create an instance of RequestBody + request_body = FileRequest(file=file_content, purpose=purpose) + + # Save the file to the sglang_oai_storage directory + os.makedirs(storage_dir, exist_ok=True) + file_id = f"backend_input_file-{uuid.uuid4()}" + filename = f"{file_id}.jsonl" + file_path = os.path.join(storage_dir, filename) + + print('file id in creat:', file_id) + + with open(file_path, "wb") as f: + f.write(request_body.file) + + # add info to global file map + file_id_request[file_id] = request_body + file_id_storage[file_id] = file_path + + # Return the response in the required format + response = FileResponse( + id=file_id, + bytes=len(request_body.file), + created_at=int(time.time()), + filename=file.filename, + purpose=request_body.purpose + ) + file_id_response[file_id] = response + + return response + except ValidationError as e: + return {"error": "Invalid input", "details": e.errors()} + + +async def v1_batches(tokenizer_manager, raw_request: Request): + try: + # Parse the JSON body + body = await raw_request.json() + + # Create an instance of BatchRequest + batch_request = BatchRequest(**body) + + # Generate a unique batch ID + batch_id = f"batch_{uuid.uuid4()}" + + # Create an instance of BatchResponse + batch_response = BatchResponse( + id=batch_id, + endpoint=batch_request.endpoint, + input_file_id=batch_request.input_file_id, + completion_window=batch_request.completion_window, + created_at=int(time.time()), + metadata=batch_request.metadata + ) + + batch_storage[batch_id] = batch_response + + # Start processing the batch asynchronously + asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request)) + + # Return the initial batch_response + return batch_response + + + except ValidationError as e: + return {"error": "Invalid input", "details": e.errors()} + except Exception as e: + return {"error": str(e)} + +async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest): + try: + print('batch_id in process_batch in SGlang backend:', batch_id) + # Update the batch status to "in_progress" + batch_storage[batch_id].status = "in_progress" + batch_storage[batch_id].in_progress_at = int(time.time()) + + # Retrieve the input file content + input_file_request = file_id_request.get(batch_request.input_file_id) + if not input_file_request: + raise ValueError("Input file not found") + + # Parse the JSONL file and process each request + input_file_path = file_id_storage.get(batch_request.input_file_id) + with open(input_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + total_requests = len(lines) + completed_requests = 0 + failed_requests = 0 + + all_ret = [] + for line in lines: + request_data = json.loads(line) + if batch_storage[batch_id].endpoint == "/v1/chat/completions": + adapted_request, request = v1_chat_generate_request(request_data, tokenizer_manager,from_file=True) + elif batch_storage[batch_id].endpoint == "/v1/completions": + pass + + + try: + + print('adapted_request in SGlang in batch:', adapted_request) + print('request_data in SGlang in batch:', request_data) + + ret = await tokenizer_manager.generate_request(adapted_request).__anext__() + print('ret in SGlang:', ret) + + + if not isinstance(ret, list): + ret = [ret] + + response = v1_chat_generate_response(request, ret, to_file=True) + response_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": request_data.get("custom_id"), + "response": response, + "error": None + } + all_ret.append(response_json) + + completed_requests += 1 + print('success in SGlang:', ret) + except Exception as e: + error_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": request_data.get("custom_id"), + "response": None, + "error": {"message": str(e)} + } + all_ret.append(error_json) + failed_requests += 1 + continue + print('all_ret in SGlang:', all_ret) + + + # Write results to a new file + output_file_id = f"backend_result_file-{uuid.uuid4()}" + output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl") + print('output file id in SGlang:', output_file_id) + with open(output_file_path, "w", encoding="utf-8") as f: + for ret in all_ret: + f.write(json.dumps(ret) + "\n") + + # Update batch response with output file information + batch_storage[batch_id].output_file_id = output_file_id + file_id_storage[output_file_id] = output_file_path + # Update batch status to "completed" + batch_storage[batch_id].status = "completed" + batch_storage[batch_id].completed_at = int(time.time()) + batch_storage[batch_id].request_counts = { + "total": total_requests, + "completed": completed_requests, + "failed": failed_requests + } + + except Exception as e: + print('error in SGlang:', e) + # Update batch status to "failed" + batch_storage[batch_id].status = "failed" + batch_storage[batch_id].failed_at = int(time.time()) + batch_storage[batch_id].errors = {"message": str(e)} + + +async def v1_retrieve_batch(batch_id: str): + # Retrieve the batch job from the in-memory storage + batch_response = batch_storage.get(batch_id) + if batch_response is None: + raise HTTPException(status_code=404, detail="Batch not found") + + return batch_response + +async def v1_retrieve_file(file_id: str): + # Retrieve the batch job from the in-memory storage + file_response = file_id_response.get(file_id) + if file_response is None: + raise HTTPException(status_code=404, detail="File not found") + return file_response + +async def v1_retrieve_file_content(file_id: str): + file_pth= file_id_storage.get(file_id) + if not file_pth or not os.path.exists(file_pth): + raise HTTPException(status_code=404, detail="File not found") + + def iter_file(): + with open(file_pth, mode="rb") as file_like: + yield from file_like + + return StreamingResponse(iter_file(), media_type="application/octet-stream") async def v1_completions(tokenizer_manager, raw_request: Request): + print('raw request in v1_completions of adapter.py', raw_request) request_json = await raw_request.json() + print('in v1_completions of adapter.py') + print(request_json) request = CompletionRequest(**request_json) adapted_request = GenerateReqInput( @@ -253,11 +470,35 @@ async def generate_stream_resp(): return response - -async def v1_chat_completions(tokenizer_manager, raw_request: Request): - request_json = await raw_request.json() - request = ChatCompletionRequest(**request_json) - +def v1_chat_generate_request(request_json, tokenizer_manager, from_file=False): + if from_file: + body = request_json["body"] + request_data = { + "messages": body["messages"], + "model": body["model"], + "frequency_penalty": body.get("frequency_penalty", 0.0), + "logit_bias": body.get("logit_bias", None), + "logprobs": body.get("logprobs", False), + "top_logprobs": body.get("top_logprobs", None), + "max_tokens": body.get("max_tokens", 16), + "n": body.get("n", 1), + "presence_penalty": body.get("presence_penalty", 0.0), + "response_format": body.get("response_format", None), + "seed": body.get("seed", None), + "stop": body.get("stop", []), + "stream": body.get("stream", False), + "temperature": body.get("temperature", 0.7), + "top_p": body.get("top_p", 1.0), + "user": body.get("user", None), + "regex": body.get("regex", None) + } + request = ChatCompletionRequest(**request_data) + ## TODO collect custom id for reorder + else: + request = ChatCompletionRequest(**request_json) + + print('request messages in v1_chat_completions:', request.messages) + # Prep the data needed for the underlying GenerateReqInput: # - prompt: The full prompt string. # - stop: Custom stop tokens. @@ -302,7 +543,75 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): }, stream=request.stream, ) + return adapted_request, request + + +def v1_chat_generate_response(request, ret, to_file=False): + choices = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for idx, ret_item in enumerate(ret): + prompt_tokens = ret_item["meta_info"]["prompt_tokens"] + completion_tokens = ret_item["meta_info"]["completion_tokens"] + + if to_file: + choice_data = { + "index": idx, + "message": {"role": "assistant", "content": ret_item["text"]}, + "logprobs": None, + "finish_reason": ret_item["meta_info"]["finish_reason"], + } + else: + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage(role="assistant", content=ret_item["text"]), + finish_reason=ret_item["meta_info"]["finish_reason"], + ) + + choices.append(choice_data) + total_prompt_tokens = prompt_tokens + total_completion_tokens += completion_tokens + + if to_file: + response = { + "status_code": 200, + "request_id": ret[0]["meta_info"]["id"], + "body": { + "id": ret[0]["meta_info"]["id"], + "object": "chat.completion", + "created": int(time.time()), + "model": request.model, + "choices": choices, + "usage": { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_prompt_tokens + total_completion_tokens, + }, + "system_fingerprint": None + } + } + else: + response = ChatCompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + ), + ) + return response + +async def v1_chat_completions(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + + print('request json in v1_chat_completions:', request_json) + + adapted_request, request = v1_chat_generate_request(request_json, tokenizer_manager, from_file=False) + if adapted_request.stream: async def generate_stream_resp(): @@ -355,6 +664,8 @@ async def generate_stream_resp(): # Non-streaming response. try: + print('adapted_request in v1_chat_completions:', adapted_request) + print('raw_request in v1_chat_completions:', raw_request) ret = await tokenizer_manager.generate_request( adapted_request, raw_request ).__anext__() @@ -363,34 +674,8 @@ async def generate_stream_resp(): if not isinstance(ret, list): ret = [ret] - choices = [] - total_prompt_tokens = 0 - total_completion_tokens = 0 - - for idx, ret_item in enumerate(ret): - prompt_tokens = ret_item["meta_info"]["prompt_tokens"] - completion_tokens = ret_item["meta_info"]["completion_tokens"] - - choice_data = ChatCompletionResponseChoice( - index=idx, - message=ChatMessage(role="assistant", content=ret_item["text"]), - finish_reason=ret_item["meta_info"]["finish_reason"], - ) - - choices.append(choice_data) - total_prompt_tokens = prompt_tokens - total_completion_tokens += completion_tokens - - response = ChatCompletionResponse( - id=ret[0]["meta_info"]["id"], - model=request.model, - choices=choices, - usage=UsageInfo( - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens, - ), - ) + + response = v1_chat_generate_response(request, ret) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index b91179203d..da0d282d39 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -45,6 +45,52 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class FileRequest(BaseModel): + # https://platform.openai.com/docs/api-reference/files/create + file: bytes # The File object (not file name) to be uploaded + purpose: str = "batch" # The intended purpose of the uploaded file, default is "batch" + +class FileResponse(BaseModel): + id: str + object: str = "file" + bytes: int + created_at: int + filename: str + purpose: str + + +class BatchRequest(BaseModel): + input_file_id: str # The ID of an uploaded file that contains requests for the new batch + endpoint: str # The endpoint to be used for all requests in the batch + completion_window: str # The time frame within which the batch should be processed + metadata: Optional[dict] = None # Optional custom metadata for the batch + +class BatchResponse(BaseModel): + id: str + object: str = "batch" + endpoint: str + errors: Optional[dict] = None + input_file_id: str + completion_window: str + status: str = "validating" + output_file_id: Optional[str] = None + error_file_id: Optional[str] = None + created_at: int + in_progress_at: Optional[int] = None + expires_at: Optional[int] = None + finalizing_at: Optional[int] = None + completed_at: Optional[int] = None + failed_at: Optional[int] = None + expired_at: Optional[int] = None + cancelling_at: Optional[int] = None + cancelled_at: Optional[int] = None + request_counts: dict = { + "total": 0, + "completed": 0, + "failed": 0 + } + metadata: Optional[dict] = None + class CompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ac62f89ae6..1ceff5f70f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -43,6 +43,11 @@ load_chat_template_for_openai_api, v1_chat_completions, v1_completions, + v1_files_create, + v1_batches, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content ) from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.server_args import PortArgs, ServerArgs @@ -55,6 +60,7 @@ set_ulimit, ) from sglang.utils import get_exception_traceback +from fastapi import FastAPI, Request, Form, UploadFile, File logger = logging.getLogger(__name__) @@ -138,6 +144,32 @@ async def openai_v1_completions(raw_request: Request): async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + print("openai_v1_files") + return await v1_files_create(file, purpose) + +## support /v1/batches +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(tokenizer_manager,raw_request) + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + print("openai_v1_files retrieve") + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + +## for "GET /v1/files/backend_result_file-29c7b5de-8ca6-4e91-9142-ce157a967475/content +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + @app.get("/v1/models") def available_models(): From 12e5492433e7654b118cd5f3fcb64c95f50556df Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Mon, 22 Jul 2024 17:59:46 +0000 Subject: [PATCH 2/8] finish all of the component including files batch API --- .../{openai_batch.py => openai_batch_chat.py} | 55 +++++------- examples/usage/openai_batch_complete.py | 85 +++++++++++++++++++ 2 files changed, 106 insertions(+), 34 deletions(-) rename examples/usage/{openai_batch.py => openai_batch_chat.py} (64%) create mode 100644 examples/usage/openai_batch_complete.py diff --git a/examples/usage/openai_batch.py b/examples/usage/openai_batch_chat.py similarity index 64% rename from examples/usage/openai_batch.py rename to examples/usage/openai_batch_chat.py index fb0c9c343f..7e85051860 100644 --- a/examples/usage/openai_batch.py +++ b/examples/usage/openai_batch_chat.py @@ -4,6 +4,7 @@ import json import os + class OpenAIBatchProcessor: def __init__(self, api_key): # client = OpenAI(api_key=api_key) @@ -12,81 +13,68 @@ def __init__(self, api_key): self.client = client def process_batch(self, input_file_path, endpoint, completion_window): - - ## # Chat completion - response = self.client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals and culture."}, - ], - temperature=0.8, - max_tokens=64, - ) - print(response) - # Upload the input file with open(input_file_path, "rb") as file: - uploaded_file = self.client.files.create( - file=file, - purpose="batch" - ) - print('file response:', uploaded_file) - print('file id:', uploaded_file.id) + uploaded_file = self.client.files.create(file=file, purpose="batch") # Create the batch job batch_job = self.client.batches.create( input_file_id=uploaded_file.id, endpoint=endpoint, - completion_window=completion_window + completion_window=completion_window, ) - - print('batch job:', batch_job) # Monitor the batch job status while batch_job.status not in ["completed", "failed", "cancelled"]: time.sleep(3) # Wait for 3 seconds before checking the status again - print(f"Batch job status: {batch_job.status}...trying again in 3 seconds...") + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) batch_job = self.client.batches.retrieve(batch_job.id) + # Check the batch job status and errors + if batch_job.status == "failed": + print(f"Batch job failed with status: {batch_job.status}") + print(f"Batch job errors: {batch_job.errors}") + return None - - # If the batch job is completed, process the results if batch_job.status == "completed": - + # print result of batch job - print('batch', batch_job.request_counts) - + print("batch", batch_job.request_counts) + result_file_id = batch_job.output_file_id # Retrieve the file content from the server file_response = self.client.files.content(result_file_id) result_content = file_response.read() # Read the content of the file # Save the content to a local file - result_file_name = "/home/ubuntu/my_sglang_dev/sglang/examples/usage/batch_job_results.jsonl" + result_file_name = "batch_job_chat_results.jsonl" with open(result_file_name, "wb") as file: file.write(result_content) # Write the binary content to the file - print('read result:', result_content) # Load data from the saved JSONL file results = [] with open(result_file_name, "r", encoding="utf-8") as file: for line in file: - json_object = json.loads(line.strip()) # Parse each line as a JSON object + json_object = json.loads( + line.strip() + ) # Parse each line as a JSON object results.append(json_object) return results else: print(f"Batch job failed with status: {batch_job.status}") return None - + + # Initialize the OpenAIBatchProcessor api_key = os.environ.get("OPENAI_API_KEY") processor = OpenAIBatchProcessor(api_key) # Process the batch job -input_file_path = "/home/ubuntu/playsglang/input.jsonl" +input_file_path = "input.jsonl" endpoint = "/v1/chat/completions" completion_window = "24h" @@ -95,4 +83,3 @@ def process_batch(self, input_file_path, endpoint, completion_window): # Print the results print(results) - diff --git a/examples/usage/openai_batch_complete.py b/examples/usage/openai_batch_complete.py new file mode 100644 index 0000000000..4be64056ca --- /dev/null +++ b/examples/usage/openai_batch_complete.py @@ -0,0 +1,85 @@ +from openai import OpenAI +import openai +import time +import json +import os + + +class OpenAIBatchProcessor: + def __init__(self, api_key): + # client = OpenAI(api_key=api_key) + client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + + self.client = client + + def process_batch(self, input_file_path, endpoint, completion_window): + + # Upload the input file + with open(input_file_path, "rb") as file: + uploaded_file = self.client.files.create(file=file, purpose="batch") + + # Create the batch job + batch_job = self.client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window, + ) + + # Monitor the batch job status + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) # Wait for 3 seconds before checking the status again + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + batch_job = self.client.batches.retrieve(batch_job.id) + + # Check the batch job status and errors + if batch_job.status == "failed": + print(f"Batch job failed with status: {batch_job.status}") + print(f"Batch job errors: {batch_job.errors}") + return None + + # If the batch job is completed, process the results + if batch_job.status == "completed": + + # print result of batch job + print("batch", batch_job.request_counts) + + result_file_id = batch_job.output_file_id + # Retrieve the file content from the server + file_response = self.client.files.content(result_file_id) + result_content = file_response.read() # Read the content of the file + + # Save the content to a local file + result_file_name = "batch_job_complete_results.jsonl" + with open(result_file_name, "wb") as file: + file.write(result_content) # Write the binary content to the file + # Load data from the saved JSONL file + results = [] + with open(result_file_name, "r", encoding="utf-8") as file: + for line in file: + json_object = json.loads( + line.strip() + ) # Parse each line as a JSON object + results.append(json_object) + + return results + else: + print(f"Batch job failed with status: {batch_job.status}") + return None + + +# Initialize the OpenAIBatchProcessor +api_key = os.environ.get("OPENAI_API_KEY") +processor = OpenAIBatchProcessor(api_key) + +# Process the batch job +input_file_path = "input_complete.jsonl" +endpoint = "/v1/completions" +completion_window = "24h" + +# Process the batch job +results = processor.process_batch(input_file_path, endpoint, completion_window) + +# Print the results +print(results) From d2270cab110b6f5e0de40f5e396f644249cbaa64 Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Mon, 22 Jul 2024 18:34:11 +0000 Subject: [PATCH 3/8] finish all of the component including files batch API --- examples/usage/openai_batch_chat.py | 7 ++++--- examples/usage/openai_batch_complete.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/usage/openai_batch_chat.py b/examples/usage/openai_batch_chat.py index 7e85051860..cffa50c67b 100644 --- a/examples/usage/openai_batch_chat.py +++ b/examples/usage/openai_batch_chat.py @@ -1,8 +1,9 @@ -from openai import OpenAI -import openai -import time import json import os +import time + +import openai +from openai import OpenAI class OpenAIBatchProcessor: diff --git a/examples/usage/openai_batch_complete.py b/examples/usage/openai_batch_complete.py index 4be64056ca..3cf2ede0b1 100644 --- a/examples/usage/openai_batch_complete.py +++ b/examples/usage/openai_batch_complete.py @@ -1,8 +1,9 @@ -from openai import OpenAI -import openai -import time import json import os +import time + +import openai +from openai import OpenAI class OpenAIBatchProcessor: From 9510bc2c2962926c3f7bd682716a7d2a5064ae7b Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Mon, 22 Jul 2024 18:45:51 +0000 Subject: [PATCH 4/8] finish all of the component including files batch API new --- .pre-commit-config.yaml | 2 +- .../sglang/srt/managers/tokenizer_manager.py | 9 +- python/sglang/srt/openai_api/adapter.py | 359 +++++++++--------- python/sglang/srt/openai_api/protocol.py | 17 +- python/sglang/srt/server.py | 23 +- python/sglang/srt/server_args.py | 7 + 6 files changed, 222 insertions(+), 195 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 393c999d28..2fa1254a66 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,6 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: stable + rev: 24.4.2 hooks: - id: black diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a5d0c0390e..d58523d927 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -55,8 +55,6 @@ def __init__( model_overide_args: dict = None, ): self.server_args = server_args - print('tokenizer port:', port_args.tokenizer_port) - print('controller port:', port_args.controller_port) context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) @@ -71,6 +69,7 @@ def __init__( trust_remote_code=server_args.trust_remote_code, model_overide_args=model_overide_args, ) + if server_args.context_length is not None: self.context_len = server_args.context_length else: @@ -124,8 +123,6 @@ async def generate_request(self, obj: GenerateReqInput, request=None): obj.post_init() is_single = obj.is_single - print('is_single:', is_single) - print('obj:', obj) if is_single: async for response in self._handle_single_request(obj, request): @@ -184,7 +181,6 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals obj.top_logprobs_num if not use_index else obj.top_logprobs_num[index] ) - tokenized_obj = TokenizedGenerateReqInput( rid, input_text, @@ -198,7 +194,6 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals top_logprobs_num, obj.stream, ) - print('tokenized_obj:', tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj) event = asyncio.Event() @@ -343,7 +338,6 @@ async def _get_pixel_values(self, image_data): async def _wait_for_response(self, event, state, obj, rid, request): while True: try: - print('wait for response:') await asyncio.wait_for(event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): @@ -357,7 +351,6 @@ async def _wait_for_response(self, event, state, obj, rid, request): obj.top_logprobs_num, obj.return_text_in_logprobs, ) - print('out in wait for response:', out) if self.server_args.log_requests and state.finished: logger.info(f"in={obj.text}, out={out}") diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 71f2e0a154..1628c8fa93 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -3,12 +3,14 @@ import asyncio import json import os -import uuid import time +import uuid from http import HTTPStatus +from typing import Dict, Optional -from fastapi import Request, UploadFile, HTTPException +from fastapi import HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import ValidationError from sglang.srt.conversation import ( Conversation, @@ -19,6 +21,8 @@ ) from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.openai_api.protocol import ( + BatchRequest, + BatchResponse, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, @@ -32,17 +36,12 @@ CompletionStreamResponse, DeltaMessage, ErrorResponse, - LogProbs, - UsageInfo, FileRequest, FileResponse, - BatchRequest, - BatchResponse, + LogProbs, + UsageInfo, ) -from pydantic import ValidationError -from typing import Optional, Dict - chat_template_name = None # In-memory storage for batch jobs and files @@ -54,9 +53,7 @@ # backend storage directory -storage_dir = "/home/ubuntu/my_sglang_dev/sglang/python/sglang/srt/openai_api/sglang_oai_storage" - - +storage_dir = None def create_error_response( @@ -112,55 +109,54 @@ def load_chat_template_for_openai_api(chat_template_arg): else: chat_template_name = chat_template_arg -async def v1_files_create(file: UploadFile, purpose: str): + +async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None): try: + global storage_dir + if file_storage_pth: + storage_dir = file_storage_pth # Read the file content file_content = await file.read() - + # Create an instance of RequestBody request_body = FileRequest(file=file_content, purpose=purpose) - + # Save the file to the sglang_oai_storage directory os.makedirs(storage_dir, exist_ok=True) file_id = f"backend_input_file-{uuid.uuid4()}" filename = f"{file_id}.jsonl" file_path = os.path.join(storage_dir, filename) - - print('file id in creat:', file_id) - + with open(file_path, "wb") as f: f.write(request_body.file) - + # add info to global file map file_id_request[file_id] = request_body file_id_storage[file_id] = file_path - + # Return the response in the required format - response = FileResponse( - id=file_id, - bytes=len(request_body.file), - created_at=int(time.time()), - filename=file.filename, - purpose=request_body.purpose - ) + response = FileResponse( + id=file_id, + bytes=len(request_body.file), + created_at=int(time.time()), + filename=file.filename, + purpose=request_body.purpose, + ) file_id_response[file_id] = response - + return response except ValidationError as e: return {"error": "Invalid input", "details": e.errors()} - + async def v1_batches(tokenizer_manager, raw_request: Request): try: - # Parse the JSON body body = await raw_request.json() - - # Create an instance of BatchRequest + batch_request = BatchRequest(**body) - - # Generate a unique batch ID + batch_id = f"batch_{uuid.uuid4()}" - + # Create an instance of BatchResponse batch_response = BatchResponse( id=batch_id, @@ -168,115 +164,116 @@ async def v1_batches(tokenizer_manager, raw_request: Request): input_file_id=batch_request.input_file_id, completion_window=batch_request.completion_window, created_at=int(time.time()), - metadata=batch_request.metadata + metadata=batch_request.metadata, ) - + batch_storage[batch_id] = batch_response - + # Start processing the batch asynchronously asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request)) - + # Return the initial batch_response return batch_response - - + except ValidationError as e: return {"error": "Invalid input", "details": e.errors()} except Exception as e: return {"error": str(e)} + async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest): try: - print('batch_id in process_batch in SGlang backend:', batch_id) # Update the batch status to "in_progress" batch_storage[batch_id].status = "in_progress" batch_storage[batch_id].in_progress_at = int(time.time()) - + # Retrieve the input file content input_file_request = file_id_request.get(batch_request.input_file_id) if not input_file_request: raise ValueError("Input file not found") - + # Parse the JSONL file and process each request input_file_path = file_id_storage.get(batch_request.input_file_id) with open(input_file_path, "r", encoding="utf-8") as f: lines = f.readlines() - + total_requests = len(lines) completed_requests = 0 failed_requests = 0 - + all_ret = [] + end_point = batch_storage[batch_id].endpoint for line in lines: request_data = json.loads(line) - if batch_storage[batch_id].endpoint == "/v1/chat/completions": - adapted_request, request = v1_chat_generate_request(request_data, tokenizer_manager,from_file=True) - elif batch_storage[batch_id].endpoint == "/v1/completions": - pass - + if end_point == "/v1/chat/completions": + adapted_request, request = v1_chat_generate_request( + request_data, tokenizer_manager, from_file=True + ) + elif end_point == "/v1/completions": + adapted_request, request = v1_generate_request( + request_data, from_file=True + ) try: - - print('adapted_request in SGlang in batch:', adapted_request) - print('request_data in SGlang in batch:', request_data) - - ret = await tokenizer_manager.generate_request(adapted_request).__anext__() - print('ret in SGlang:', ret) - - + ret = await tokenizer_manager.generate_request( + adapted_request + ).__anext__() + if not isinstance(ret, list): ret = [ret] - - response = v1_chat_generate_response(request, ret, to_file=True) + if end_point == "/v1/chat/completions": + response = v1_chat_generate_response(request, ret, to_file=True) + else: + response = v1_generate_response(request, ret, to_file=True) + response_json = { "id": f"batch_req_{uuid.uuid4()}", "custom_id": request_data.get("custom_id"), "response": response, - "error": None + "error": None, } all_ret.append(response_json) - + completed_requests += 1 - print('success in SGlang:', ret) except Exception as e: error_json = { "id": f"batch_req_{uuid.uuid4()}", "custom_id": request_data.get("custom_id"), "response": None, - "error": {"message": str(e)} + "error": {"message": str(e)}, } all_ret.append(error_json) failed_requests += 1 continue - print('all_ret in SGlang:', all_ret) - - + # Write results to a new file output_file_id = f"backend_result_file-{uuid.uuid4()}" + global storage_dir output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl") - print('output file id in SGlang:', output_file_id) with open(output_file_path, "w", encoding="utf-8") as f: for ret in all_ret: f.write(json.dumps(ret) + "\n") - + # Update batch response with output file information - batch_storage[batch_id].output_file_id = output_file_id + retrieve_batch = batch_storage[batch_id] + retrieve_batch.output_file_id = output_file_id file_id_storage[output_file_id] = output_file_path # Update batch status to "completed" - batch_storage[batch_id].status = "completed" - batch_storage[batch_id].completed_at = int(time.time()) - batch_storage[batch_id].request_counts = { + retrieve_batch.status = "completed" + retrieve_batch.completed_at = int(time.time()) + retrieve_batch.request_counts = { "total": total_requests, "completed": completed_requests, - "failed": failed_requests + "failed": failed_requests, } - + except Exception as e: - print('error in SGlang:', e) + print("error in SGlang:", e) # Update batch status to "failed" - batch_storage[batch_id].status = "failed" - batch_storage[batch_id].failed_at = int(time.time()) - batch_storage[batch_id].errors = {"message": str(e)} + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "failed" + retrieve_batch.failed_at = int(time.time()) + retrieve_batch.errors = {"message": str(e)} async def v1_retrieve_batch(batch_id: str): @@ -284,9 +281,10 @@ async def v1_retrieve_batch(batch_id: str): batch_response = batch_storage.get(batch_id) if batch_response is None: raise HTTPException(status_code=404, detail="Batch not found") - + return batch_response + async def v1_retrieve_file(file_id: str): # Retrieve the batch job from the in-memory storage file_response = file_id_response.get(file_id) @@ -294,8 +292,9 @@ async def v1_retrieve_file(file_id: str): raise HTTPException(status_code=404, detail="File not found") return file_response + async def v1_retrieve_file_content(file_id: str): - file_pth= file_id_storage.get(file_id) + file_pth = file_id_storage.get(file_id) if not file_pth or not os.path.exists(file_pth): raise HTTPException(status_code=404, detail="File not found") @@ -305,12 +304,13 @@ def iter_file(): return StreamingResponse(iter_file(), media_type="application/octet-stream") -async def v1_completions(tokenizer_manager, raw_request: Request): - print('raw request in v1_completions of adapter.py', raw_request) - request_json = await raw_request.json() - print('in v1_completions of adapter.py') - print(request_json) - request = CompletionRequest(**request_json) + +def v1_generate_request(request_json, from_file=False): + if from_file: + body = request_json["body"] + request = CompletionRequest(**body) + else: + request = CompletionRequest(**request_json) adapted_request = GenerateReqInput( text=request.prompt, @@ -330,6 +330,95 @@ async def v1_completions(tokenizer_manager, raw_request: Request): return_text_in_logprobs=True, stream=request.stream, ) + return adapted_request, request + + +def v1_generate_response(request, ret, to_file=False): + choices = [] + + for idx, ret_item in enumerate(ret): + text = ret_item["text"] + + if request.echo: + text = request.prompt + text + + if request.logprobs: + if request.echo: + prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"] + prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"] + else: + prefill_token_logprobs = None + prefill_top_logprobs = None + + logprobs = to_openai_style_logprobs( + prefill_token_logprobs=prefill_token_logprobs, + prefill_top_logprobs=prefill_top_logprobs, + decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"], + decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"], + ) + else: + logprobs = None + + if to_file: + ## to make the choise data json serializable + choice_data = { + "index": idx, + "text": text, + "logprobs": logprobs, + "finish_reason": ret_item["meta_info"]["finish_reason"], + } + else: + choice_data = CompletionResponseChoice( + index=idx, + text=text, + logprobs=logprobs, + finish_reason=ret_item["meta_info"]["finish_reason"], + ) + + choices.append(choice_data) + + if to_file: + response = { + "status_code": 200, + "request_id": ret[0]["meta_info"]["id"], + "body": { + ## remain the same but if needed we can change that + "id": ret[0]["meta_info"]["id"], + "object": "text_completion", + "created": int(time.time()), + "model": request.model, + "choices": choices, + "usage": { + "prompt_tokens": ret[0]["meta_info"]["prompt_tokens"], + "completion_tokens": sum( + item["meta_info"]["completion_tokens"] for item in ret + ), + "total_tokens": ret[0]["meta_info"]["prompt_tokens"] + + sum(item["meta_info"]["completion_tokens"] for item in ret), + }, + "system_fingerprint": None, + }, + } + else: + response = CompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=ret[0]["meta_info"]["prompt_tokens"], + completion_tokens=sum( + item["meta_info"]["completion_tokens"] for item in ret + ), + total_tokens=ret[0]["meta_info"]["prompt_tokens"] + + sum(item["meta_info"]["completion_tokens"] for item in ret), + ), + ) + return response + + +async def v1_completions(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + adapted_request, request = v1_generate_request(request_json, from_file=False) if adapted_request.stream: @@ -420,85 +509,18 @@ async def generate_stream_resp(): if not isinstance(ret, list): ret = [ret] - choices = [] - - for idx, ret_item in enumerate(ret): - text = ret_item["text"] - - if request.echo: - text = request.prompt + text - - if request.logprobs: - if request.echo: - prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"] - prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"] - else: - prefill_token_logprobs = None - prefill_top_logprobs = None - - logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"], - decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"], - ) - else: - logprobs = None - - choice_data = CompletionResponseChoice( - index=idx, - text=text, - logprobs=logprobs, - finish_reason=ret_item["meta_info"]["finish_reason"], - ) - - choices.append(choice_data) - - response = CompletionResponse( - id=ret[0]["meta_info"]["id"], - model=request.model, - choices=choices, - usage=UsageInfo( - prompt_tokens=ret[0]["meta_info"]["prompt_tokens"], - completion_tokens=sum( - item["meta_info"]["completion_tokens"] for item in ret - ), - total_tokens=ret[0]["meta_info"]["prompt_tokens"] - + sum(item["meta_info"]["completion_tokens"] for item in ret), - ), - ) + response = v1_generate_response(request, ret) return response + def v1_chat_generate_request(request_json, tokenizer_manager, from_file=False): if from_file: body = request_json["body"] - request_data = { - "messages": body["messages"], - "model": body["model"], - "frequency_penalty": body.get("frequency_penalty", 0.0), - "logit_bias": body.get("logit_bias", None), - "logprobs": body.get("logprobs", False), - "top_logprobs": body.get("top_logprobs", None), - "max_tokens": body.get("max_tokens", 16), - "n": body.get("n", 1), - "presence_penalty": body.get("presence_penalty", 0.0), - "response_format": body.get("response_format", None), - "seed": body.get("seed", None), - "stop": body.get("stop", []), - "stream": body.get("stream", False), - "temperature": body.get("temperature", 0.7), - "top_p": body.get("top_p", 1.0), - "user": body.get("user", None), - "regex": body.get("regex", None) - } - request = ChatCompletionRequest(**request_data) - ## TODO collect custom id for reorder + request = ChatCompletionRequest(**body) else: request = ChatCompletionRequest(**request_json) - - print('request messages in v1_chat_completions:', request.messages) - + # Prep the data needed for the underlying GenerateReqInput: # - prompt: The full prompt string. # - stop: Custom stop tokens. @@ -556,6 +578,7 @@ def v1_chat_generate_response(request, ret, to_file=False): completion_tokens = ret_item["meta_info"]["completion_tokens"] if to_file: + ## to make the choise data json serializable choice_data = { "index": idx, "message": {"role": "assistant", "content": ret_item["text"]}, @@ -588,8 +611,8 @@ def v1_chat_generate_response(request, ret, to_file=False): "completion_tokens": total_completion_tokens, "total_tokens": total_prompt_tokens + total_completion_tokens, }, - "system_fingerprint": None - } + "system_fingerprint": None, + }, } else: response = ChatCompletionResponse( @@ -607,11 +630,11 @@ def v1_chat_generate_response(request, ret, to_file=False): async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() - - print('request json in v1_chat_completions:', request_json) - - adapted_request, request = v1_chat_generate_request(request_json, tokenizer_manager, from_file=False) - + + adapted_request, request = v1_chat_generate_request( + request_json, tokenizer_manager, from_file=False + ) + if adapted_request.stream: async def generate_stream_resp(): @@ -664,8 +687,6 @@ async def generate_stream_resp(): # Non-streaming response. try: - print('adapted_request in v1_chat_completions:', adapted_request) - print('raw_request in v1_chat_completions:', raw_request) ret = await tokenizer_manager.generate_request( adapted_request, raw_request ).__anext__() @@ -674,7 +695,7 @@ async def generate_stream_resp(): if not isinstance(ret, list): ret = [ret] - + response = v1_chat_generate_response(request, ret) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index da0d282d39..c2c35f18e0 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -48,7 +48,10 @@ class UsageInfo(BaseModel): class FileRequest(BaseModel): # https://platform.openai.com/docs/api-reference/files/create file: bytes # The File object (not file name) to be uploaded - purpose: str = "batch" # The intended purpose of the uploaded file, default is "batch" + purpose: str = ( + "batch" # The intended purpose of the uploaded file, default is "batch" + ) + class FileResponse(BaseModel): id: str @@ -60,11 +63,14 @@ class FileResponse(BaseModel): class BatchRequest(BaseModel): - input_file_id: str # The ID of an uploaded file that contains requests for the new batch + input_file_id: ( + str # The ID of an uploaded file that contains requests for the new batch + ) endpoint: str # The endpoint to be used for all requests in the batch completion_window: str # The time frame within which the batch should be processed metadata: Optional[dict] = None # Optional custom metadata for the batch + class BatchResponse(BaseModel): id: str object: str = "batch" @@ -84,13 +90,10 @@ class BatchResponse(BaseModel): expired_at: Optional[int] = None cancelling_at: Optional[int] = None cancelled_at: Optional[int] = None - request_counts: dict = { - "total": 0, - "completed": 0, - "failed": 0 - } + request_counts: dict = {"total": 0, "completed": 0, "failed": 0} metadata: Optional[dict] = None + class CompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 1ceff5f70f..e7f4689f7c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -23,7 +23,7 @@ import requests import uvicorn import uvloop -from fastapi import FastAPI, Request +from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint @@ -41,13 +41,13 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, + v1_batches, v1_chat_completions, v1_completions, v1_files_create, - v1_batches, v1_retrieve_batch, v1_retrieve_file, - v1_retrieve_file_content + v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.server_args import PortArgs, ServerArgs @@ -60,7 +60,6 @@ set_ulimit, ) from sglang.utils import get_exception_traceback -from fastapi import FastAPI, Request, Form, UploadFile, File logger = logging.getLogger(__name__) @@ -144,27 +143,30 @@ async def openai_v1_completions(raw_request: Request): async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) + @app.post("/v1/files") async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - print("openai_v1_files") - return await v1_files_create(file, purpose) + return await v1_files_create( + file, purpose, tokenizer_manager.server_args.file_storage_pth + ) + -## support /v1/batches @app.post("/v1/batches") async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager,raw_request) + return await v1_batches(tokenizer_manager, raw_request) + @app.get("/v1/batches/{batch_id}") async def retrieve_batch(batch_id: str): return await v1_retrieve_batch(batch_id) + @app.get("/v1/files/{file_id}") async def retrieve_file(file_id: str): - print("openai_v1_files retrieve") # https://platform.openai.com/docs/api-reference/files/retrieve return await v1_retrieve_file(file_id) -## for "GET /v1/files/backend_result_file-29c7b5de-8ca6-4e91-9142-ce157a967475/content + @app.get("/v1/files/{file_id}/content") async def retrieve_file_content(file_id: str): # https://platform.openai.com/docs/api-reference/files/retrieve-contents @@ -222,6 +224,7 @@ def launch_server( if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template) + _set_global_server_args(server_args) # Allocate ports diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 264985fb5f..054f114fb7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -44,6 +44,7 @@ class ServerArgs: # Other api_key: str = "" + file_storage_pth: str = "SGlang_storage" # Data parallelism dp_size: int = 1 @@ -260,6 +261,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.api_key, help="Set API key of the server.", ) + parser.add_argument( + "--file-storage-pth", + type=str, + default=ServerArgs.file_storage_pth, + help="The path of the file storage in backend.", + ) # Data parallelism parser.add_argument( From af67c106fe43191cf5b81ad79cfb63e2059f0a64 Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Sun, 28 Jul 2024 17:35:50 +0000 Subject: [PATCH 5/8] finish alactuall batch --- examples/usage/openai_parallel_sample.py | 37 ++ python/sglang/srt/managers/io_struct.py | 27 +- .../sglang/srt/managers/tokenizer_manager.py | 20 +- python/sglang/srt/openai_api/adapter.py | 359 ++++++++++-------- 4 files changed, 273 insertions(+), 170 deletions(-) diff --git a/examples/usage/openai_parallel_sample.py b/examples/usage/openai_parallel_sample.py index d2d1e406ff..0d3a372b4d 100644 --- a/examples/usage/openai_parallel_sample.py +++ b/examples/usage/openai_parallel_sample.py @@ -13,6 +13,17 @@ print(response) +# Text completion +response = client.completions.create( + model="default", + prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", + n=1, + temperature=0.8, + max_tokens=32, +) +print(response) + + # Text completion response = client.completions.create( model="default", @@ -24,6 +35,17 @@ print(response) +# Text completion +response = client.completions.create( + model="default", + prompt=["The name of the famous soccer player is"], + n=1, + temperature=0.8, + max_tokens=128, +) +print(response) + + # Text completion response = client.completions.create( model="default", @@ -60,6 +82,21 @@ ) print(response) +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=64, + logprobs=True, + n=1, +) +print(response) + + # Chat completion response = client.chat.completions.create( model="default", diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8875994f14..99fc631995 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -40,7 +40,10 @@ def post_init(self): self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") - if self.sampling_params.get("n", 1) != 1: + if ( + isinstance(self.sampling_params, dict) + and self.sampling_params.get("n", 1) != 1 + ): is_single = False else: if self.text is not None: @@ -61,8 +64,26 @@ def post_init(self): if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: - - parallel_sample_num = self.sampling_params.get("n", 1) + parallel_sample_num_list = [] + if isinstance(self.sampling_params, dict): + parallel_sample_num = self.sampling_params.get("n", 1) + elif isinstance(self.sampling_params, list): + for sp in self.sampling_params: + parallel_sample_num = sp.get("n", 1) + parallel_sample_num_list.append(parallel_sample_num) + parallel_sample_num = max(parallel_sample_num_list) + all_equal = all( + element == parallel_sample_num + for element in parallel_sample_num_list + ) + if parallel_sample_num > 1 and (not all_equal): + ## TODO cope with the case that the parallel_sample_num is different for different samples + raise ValueError( + "The parallel_sample_num should be the same for all samples in sample params." + ) + else: + parallel_sample_num = 1 + self.parallel_sample_num = parallel_sample_num if parallel_sample_num != 1: # parallel sampling +1 represents the original prefill stage diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d58523d927..0e9af90087 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -152,33 +152,33 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] else: - use_index = index is not None - rid = obj.rid if not use_index else obj.rid[index] - input_text = obj.text if not use_index else obj.text[index] + not_use_index = not (index is not None) + rid = obj.rid if not_use_index else obj.rid[index] + input_text = obj.text if not_use_index else obj.text[index] input_ids = ( self.tokenizer.encode(input_text) if obj.input_ids is None else obj.input_ids ) - if use_index and obj.input_ids: + if not not_use_index and obj.input_ids: input_ids = obj.input_ids[index] self._validate_input_length(input_ids) sampling_params = self._get_sampling_params( - obj.sampling_params if not use_index else obj.sampling_params[index] + obj.sampling_params if not_use_index else obj.sampling_params[index] ) pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data if not use_index else obj.image_data[index] + obj.image_data if not_use_index else obj.image_data[index] ) return_logprob = ( - obj.return_logprob if not use_index else obj.return_logprob[index] + obj.return_logprob if not_use_index else obj.return_logprob[index] ) logprob_start_len = ( - obj.logprob_start_len if not use_index else obj.logprob_start_len[index] + obj.logprob_start_len if not_use_index else obj.logprob_start_len[index] ) top_logprobs_num = ( - obj.top_logprobs_num if not use_index else obj.top_logprobs_num[index] + obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index] ) tokenized_obj = TokenizedGenerateReqInput( @@ -210,7 +210,7 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals async def _handle_batch_request(self, obj, request): batch_size = obj.batch_size - parallel_sample_num = obj.sampling_params[0].get("n", 1) + parallel_sample_num = obj.parallel_sample_num if parallel_sample_num != 1: ## send prefill requests diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 1628c8fa93..14be3d0104 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -6,7 +6,7 @@ import time import uuid from http import HTTPStatus -from typing import Dict, Optional +from typing import Dict, List, Optional from fastapi import HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse @@ -203,49 +203,51 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe all_ret = [] end_point = batch_storage[batch_id].endpoint + file_request_list = [] + all_requests = [] for line in lines: request_data = json.loads(line) + file_request_list.append(request_data) + body = request_data["body"] if end_point == "/v1/chat/completions": - adapted_request, request = v1_chat_generate_request( - request_data, tokenizer_manager, from_file=True - ) + all_requests.append(ChatCompletionRequest(**body)) elif end_point == "/v1/completions": - adapted_request, request = v1_generate_request( - request_data, from_file=True - ) - - try: - ret = await tokenizer_manager.generate_request( - adapted_request - ).__anext__() - - if not isinstance(ret, list): - ret = [ret] - if end_point == "/v1/chat/completions": - response = v1_chat_generate_response(request, ret, to_file=True) - else: - response = v1_generate_response(request, ret, to_file=True) - - response_json = { - "id": f"batch_req_{uuid.uuid4()}", - "custom_id": request_data.get("custom_id"), - "response": response, - "error": None, - } - all_ret.append(response_json) - - completed_requests += 1 - except Exception as e: - error_json = { - "id": f"batch_req_{uuid.uuid4()}", - "custom_id": request_data.get("custom_id"), - "response": None, - "error": {"message": str(e)}, - } - all_ret.append(error_json) - failed_requests += 1 - continue - + all_requests.append(CompletionRequest(**body)) + if end_point == "/v1/chat/completions": + adapted_request, request = v1_chat_generate_request( + all_requests, tokenizer_manager + ) + elif end_point == "/v1/completions": + adapted_request, request = v1_generate_request(all_requests) + try: + ret = await tokenizer_manager.generate_request(adapted_request).__anext__() + if not isinstance(ret, list): + ret = [ret] + if end_point == "/v1/chat/completions": + responses = v1_chat_generate_response(request, ret, to_file=True) + else: + responses = v1_generate_response(request, ret, to_file=True) + + except Exception as e: + error_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": request_data.get("custom_id"), + "response": None, + "error": {"message": str(e)}, + } + all_ret.append(error_json) + failed_requests += len(file_request_list) + + for idx, response in enumerate(responses): + ## the batch_req here can be changed to be named within a batch granularity + response_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": file_request_list[idx].get("custom_id"), + "response": response, + "error": None, + } + all_ret.append(response_json) + completed_requests += 1 # Write results to a new file output_file_id = f"backend_result_file-{uuid.uuid4()}" global storage_dir @@ -305,32 +307,49 @@ def iter_file(): return StreamingResponse(iter_file(), media_type="application/octet-stream") -def v1_generate_request(request_json, from_file=False): - if from_file: - body = request_json["body"] - request = CompletionRequest(**body) - else: - request = CompletionRequest(**request_json) +def v1_generate_request(all_requests): + + texts = [] + sampling_params_list = [] + + for request in all_requests: + texts.append(request.prompt) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "stop": request.stop, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "regex": request.regex, + "n": request.n, + "ignore_eos": request.ignore_eos, + } + ) + if len(all_requests) > 1 and request.n > 1: + raise ValueError( + "Batch operation is not supported for completions from files" + ) + + if len(all_requests) == 1: + texts = texts[0] + sampling_params_list = sampling_params_list[0] adapted_request = GenerateReqInput( - text=request.prompt, - sampling_params={ - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "stop": request.stop, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "regex": request.regex, - "n": request.n, - "ignore_eos": request.ignore_eos, - }, - return_logprob=request.logprobs is not None and request.logprobs > 0, - top_logprobs_num=request.logprobs if request.logprobs is not None else 0, + text=texts, + sampling_params=sampling_params_list, + return_logprob=all_requests[0].logprobs is not None + and all_requests[0].logprobs > 0, + top_logprobs_num=( + all_requests[0].logprobs if all_requests[0].logprobs is not None else 0 + ), return_text_in_logprobs=True, - stream=request.stream, + stream=all_requests[0].stream, ) - return adapted_request, request + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests def v1_generate_response(request, ret, to_file=False): @@ -338,12 +357,22 @@ def v1_generate_response(request, ret, to_file=False): for idx, ret_item in enumerate(ret): text = ret_item["text"] + echo = False + if isinstance(request, List) and request[idx].echo: + echo = True + text = request[idx].prompt + text - if request.echo: + elif (not isinstance(request, List)) and request.echo: + echo = True text = request.prompt + text - if request.logprobs: - if request.echo: + logprobs = False + if isinstance(request, List) and request[idx].logprobs: + logprobs = True + elif (not isinstance(request, List)) and request.logprobs: + logprobs = True + if logprobs: + if echo: prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"] prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"] else: @@ -362,7 +391,7 @@ def v1_generate_response(request, ret, to_file=False): if to_file: ## to make the choise data json serializable choice_data = { - "index": idx, + "index": 0, "text": text, "logprobs": logprobs, "finish_reason": ret_item["meta_info"]["finish_reason"], @@ -378,27 +407,29 @@ def v1_generate_response(request, ret, to_file=False): choices.append(choice_data) if to_file: - response = { - "status_code": 200, - "request_id": ret[0]["meta_info"]["id"], - "body": { - ## remain the same but if needed we can change that - "id": ret[0]["meta_info"]["id"], - "object": "text_completion", - "created": int(time.time()), - "model": request.model, - "choices": choices, - "usage": { - "prompt_tokens": ret[0]["meta_info"]["prompt_tokens"], - "completion_tokens": sum( - item["meta_info"]["completion_tokens"] for item in ret - ), - "total_tokens": ret[0]["meta_info"]["prompt_tokens"] - + sum(item["meta_info"]["completion_tokens"] for item in ret), + responses = [] + for i, choice in enumerate(choices): + response = { + "status_code": 200, + "request_id": ret[i]["meta_info"]["id"], + "body": { + ## remain the same but if needed we can change that + "id": ret[i]["meta_info"]["id"], + "object": "text_completion", + "created": int(time.time()), + "model": request[i].model, + "choices": choice, + "usage": { + "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], + "completion_tokens": ret[i]["meta_info"]["completion_tokens"], + "total_tokens": ret[i]["meta_info"]["prompt_tokens"] + + ret[i]["meta_info"]["completion_tokens"], + }, + "system_fingerprint": None, }, - "system_fingerprint": None, - }, - } + } + responses.append(response) + return responses else: response = CompletionResponse( id=ret[0]["meta_info"]["id"], @@ -418,7 +449,8 @@ def v1_generate_response(request, ret, to_file=False): async def v1_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() - adapted_request, request = v1_generate_request(request_json, from_file=False) + all_requests = [CompletionRequest(**request_json)] + adapted_request, request = v1_generate_request(all_requests) if adapted_request.stream: @@ -514,58 +546,67 @@ async def generate_stream_resp(): return response -def v1_chat_generate_request(request_json, tokenizer_manager, from_file=False): - if from_file: - body = request_json["body"] - request = ChatCompletionRequest(**body) - else: - request = ChatCompletionRequest(**request_json) - - # Prep the data needed for the underlying GenerateReqInput: - # - prompt: The full prompt string. - # - stop: Custom stop tokens. - # - image_data: None or a list of image strings (URLs or base64 strings). - # None skips any image processing in GenerateReqInput. - if not isinstance(request.messages, str): - # Apply chat template and its stop strings. - if chat_template_name is None: - prompt = tokenizer_manager.tokenizer.apply_chat_template( - request.messages, tokenize=False, add_generation_prompt=True - ) +def v1_chat_generate_request(all_requests, tokenizer_manager): + + texts = [] + sampling_params_list = [] + image_data_list = [] + for request in all_requests: + # Prep the data needed for the underlying GenerateReqInput: + # - prompt: The full prompt string. + # - stop: Custom stop tokens. + # - image_data: None or a list of image strings (URLs or base64 strings). + # None skips any image processing in GenerateReqInput. + if not isinstance(request.messages, str): + # Apply chat template and its stop strings. + if chat_template_name is None: + prompt = tokenizer_manager.tokenizer.apply_chat_template( + request.messages, tokenize=False, add_generation_prompt=True + ) + stop = request.stop + image_data = None + else: + conv = generate_chat_conv(request, chat_template_name) + prompt = conv.get_prompt() + image_data = conv.image_data + stop = conv.stop_str or [] + if request.stop: + if isinstance(request.stop, str): + stop.append(request.stop) + else: + stop.extend(request.stop) + else: + # Use the raw prompt and stop strings if the messages is already a string. + prompt = request.messages stop = request.stop image_data = None - else: - conv = generate_chat_conv(request, chat_template_name) - prompt = conv.get_prompt() - image_data = conv.image_data - stop = conv.stop_str or [] - if request.stop: - if isinstance(request.stop, str): - stop.append(request.stop) - else: - stop.extend(request.stop) - else: - # Use the raw prompt and stop strings if the messages is already a string. - prompt = request.messages - stop = request.stop - image_data = None - + texts.append(prompt) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "stop": stop, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "regex": request.regex, + "n": request.n, + } + ) + image_data_list.append(image_data) + if len(all_requests) == 1: + texts = texts[0] + sampling_params_list = sampling_params_list[0] + image_data = image_data_list[0] adapted_request = GenerateReqInput( - text=prompt, + text=texts, image_data=image_data, - sampling_params={ - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "stop": stop, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "regex": request.regex, - "n": request.n, - }, + sampling_params=sampling_params_list, stream=request.stream, ) - return adapted_request, request + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests def v1_chat_generate_response(request, ret, to_file=False): @@ -578,9 +619,9 @@ def v1_chat_generate_response(request, ret, to_file=False): completion_tokens = ret_item["meta_info"]["completion_tokens"] if to_file: - ## to make the choise data json serializable + ## to make the choice data json serializable choice_data = { - "index": idx, + "index": 0, "message": {"role": "assistant", "content": ret_item["text"]}, "logprobs": None, "finish_reason": ret_item["meta_info"]["finish_reason"], @@ -595,25 +636,31 @@ def v1_chat_generate_response(request, ret, to_file=False): choices.append(choice_data) total_prompt_tokens = prompt_tokens total_completion_tokens += completion_tokens - if to_file: - response = { - "status_code": 200, - "request_id": ret[0]["meta_info"]["id"], - "body": { - "id": ret[0]["meta_info"]["id"], - "object": "chat.completion", - "created": int(time.time()), - "model": request.model, - "choices": choices, - "usage": { - "prompt_tokens": total_prompt_tokens, - "completion_tokens": total_completion_tokens, - "total_tokens": total_prompt_tokens + total_completion_tokens, + responses = [] + + for i, choice in enumerate(choices): + response = { + "status_code": 200, + "request_id": ret[i]["meta_info"]["id"], + "body": { + ## remain the same but if needed we can change that + "id": ret[i]["meta_info"]["id"], + "object": "chat.completion", + "created": int(time.time()), + "model": request[i].model, + "choices": choice, + "usage": { + "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], + "completion_tokens": ret[i]["meta_info"]["completion_tokens"], + "total_tokens": ret[i]["meta_info"]["prompt_tokens"] + + ret[i]["meta_info"]["completion_tokens"], + }, + "system_fingerprint": None, }, - "system_fingerprint": None, - }, - } + } + responses.append(response) + return responses else: response = ChatCompletionResponse( id=ret[0]["meta_info"]["id"], @@ -625,15 +672,13 @@ def v1_chat_generate_response(request, ret, to_file=False): total_tokens=total_prompt_tokens + total_completion_tokens, ), ) - return response + return response async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() - - adapted_request, request = v1_chat_generate_request( - request_json, tokenizer_manager, from_file=False - ) + all_requests = [ChatCompletionRequest(**request_json)] + adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) if adapted_request.stream: From 3db9d03600ea95532adc8f525346d395295cdf34 Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Sun, 28 Jul 2024 17:48:07 +0000 Subject: [PATCH 6/8] solve conflict again --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a54cec3920..0e9af90087 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -514,4 +514,4 @@ def get_pixel_values( pixel_values = pixel_values.astype(np.float16) return pixel_values, image_hash, image.size except Exception: - print("Exception in TokenizerManager:\n" + get_exception_traceback()) \ No newline at end of file + print("Exception in TokenizerManager:\n" + get_exception_traceback()) From 512aff99bb869bd50c576093a29c4c37890f031f Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Mon, 29 Jul 2024 08:30:21 +0000 Subject: [PATCH 7/8] store meta data only in mem --- python/sglang/srt/openai_api/adapter.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 2aefc16f1a..d4e4e9cef5 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -59,9 +59,16 @@ chat_template_name = None + +class FileMetadata: + def __init__(self, filename: str, purpose: str): + self.filename = filename + self.purpose = purpose + + # In-memory storage for batch jobs and files batch_storage: Dict[str, BatchResponse] = {} -file_id_request: Dict[str, FileRequest] = {} +file_id_request: Dict[str, FileMetadata] = {} file_id_response: Dict[str, FileResponse] = {} ## map file id to file path in SGlang backend file_id_storage: Dict[str, str] = {} @@ -146,7 +153,7 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str f.write(request_body.file) # add info to global file map - file_id_request[file_id] = request_body + file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose) file_id_storage[file_id] = file_path # Return the response in the required format From 5090891a67f1a498083652cd688f902f6d9159df Mon Sep 17 00:00:00 2001 From: yichuan520030910320 Date: Mon, 29 Jul 2024 10:07:04 +0000 Subject: [PATCH 8/8] fix small bug about logprobs when solving conflict --- python/sglang/srt/openai_api/adapter.py | 39 ++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index d4e4e9cef5..5fa75f1b88 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -331,11 +331,15 @@ def iter_file(): def v1_generate_request(all_requests): - texts = [] + prompts = [] sampling_params_list = [] - + first_prompt_type = type(all_requests[0].prompt) for request in all_requests: - texts.append(request.prompt) + prompt = request.prompt + assert ( + type(prompt) == first_prompt_type + ), "All prompts must be of the same type in file input settings" + prompts.append(prompt) sampling_params_list.append( { "temperature": request.temperature, @@ -355,11 +359,20 @@ def v1_generate_request(all_requests): ) if len(all_requests) == 1: - texts = texts[0] + prompt = prompts[0] sampling_params_list = sampling_params_list[0] + if isinstance(prompts, str) or isinstance(prompts[0], str): + prompt_kwargs = {"text": prompt} + else: + prompt_kwargs = {"input_ids": prompt} + else: + if isinstance(prompts[0], str): + prompt_kwargs = {"text": prompts} + else: + prompt_kwargs = {"input_ids": prompts} adapted_request = GenerateReqInput( - text=texts, + **prompt_kwargs, sampling_params=sampling_params_list, return_logprob=all_requests[0].logprobs is not None and all_requests[0].logprobs > 0, @@ -401,17 +414,17 @@ def v1_generate_response(request, ret, to_file=False): logprobs = True if logprobs: if echo: - prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"] - prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"] + input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] + input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] else: - prefill_token_logprobs = None - prefill_top_logprobs = None + input_token_logprobs = None + input_top_logprobs = None logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"], - decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"], + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], ) else: logprobs = None