Skip to content

Commit

Permalink
Support more OpenAI API test (#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
yichuan520030910320 authored Aug 4, 2024
1 parent bb66cc4 commit d53dcf9
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 59 deletions.
13 changes: 9 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def post_init(self):
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples
# TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
Expand All @@ -103,14 +103,19 @@ def post_init(self):
if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1
if isinstance(self.text, List):
## suppot batch operation
if isinstance(self.text, list):
# suppot batch operation
self.batch_size = len(self.text)
num = num * len(self.text)
elif isinstance(self.input_ids, list) and isinstance(
self.input_ids[0], list
):
self.batch_size = len(self.input_ids)
num = num * len(self.input_ids)
else:
self.batch_size = 1
else:
## support select operation
# support select operation
num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num

Expand Down
41 changes: 28 additions & 13 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ async def generate_request(self, obj: GenerateReqInput, request=None):
async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False
):
if not is_cache_for_prefill:
not_use_index = not (index is not None)
if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None

rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = (
Expand Down Expand Up @@ -182,14 +183,27 @@ async def _handle_single_request(
top_logprobs_num = (
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
)
else:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else: # A prefill request to cache the common prompt for parallel sampling
if obj.text is not None:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
input_text = None
if isinstance(obj.input_ids, list) and isinstance(
obj.input_ids[0], list
):
# when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[index]
rid = obj.rid[index]
else:
input_ids = obj.input_ids
rid = obj.rid[0]

sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
Expand Down Expand Up @@ -240,11 +254,11 @@ async def _handle_batch_request(self, obj: GenerateReqInput, request):
):
if input_id_result is not None:
input_id_result.append(input_id)
pass
if len(input_id_result) > 1 and input_id_result is not None:
if input_id_result is not None and len(input_id_result) > 1:
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]

# First send out all requests
for i in range(batch_size):
for j in range(parallel_sample_num):
Expand All @@ -264,11 +278,12 @@ async def _handle_batch_request(self, obj: GenerateReqInput, request):
input_text = None
input_ids = obj.input_ids[i]
else:
assert obj.input_ids is not None
if batch_size == 1:
input_text = obj.text
input_text = None
input_ids = obj.input_ids
else:
input_text = obj.text[i]
input_text = None
input_ids = obj.input_ids[i]
sampling_params = self._get_sampling_params(obj.sampling_params[index])
pixel_values, image_hash, image_size = await self._get_pixel_values(
Expand Down
Loading

0 comments on commit d53dcf9

Please sign in to comment.