Skip to content

Commit

Permalink
Optimize schedule (sgl-project#1339)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Sep 5, 2024
1 parent 62f15ee commit ab4a83b
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 8 deletions.
110 changes: 105 additions & 5 deletions python/sglang/srt/managers/policy_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,24 @@ class PrefillAdder:
def __init__(
self,
tree_cache: BasePrefixCache,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
):
self.tree_cache = tree_cache
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens

self.req_states = None
self.can_run_list = []
self.new_inflight_req = None
self.log_hit_tokens = 0
Expand All @@ -136,16 +142,14 @@ def no_remaining_tokens(self):
)
)

def remove_running_tokens(
self, running_batch: ScheduleBatch, new_token_ratio: float
):
def remove_running_tokens(self, running_batch: ScheduleBatch):
self.rem_total_tokens -= sum(
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* new_token_ratio
* self.new_token_ratio
for r in running_batch.reqs
]
)
Expand All @@ -161,7 +165,29 @@ def _prefill_one_req(
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len

def add_inflight_req_ignore_eos(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)

self._prefill_one_req(
0,
req.extend_input_len,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)

# Return if chunked prefill not finished
return req if truncated else None

def add_inflight_req(self, req: Req):
if req.sampling_params.ignore_eos:
return self.add_inflight_req_ignore_eos(req)

truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
Expand Down Expand Up @@ -190,7 +216,81 @@ def _lock_node(self, last_node: TreeNode):
delta = self.tree_cache.dec_lock_ref(last_node)
self.rem_total_tokens += delta

def add_one_req_ignore_eos(self, req: Req):
def get_req_state(r):
new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
)
tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
r.output_ids
)
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)

if tokens_left > 0:
return (tokens_left, tokens_occupied)

return None

if self.req_states is None:
self.req_states = []
if self.running_batch is not None:
for r in self.running_batch.reqs:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
for r in self.can_run_list:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
state = get_req_state(req)
if state is not None:
self.req_states.append(state)

self.req_states.sort(key=lambda x: x[0])
else:
state = get_req_state(req)
if state is not None:
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
if tokens_left >= state[0]:
self.req_states.insert(i, state)
break
else:
self.req_states.append(state)

tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
bs = len(self.req_states) - i
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
return False
tokens_freed += tokens_occupied

if req.extend_input_len <= self.rem_chunk_tokens:
self.can_run_list.append(req)
self._prefill_one_req(
0,
req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_inflight_req = req
self._prefill_one_req(0, trunc_len, 0)

return True

def add_one_req(self, req: Req):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req)

total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
)
Expand Down Expand Up @@ -233,4 +333,4 @@ def add_one_req(self, req: Req):
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)

return True
return True and not self.no_remaining_tokens()
21 changes: 18 additions & 3 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False

def exposed_step(self, recv_reqs: List):
try:
Expand Down Expand Up @@ -253,7 +254,13 @@ def exposed_step(self, recv_reqs: List):

@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_prefill_batch()
if self.current_inflight_req is not None:
self.do_not_get_new_batch = False

new_batch = (
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
)
self.do_not_get_new_batch = False

if new_batch is not None:
# Run a new prefill batch
Expand Down Expand Up @@ -409,14 +416,16 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:

adder = PrefillAdder(
self.tree_cache,
self.running_batch,
self.new_token_ratio,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens,
self.chunked_prefill_size,
num_mixed_running,
)

if self.running_batch is not None:
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
adder.remove_running_tokens(self.running_batch)

has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
Expand All @@ -428,11 +437,12 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
)

for req in self.waiting_queue:
if adder.no_remaining_tokens():
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res
or adder.no_remaining_tokens()
or running_bs + len(adder.can_run_list) >= self.max_running_requests
):
break
Expand Down Expand Up @@ -700,6 +710,7 @@ def forward_decode_batch(self, batch: ScheduleBatch):
next_token_ids = next_token_ids.tolist()

# Check finish condition
has_finished = False
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
Expand All @@ -712,6 +723,7 @@ def forward_decode_batch(self, batch: ScheduleBatch):

if req.finished():
self.tree_cache.cache_finished_req(req)
has_finished = True

if req.return_logprob:
req.output_token_logprobs.append(
Expand All @@ -720,6 +732,9 @@ def forward_decode_batch(self, batch: ScheduleBatch):
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

if not has_finished:
self.do_not_get_new_batch = True

self.handle_finished_requests(batch)

def handle_finished_requests(self, batch: ScheduleBatch):
Expand Down

0 comments on commit ab4a83b

Please sign in to comment.