Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… #5663

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,18 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Physical cache allocation
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape)
if config.use_cuda_kernel:
x = 16 // torch.tensor([], dtype=config.dtype).element_size()
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
)
self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape)
else:
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
Expand Down Expand Up @@ -479,7 +488,9 @@ def _init_logical_caches(self):
blocks.append(cache_block)
return blocks

def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
def _init_device_caches(
self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device.

For each layer of the model, we allocate two tensors for key and value respectively,
Expand All @@ -488,6 +499,6 @@ def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tenso
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
return k_cache, v_cache
46 changes: 31 additions & 15 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def forward(
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
Expand All @@ -252,6 +253,21 @@ def forward(
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
Expand All @@ -275,21 +291,21 @@ def forward(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)

attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
Expand Down
67 changes: 31 additions & 36 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,8 @@ def llama_model_forward(
"""
block_tables = inputmetadata.block_tables
sequence_lengths = inputmetadata.sequence_lengths
batch_size = inputmetadata.batch_size
kv_seq_len = inputmetadata.kv_seq_len

# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False

# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
# during speculative-decoding (`q_len > 1`)
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
Expand Down Expand Up @@ -575,6 +568,7 @@ def forward(
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
Expand All @@ -592,20 +586,21 @@ def forward(
block_tables,
high_precision,
)
# inference_ops.flash_decoding_attention(
# output_tensor,
# query_states,
# k_cache,
# v_cache,
# sequence_lengths,
# block_tables,
# block_size,
# kv_seq_len,
# fd_inter_tensor.mid_output,
# fd_inter_tensor.mid_output_lse,
# sm_scale,
# )
# attn_output = output_tensor
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
None,
sm_scale,
)
attn_output = output_tensor
else:
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
Expand All @@ -627,21 +622,21 @@ def forward(
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
configs = [
triton.testing.Benchmark(
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
x_vals=[2**i for i in range(3, 8)],
x_vals=[2**i for i in range(2, 8)],
line_arg="provider",
line_vals=[
"vllm_paged_decoding_attention",
Expand Down Expand Up @@ -113,6 +113,8 @@ def benchmark_flash_decoding_attention(
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
alibi_slopes = None
kv_scale = 1.0

mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
Expand All @@ -136,6 +138,7 @@ def benchmark_flash_decoding_attention(
max_seq_len_across_batch,
alibi_slopes,
"auto",
kv_scale,
)
elif provider == "triton_flash_decoding_attention":
fn = lambda: flash_decoding_attention(
Expand Down Expand Up @@ -164,6 +167,7 @@ def benchmark_flash_decoding_attention(
max_seq_len_across_batch,
mid_output,
mid_output_lse,
alibi_slopes,
sm_scale,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
from tests.test_infer.test_ops.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
mock_alloc_single_token,
)

inference_ops = InferenceOpsLoader().load()

Expand Down Expand Up @@ -68,11 +72,17 @@ def benchmark_rotary_emb(
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
x = 16 // torch.tensor([], dtype=dtype).element_size()
new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda")

past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
_ = mock_alloc_block_table_and_kvcache_v3(
k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k)
Expand All @@ -94,12 +104,12 @@ def benchmark_rotary_emb(
)
elif provider == "no_fused_cuda_rotary_emb_func":
fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
]
elif provider == "fused_cuda_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
)
else:
raise ValueError("Undefined provider")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data

try:
Expand Down Expand Up @@ -68,6 +69,9 @@ def benchmark_kvcache_copy(
elif provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
elif provider == "cuda_copy_func":
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype
)
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
Expand Down
Loading
Loading