Skip to content

Commit

Permalink
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
Browse files Browse the repository at this point in the history
* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
  • Loading branch information
yuanheng-zhao authored Feb 19, 2024
1 parent 8c69deb commit b21aac5
Show file tree
Hide file tree
Showing 11 changed files with 905 additions and 115 deletions.
449 changes: 449 additions & 0 deletions colossalai/inference/batch_bucket.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _verify_config(self) -> None:
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"

# check distributed
assert (
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
# check prompt template
Expand Down
10 changes: 1 addition & 9 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
Expand Down Expand Up @@ -254,20 +254,12 @@ def add_request(
else:
prompt = prompts[i]

max_blocks_per_sequence = (
self.inference_config.max_input_len
+ self.inference_config.max_output_len
+ self.inference_config.block_size
- 1
) // self.inference_config.block_size
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
Expand Down
200 changes: 117 additions & 83 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import List
from typing import Dict, List, Union

import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig

from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger

__all__ = ["RunningList", "RequestHandler"]
Expand All @@ -24,45 +25,79 @@ class RunningList:
Args:
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
prefill: (List) List that contains default inputs, defaults to [].
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
"""

def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
self.prefill_ratio = prefill_ratio
self.decoding: List[Sequence] = []
self.prefill: List[Sequence] = prefill if prefill is not None else []
self._decoding: Dict[int, Sequence] = dict()
self._prefill: Dict[int, Sequence] = (
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
)

def append(self, seq: Sequence):
# add seq to prefilling list first.
self.prefill.append(seq)

def find_seq(self, request_id):
for seq in self.decoding:
if request_id == seq.request_id:
return seq
for seq in self.prefill:
if request_id == seq.request_id:
return seq
return None
@property
def decoding(self):
return list(self._decoding.values())

@property
def prefill(self):
return list(self._prefill.values())

@property
def prefill_seq_num(self):
return len(self._prefill)

@property
def decoding_seq_num(self):
return len(self._decoding)

@property
def total_seq_num(self):
return self.prefill_seq_num + self.decoding_seq_num

def remove(self, seq: Sequence):
if seq in self.decoding:
self.decoding.remove(seq)
elif seq in self.prefill:
self.prefill.remove(seq)
def append(self, seq: Sequence):
assert (seq.request_id not in self._prefill) and (
seq.request_id not in self._decoding
), f"Sequence uid {seq.request_id} already exists."
self._prefill[seq.request_id] = seq

def extend(self, seqs: List[Sequence]):
for seq in seqs:
self._prefill[seq.request_id] = seq

def find_seq(self, request_id) -> Union[Sequence, None]:
seq = None
if request_id in self._decoding:
seq = self._decoding[request_id]
elif request_id in self._prefill:
seq = self._prefill[request_id]
return seq

def remove(self, seq: Sequence) -> None:
if seq.request_id in self._decoding:
self._decoding.pop(seq.request_id)
elif seq.request_id in self._prefill:
self._prefill.pop(seq.request_id)
else:
raise ValueError(f"sequence {seq.request_id} is not in running list")
raise ValueError(f"Sequence {seq.request_id} is not in running list")

def ready_for_prefill(self):
if not self.decoding:
return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
if not self._decoding:
return len(self._prefill) > 0
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio

def is_empty(self):
return not self.decoding and not self.prefill
return not self._decoding and not self._prefill

def total_seq_num(self):
return len(self.decoding) + len(self.prefill)
def mark_prefill_running(self) -> None:
for seq_id in self._prefill:
self._prefill[seq_id].mark_running()

def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
for seq_id in seq_ids:
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
self._decoding[seq_id] = self._prefill.pop(seq_id)


class RequestHandler:
Expand Down Expand Up @@ -110,25 +145,27 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo

# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
device=device,
)

def _init_cache(self, model_config):
Expand Down Expand Up @@ -159,40 +196,39 @@ def schedule(self):
remove_list.append(seq)
break

# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
break
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
remove_list.extend(lst[:num_seqs_to_add])
self.running_list.extend(lst[:num_seqs_to_add])

# Try to allocate cache blocks for the sequence.
if (
self.cache_manager.check_allocation(seq)
and (len(self.running_list.prefill) + len(self.running_list.decoding))
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
):
# If succeed, add the sequence to running list.
remove_list.append(seq)
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
for seq in remove_list:
lst.remove(seq)

if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
self.prefill_batch.add_seqs(self.running_list.prefill)
return self.prefill_batch
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)

if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set:
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
for seq in self.running_list.prefill[:num_seqs_to_add]:
seq.mark_running()
# allocate blocks for the prefill batch
self.prefill_bb.add_seqs(
self.running_list.prefill[:num_seqs_to_add],
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
)

return self.prefill_bb

if not self.running_bb.is_empty:
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
)
if seqs_ids_to_recycle:
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
for seq in seqs_to_recycle:
seq.recycle()
self.running_batch.del_seq(seq)
self.running_list.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.

return self.running_batch
return self.running_bb

def add_sequence(self, req: Sequence):
"""
Expand All @@ -213,7 +249,7 @@ def abort_sequence(self, request_id: str):
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.cache_manager.free_block_table(seq.block_table)
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
Expand Down Expand Up @@ -242,7 +278,7 @@ def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)

return sample_tokens

Expand Down Expand Up @@ -273,27 +309,25 @@ def search_tokens(self, generation_config: GenerationConfig, logits):

# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
if not self.prefill_bb.is_empty:
self.prefill_bb.append_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens)
self.running_bb.append_batch_tokens(sample_tokens)

def update(self):
"""
Update current running list and done list
"""
if not self.prefill_batch.is_empty:
self.running_list.decoding.extend(self.running_list.prefill)
self.running_batch.add_seqs(self.running_list.prefill)
self.running_list.prefill.clear()
self.prefill_batch.clear_batch()

finish_seqs = self.running_batch.fliter_batch()

for seq in finish_seqs:
if not self.prefill_bb.is_empty:
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
self.running_bb.merge(self.prefill_bb)
# clear the prefill batch without assigning a free_block_tables_fn
# since we want to reuse the memory recorded on the block tables
self.prefill_bb.clear(free_block_tables_fn=None)

finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
for seq in finished_seqs:
self.running_list.remove(seq)
self.cache_manager.free_block_table(seq.block_table)

self.done_list.extend(finish_seqs)
self.done_list.extend(finished_seqs)

return finish_seqs
return finished_seqs
Loading

0 comments on commit b21aac5

Please sign in to comment.