Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Optimize the writing style of tail processing and the logic related to macro definitions. #5519

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/inference/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ROOT=$(realpath $(dirname $0))
echo $ROOT
PY_SCRIPT=${ROOT}/benchmark_llama.py
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
mode="colossalai"
mode=$1

mkdir -p logs

Expand Down
23 changes: 8 additions & 15 deletions extensions/csrc/common/micros.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,14 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
switch (HIGH_PRECISION) { \
case false: { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
case true: { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
default: \
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
if (HIGH_PRECISION) { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
} else { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
Expand Down
16 changes: 5 additions & 11 deletions extensions/csrc/common/mp_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,11 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};

template <bool high_precision, typename scalar_t>
struct ScalarTypeTrait;

template <typename T>
struct ScalarTypeTrait<true, T> {
using Type = typename MPTypeTrait<T>::Type;
};

template <typename T>
struct ScalarTypeTrait<false, T> {
using Type = T;
template <bool high_precision, typename T>
struct ScalarTypeTrait {
using Type =
typename std::conditional<high_precision, typename MPTypeTrait<T>::Type,
T>::type;
};

} // namespace common
Expand Down
133 changes: 60 additions & 73 deletions extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"

template<typename scalar_t, int VecSize>
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void context_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
Expand Down Expand Up @@ -55,17 +55,19 @@ __global__ void context_kv_cache_memcpy_kernel(
}

// tail process
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
if (!Aligned) {
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}

}
Expand Down Expand Up @@ -93,76 +95,61 @@ void apply_context_kv_cache_memcpy(

int vec_size = get_vec_size<scalar_t>(key);

bool aligned = true;
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
aligned = false;
}

int thread_nums = head_num * head_dim / vec_size;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(max_seq_len_in_batch, batch_size);
dim3 block(std::min(thread_nums, 512));

switch (vec_size) {
case 1:
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 2:
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 4:
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
context_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
cu_seqlens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
batch_size, \
block_table_stride, \
key_stride, \
value_stride \
); \
} while(0)

#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
do { \
switch (vec_size) { \
case 1: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", vec_size); \
break; \
} \
} while(0)


if (aligned) {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
}
else {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
}

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
124 changes: 57 additions & 67 deletions extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"

template<typename scalar_t, int VecSize>
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
Expand Down Expand Up @@ -45,17 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}

for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
if (!Aligned) {
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}

}
Expand All @@ -80,70 +82,58 @@ void apply_decode_kv_cache_memcpy(

int vec_size = get_vec_size<scalar_t>(key);

bool aligned = true;
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
aligned = false;
}

int thread_nums = head_num * head_dim / vec_size;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));

switch (vec_size) {
case 1:
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 2:
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 4:
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
decode_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
key_stride, \
value_stride, \
block_table_stride \
); \
} while(0)

#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \
do { \
switch (__vec_size) { \
case 1: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", __vec_size); \
break; \
} \
} while(0)

if (aligned) {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size);
}
else {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size);
}

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
Loading