Skip to content

Commit

Permalink
[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)
Browse files Browse the repository at this point in the history
* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yuanheng-zhao and pre-commit-ci[bot] authored Apr 23, 2024
1 parent ccf7279 commit 5d4c1fe
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 199 deletions.
1 change: 1 addition & 0 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
seq.revoke_finished_status()
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens

def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
Expand Down
18 changes: 16 additions & 2 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,13 @@ def generate(
"""
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)

output_seqs_list = []
total_tokens_list = []
Expand Down Expand Up @@ -573,6 +579,7 @@ def add_request(
request_ids: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None:
"""
Add requests.
Expand Down Expand Up @@ -629,6 +636,13 @@ def add_request(
else:
prompt = prompts[i]

max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])

sequence = Sequence(
request_id,
prompt,
Expand All @@ -637,7 +651,7 @@ def add_request(
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
max_output_len=max_new_tokens,
)
self.request_handler.add_sequence(sequence)

Expand Down
9 changes: 5 additions & 4 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,11 @@ def update_seq_finished(self, sequence: Sequence, generation_config: GenerationC

def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
max_length = generation_config.max_length
max_new_tokens = generation_config.max_new_tokens
if max_length is not None:
max_new_tokens = max_length - seq.input_len
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()

def check_unfinished_seqs(self) -> bool:
Expand Down
21 changes: 13 additions & 8 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class KVCacheManager:
The block table after block allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 |
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer,
and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer,
corresponding to these blocks will be used to read/write KV Caches in kernels.
For a batch of sequences, the block tables after allocation might be:
Expand All @@ -64,9 +64,12 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size
assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len
Expand All @@ -80,19 +83,21 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
# if verbose:
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
* 2
* self.num_blocks
* self.block_size
* self.head_num
* self.kv_head_num
* self.head_size
)
self.logger.info(
f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}."
)
# Logical cache blocks allocation
self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches())
Expand Down Expand Up @@ -453,7 +458,7 @@ def _init_logical_caches(self):
"""
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
blocks = []
physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size
physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size
k_ptrs = [
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
]
Expand Down
7 changes: 4 additions & 3 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ def __init__(
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
else:
self.q_proj_weight = attn_qproj_w
self.k_proj_weight = attn_kproj_w
self.v_proj_weight = attn_vproj_w
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())

@staticmethod
def from_native_module(
Expand Down Expand Up @@ -638,6 +638,7 @@ def forward(
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)

Expand Down
8 changes: 8 additions & 0 deletions colossalai/inference/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def check_finish(self) -> bool:

return False

def revoke_finished_status(self) -> None:
"""
Revoke the finished status of the sequence.
This is only used by speculative decoding for now.
"""
if RequestStatus.is_finished(self.status):
self.status = RequestStatus.RUNNING

def __hash__(self):
return hash(self.request_id)

Expand Down
Loading

0 comments on commit 5d4c1fe

Please sign in to comment.