Skip to content

Commit

Permalink
simplify scheduling for token generation
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Jan 15, 2024
1 parent 6786737 commit 94ebae1
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def __init__(self, inference_engine, tokenizer, model_config):
self.buffer = deque()
self.scheduled_length = 0
self.scheduled_seq_num = 0
self.scheduled_req_blocks = torch.zeros(inference_engine.n_kv_cache_groups,
dtype=torch.int32,
device="cpu")
self.scheduled_req_blocks = 0

# TODO: we will need to prune self._post_processors for long running deployments
self._post_processors = {}
Expand Down Expand Up @@ -175,7 +173,7 @@ def _reset_scheduler_bookkeeping(self) -> None:
self.scheduled_requests = RequestBatch()
self.scheduled_length = 0
self.scheduled_seq_num = 0
self.scheduled_req_blocks.zero_()
self.scheduled_req_blocks = 0

@sync_debug
def _process_logits(
Expand Down Expand Up @@ -227,12 +225,33 @@ def _generate_output(self, r: Request) -> bool:
for output in outputs:
self.result_queues[r.tid].put_nowait(output)

def _do_schedule_requests(self, requests: List[Request]) -> None:
def _schedule_token_gen(self, requests: List[Request]) -> None:
free_blocks = self.inference_engine.free_blocks.min().item()
conf_manager = self.inference_engine._config.state_manager

num_schedulable = min(len(requests), conf_manager.max_ragged_sequence_count)
num_schedulable = min(num_schedulable, conf_manager.max_ragged_batch_size)

free_blocks = self.inference_engine.free_blocks
for r in requests[:num_schedulable]:
block_capacity = self.inference_engine.get_remaining_block_capacity(r.uid)
# We can schedule token generation if the last block has a capacity
if block_capacity > 0:
self.scheduled_length += 1
self.scheduled_requests.append(r)
else:
# We need a new block
if free_blocks > 0:
free_blocks -= 1
self.scheduled_length += 1
self.scheduled_req_blocks += 1
self.scheduled_requests.append(r)

def _schedule_prompts(self, requests: List[Request]) -> None:
free_blocks = self.inference_engine.free_blocks.min().item()
conf_manager = self.inference_engine._config.state_manager

for r in requests:
if free_blocks.min().item() == 0:
if free_blocks == 0:
break

if r.max_length <= r.seq_length:
Expand Down Expand Up @@ -286,22 +305,22 @@ def schedule_requests(self) -> None:
r = self.request_queue.get_nowait()
self.buffer.append(r)

# Run next token generation first
next_token_gen_reqs = []
prompt_reqs = []

for r in self.buffer:
if r.is_flush_request:
self.scheduled_requests.append(r)
else:
if len(r.input_tokens) == 1:
next_token_gen_reqs.append(r)
if r.num_generated_tokens > 0:
if r.max_length > r.seq_length:
next_token_gen_reqs.append(r)
else:
prompt_reqs.append(r)

# We want to process next token generation first
self._do_schedule_requests(next_token_gen_reqs)
self._do_schedule_requests(prompt_reqs)
self._schedule_token_gen(next_token_gen_reqs)
self._schedule_prompts(prompt_reqs)

if len(self.buffer) > 0 and len(self.scheduled_requests) == 0:
print(
Expand Down

0 comments on commit 94ebae1

Please sign in to comment.