Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve efficiency of scheduling and token sampiling #377

Merged
merged 9 commits into from
Jan 18, 2024
8 changes: 4 additions & 4 deletions mii/batching/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,20 @@ def next_tokens(self) -> List[torch.Tensor]:
return [r.next_token for r in self.requests]

@property
def done_tokens(self) -> List[torch.Tensor]:
def done_tokens(self) -> List[bool]:
return [r.is_done for r in self.requests]

@next_tokens.setter
def next_tokens(self, next_tokens: List[torch.Tensor]) -> None:
def next_tokens(self, next_tokens: torch.Tensor) -> None:
assert len(next_tokens) == len(self.requests)
for idx, r in enumerate(self.requests):
r.next_token = next_tokens[idx]

@done_tokens.setter
def done_tokens(self, done_tokens: List[torch.Tensor]) -> None:
def done_tokens(self, done_tokens: torch.Tensor) -> None:
assert len(done_tokens) == len(self.requests)
for idx, r in enumerate(self.requests):
r.is_done = done_tokens[idx]
r.is_done = done_tokens[idx].item()

def to_msg_dicts(self) -> List[Dict[str, Any]]:
return [r.to_msg_dict() for r in self.requests]
Expand Down
9 changes: 5 additions & 4 deletions mii/batching/generation/logit_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ def forward(self, logits: torch.Tensor) -> torch.Tensor:
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = FLOAT_PAD
return logits

indices_to_remove = sorted_indices_to_remove.scatter(1,
sorted_indices,
sorted_indices_to_remove)
return logits.masked_fill(indices_to_remove, FLOAT_PAD)

def get_key(self) -> str:
return super().get_key() + f"_top_p={self.top_p}"
Expand Down
49 changes: 36 additions & 13 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def __init__(self, inference_engine, tokenizer, model_config):
self.buffer = deque()
self.scheduled_length = 0
self.scheduled_seq_num = 0
self.scheduled_req_blocks = torch.zeros(inference_engine.n_kv_cache_groups,
dtype=torch.int32,
device="cpu")
self.scheduled_req_blocks = 0

# TODO: we will need to prune self._post_processors for long running deployments
self._post_processors = {}
Expand Down Expand Up @@ -175,7 +173,7 @@ def _reset_scheduler_bookkeeping(self) -> None:
self.scheduled_requests = RequestBatch()
self.scheduled_length = 0
self.scheduled_seq_num = 0
self.scheduled_req_blocks.zero_()
self.scheduled_req_blocks = 0

@sync_debug
def _process_logits(
Expand All @@ -194,6 +192,7 @@ def _process_logits(
running_requests,
self._post_processors)
next_tokens = next_tokens.to(torch.device("cpu"), non_blocking=False)
done_tokens = done_tokens.to(torch.device("cpu"), non_blocking=False)
return next_tokens, done_tokens

@sync_debug
Expand Down Expand Up @@ -226,12 +225,35 @@ def _generate_output(self, r: Request) -> bool:
for output in outputs:
self.result_queues[r.tid].put_nowait(output)

def _do_schedule_requests(self, requests: List[Request]) -> None:
def _schedule_token_gen(self, requests: List[Request]) -> None:
free_blocks = min(self.inference_engine.free_blocks)
conf_manager = self.inference_engine._config.state_manager

free_blocks = self.inference_engine.free_blocks
num_schedulable = min([
len(requests),
conf_manager.max_ragged_sequence_count,
conf_manager.max_ragged_batch_size
])

for r in requests[:num_schedulable]:
block_capacity = self.inference_engine.get_remaining_block_capacity(r.uid)
# We can schedule token generation if the last block has a capacity
if block_capacity > 0:
self.scheduled_length += 1
self.scheduled_requests.append(r)
elif free_blocks > 0:
# We need a new block
free_blocks -= 1
self.scheduled_length += 1
self.scheduled_req_blocks += 1
self.scheduled_requests.append(r)

def _schedule_prompts(self, requests: List[Request]) -> None:
free_blocks = min(self.inference_engine.free_blocks)
conf_manager = self.inference_engine._config.state_manager

for r in requests:
if free_blocks.min().item() == 0:
if free_blocks == 0:
break

if r.max_length <= r.seq_length:
Expand Down Expand Up @@ -285,22 +307,22 @@ def schedule_requests(self) -> None:
r = self.request_queue.get_nowait()
self.buffer.append(r)

# Run next token generation first
next_token_gen_reqs = []
prompt_reqs = []

for r in self.buffer:
if r.is_flush_request:
self.scheduled_requests.append(r)
else:
if len(r.input_tokens) == 1:
next_token_gen_reqs.append(r)
if r.num_generated_tokens > 0:
if r.max_length > r.seq_length:
next_token_gen_reqs.append(r)
else:
prompt_reqs.append(r)

# We want to process next token generation first
self._do_schedule_requests(next_token_gen_reqs)
self._do_schedule_requests(prompt_reqs)
self._schedule_token_gen(next_token_gen_reqs)
self._schedule_prompts(prompt_reqs)

if len(self.buffer) > 0 and len(self.scheduled_requests) == 0:
print(
Expand Down Expand Up @@ -422,7 +444,8 @@ def make_response(self,
finish_reason=finish_reason)

def put(self, uids: List[int], tokenized_input: List[torch.Tensor]) -> torch.Tensor:
return self.inference_engine.put(uids, tokenized_input)
# Call inference engine. You can skip checking schedulability because we already checked when scheduling
return self.inference_engine.put(uids, tokenized_input, do_checks=False)

def flush(self, uids: List[int]) -> None:
for uid in uids:
Expand Down
Loading