Skip to content

Commit

Permalink
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision…
Browse files Browse the repository at this point in the history
… Flag To Rotary Embedding (#5461)

* Support FP16/BF16 Flash Attention 2

* fix bugs in test_kv_cache_memcpy.py

* add context_kv_cache_memcpy_kernel.cu

* rm typename MT

* add tail process

* add high_precision

* add high_precision to config.py

* rm unused code

* change the comment for the high_precision parameter

* update test_rotary_embdding_unpad.py

* fix vector_copy_utils.h

* add comment for self.high_precision when using float32
  • Loading branch information
isky-cd authored Mar 25, 2024
1 parent 7ff42cc commit 87079cf
Show file tree
Hide file tree
Showing 15 changed files with 549 additions and 137 deletions.
7 changes: 6 additions & 1 deletion colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class InferenceConfig:
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""

# NOTE: arrange configs according to their importance and frequency of usage
Expand Down Expand Up @@ -89,6 +89,7 @@ class InferenceConfig:
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False

def __post_init__(self):
self._verify_config()
Expand All @@ -108,6 +109,10 @@ def _verify_config(self) -> None:
self.dtype in _ALLOWED_DTYPES
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"

# skip using casting when the data type is float32
if self.dtype == torch.float32:
self.high_precision = False

# check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()
Expand Down
2 changes: 2 additions & 0 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
self.high_precision = inference_config.high_precision
model = model.eval()
model = model.cuda()
model.to(self.dtype)
Expand Down Expand Up @@ -297,6 +298,7 @@ def step(self) -> List[str]:
batch,
self.k_cahce,
self.v_cache,
self.high_precision,
)

if self.inference_config.pad_input:
Expand Down
176 changes: 110 additions & 66 deletions colossalai/inference/modeling/models/nopadding_llama.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def benchmark_inference(args):

data = data_gen(mbsz, args.seq_len)

data = data.tolist()
if args.mode == "colossalai" or args.mode == "vllm":
data = data.tolist()

generation_config = GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
Expand Down
17 changes: 17 additions & 0 deletions extensions/csrc/common/micros.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
switch (HIGH_PRECISION) { \
case false: { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
case true: { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
default: \
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
Expand Down
13 changes: 13 additions & 0 deletions extensions/csrc/common/mp_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,18 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};

template <bool high_precision, typename scalar_t>
struct ScalarTypeTrait;

template <typename T>
struct ScalarTypeTrait<true, T> {
using Type = typename MPTypeTrait<T>::Type;
};

template <typename T>
struct ScalarTypeTrait<false, T> {
using Type = T;
};

} // namespace common
} // namespace colossalAI
195 changes: 195 additions & 0 deletions extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "utils/vector_copy_utils.h"
#include "../common/micros.h"

template<typename scalar_t, int VecSize>
__global__ void context_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ cu_seqlens,
const int* __restrict__ block_tables,
const int head_num,
const int head_dim,
const int block_size,
const int batch_size,
const int block_table_stride,
const int64_t key_stride,
const int64_t value_stride
)
{
const int seq_token_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];

if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
return ;
}

const int block_offset = seq_token_id % block_size;
const int hidden_size = head_num * head_dim;
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id;
int head_offset;
int64_t key_src_id;
int64_t value_src_id;
int64_t target_id;

int i = threadIdx.x * VecSize;

for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}

// tail process
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}

}

template<typename scalar_t>
void apply_context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int batch_size = block_tables.size(0);

int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);

int vec_size = get_vec_size<scalar_t>(key);

if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}

int thread_nums = head_num * head_dim / vec_size;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(max_seq_len_in_batch, batch_size);
dim3 block(std::min(thread_nums, 512));

switch (vec_size) {
case 1:
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 2:
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 4:
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}

AT_CUDA_CHECK(cudaGetLastError());

}

void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"context_kv_cache_memcpy",
apply_context_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
}
17 changes: 16 additions & 1 deletion extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ __global__ void decode_kv_cache_memcpy_kernel(
return ;
}

for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
int i = threadIdx.x * VecSize;

for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
Expand All @@ -43,6 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}

for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}

}

template<typename scalar_t>
Expand Down
Loading

0 comments on commit 87079cf

Please sign in to comment.