Skip to content

Commit

Permalink
[Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706)
Browse files Browse the repository at this point in the history
* add convert_fp8 op for fp8 test in the future

* rerun ci
  • Loading branch information
Courtesy-Xs authored May 10, 2024
1 parent bfad393 commit 50104ab
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 10 deletions.
127 changes: 127 additions & 0 deletions extensions/csrc/kernel/cuda/convert_fp8_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAContext.h>

#include <cmath>

#include "common/micros.h"
#include "utils/vec_copy.h"
#include "funcs/cast_functor.h"


using colossalAI::cuda::utils::copy;
using colossalAI::cuda::utils::get_vec_size;
using colossalAI::funcs::CastFunctor;

template <typename InT, typename OutT, int VecSize>
__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail)
{
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
const int64_t grid_size = blockDim.x * gridDim.x;
if(idx > numel + tail) {
return;
}

for(int64_t i = idx; i < numel; i += grid_size) {
copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize);
}
// Tail process
if(threadIdx.x == 0)
{
for(int i = 0; i < tail; ++i)
{
outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]);
}
}
}

template <typename InT, typename OutT>
void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output)
{
const int kVecSize = get_vec_size<InT>(input);
const int kNumel = torch::numel(input);

const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize)));
const int kTail = kNumel & (kVecSize - 1);
int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(grid_size);
dim3 block(256);

#define _(VEC_SIZE) \
convert_fp8_kernel<InT, OutT, VEC_SIZE> \
<<<grid, block, 0, stream>>> \
(reinterpret_cast<const InT*>(input.data_ptr()), \
reinterpret_cast<OutT*>(output.data_ptr()), \
kVecNumel, \
kTail)

switch (kVecSize)
{
case 1:
_(1);
break;
case 2:
_(2);
break;
case 4:
_(4);
break;
}
#undef _
AT_CUDA_CHECK(cudaGetLastError());
}

void convert_fp8(torch::Tensor& input, torch::Tensor& output)
{
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!");
TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte ||
input.scalar_type() == at::ScalarType::Float ||
input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!");
TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte ||
output.scalar_type() == at::ScalarType::Float ||
output.scalar_type() == at::ScalarType::Half ||
output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!");
TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!");

#define _(InT, OutT) \
apply_convert_fp8<InT, OutT>(input, output)


if(input.scalar_type() == at::ScalarType::Byte)
{
if(output.scalar_type() == at::ScalarType::Float)
{
_(uint8_t, float);
}
else if(output.scalar_type() == at::ScalarType::Half)
{
_(uint8_t, half);
}
else if(output.scalar_type() == at::ScalarType::BFloat16)
{
_(uint8_t, __nv_bfloat16);
}
}
else
{
if(input.scalar_type() == at::ScalarType::Float)
{
_(float, uint8_t);
}
else if(input.scalar_type() == at::ScalarType::Half)
{
_(half, uint8_t);
}
else if(input.scalar_type() == at::ScalarType::BFloat16)
{
_(__nv_bfloat16, uint8_t);
}
}

#undef _
}
17 changes: 7 additions & 10 deletions extensions/csrc/kernel/cuda/utils/vec_copy.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

#pragma once

#include <cuda_fp16.h>
#include <stdint.h>

#include "common/vec_type_traits.h"
#include "funcs/cast_functor.h"

Expand All @@ -12,9 +9,9 @@ namespace cuda {
namespace utils {

// Note(LiuYang): Depreciated
template <typename T, int vec_size>
template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}

Expand All @@ -34,17 +31,17 @@ __device__ __inline__ void copy_zero_vector(T *dst) {
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
}

template <typename SrcT, typename DstT, int vec_size>
template <typename SrcT, typename DstT, int VecSize>
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
using SrcVT = typename common::VecTypeTrait<SrcT, VecSize>::Type;
using DstVT = typename common::VecTypeTrait<DstT, VecSize>::Type;
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
*(reinterpret_cast<const SrcVT *>(src)));
}

template <typename T, int vec_size>
template <typename T, int VecSize>
__device__ __inline__ void copy(const T *src, T *dst) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}

Expand Down
5 changes: 5 additions & 0 deletions extensions/pybind/inference/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ void flash_decoding_attention(
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes, float scale);

void convert_fp8(torch::Tensor& input, torch::Tensor& output);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
Expand Down Expand Up @@ -102,4 +104,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("flash_decoding_attention", &flash_decoding_attention,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");

m.def("convert_fp8", &convert_fp8,
"Convert input to fp8 output or convert fp8 input to output.");
}
1 change: 1 addition & 0 deletions extensions/pybind/inference/inference_ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def sources_files(self):
"kernel/cuda/rms_layernorm_kernel.cu",
"kernel/cuda/get_cos_and_sin_kernel.cu",
"kernel/cuda/flash_decoding_attention_kernel.cu",
"kernel/cuda/convert_fp8_kernel.cu",
]
] + [self.pybind_abs_path("inference/inference.cpp")]
return ret
Expand Down
57 changes: 57 additions & 0 deletions tests/test_infer/test_kernels/cuda/test_convert_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import random

import pytest
import torch

from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device

inference_ops = InferenceOpsLoader().load()

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]


@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!")
@pytest.mark.parametrize("num_heads", [8])
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256])
@pytest.mark.parametrize("block_size", [8, 16, 32])
@pytest.mark.parametrize("num_blocks", [1024, 10000])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
@pytest.mark.parametrize("seed", [0])
@torch.inference_mode()
def test_fp8_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = get_current_device()

low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)

cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
inference_ops.convert_fp8(cache, cache_fp8)

converted_cache = torch.empty_like(cache)
inference_ops.convert_fp8(cache_fp8, converted_cache)

assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)


if __name__ == "__main__":
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0)

0 comments on commit 50104ab

Please sign in to comment.