Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix/Inference] Fix GQA Triton and Support Llama3 #5624

1 change: 1 addition & 0 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
seq.revoke_finished_status()
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens

def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
Expand Down
18 changes: 16 additions & 2 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,13 @@ def generate(
"""
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)

output_seqs_list = []
total_tokens_list = []
Expand Down Expand Up @@ -573,6 +579,7 @@ def add_request(
request_ids: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None:
"""
Add requests.
Expand Down Expand Up @@ -629,6 +636,13 @@ def add_request(
else:
prompt = prompts[i]

max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])

sequence = Sequence(
request_id,
prompt,
Expand All @@ -637,7 +651,7 @@ def add_request(
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
max_output_len=max_new_tokens,
)
self.request_handler.add_sequence(sequence)

Expand Down
9 changes: 5 additions & 4 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,11 @@ def update_seq_finished(self, sequence: Sequence, generation_config: GenerationC

def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
max_length = generation_config.max_length
max_new_tokens = generation_config.max_new_tokens
if max_length is not None:
max_new_tokens = max_length - seq.input_len
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()

def check_unfinished_seqs(self) -> bool:
Expand Down
21 changes: 13 additions & 8 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class KVCacheManager:
The block table after block allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 |
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
and the physical caches, each with size of `block_size * head_num * head_size * elem_size` for a single layer,
and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer,
corresponding to these blocks will be used to read/write KV Caches in kernels.

For a batch of sequences, the block tables after allocation might be:
Expand All @@ -64,9 +64,12 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size
assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len
Expand All @@ -80,19 +83,21 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
# if verbose:
# self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
* 2
* self.num_blocks
* self.block_size
* self.head_num
* self.kv_head_num
* self.head_size
)
self.logger.info(
f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}."
)
# Logical cache blocks allocation
self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches())
Expand Down Expand Up @@ -453,7 +458,7 @@ def _init_logical_caches(self):
"""
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
blocks = []
physical_block_size = self.elem_size_in_bytes * self.block_size * self.head_num * self.head_size
physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size
k_ptrs = [
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
]
Expand Down
7 changes: 4 additions & 3 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ def __init__(
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
else:
self.q_proj_weight = attn_qproj_w
self.k_proj_weight = attn_kproj_w
self.v_proj_weight = attn_vproj_w
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())

@staticmethod
def from_native_module(
Expand Down Expand Up @@ -638,6 +638,7 @@ def forward(
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)

Expand Down
8 changes: 8 additions & 0 deletions colossalai/inference/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def check_finish(self) -> bool:

return False

def revoke_finished_status(self) -> None:
"""
Revoke the finished status of the sequence.
This is only used by speculative decoding for now.
"""
if RequestStatus.is_finished(self.status):
self.status = RequestStatus.RUNNING

def __hash__(self):
return hash(self.request_id)

Expand Down
Loading
Loading