-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706)
* add convert_fp8 op for fp8 test in the future * rerun ci
- Loading branch information
1 parent
bfad393
commit 50104ab
Showing
5 changed files
with
197 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
#include <torch/extension.h> | ||
#include <ATen/cuda/Exceptions.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include <cmath> | ||
|
||
#include "common/micros.h" | ||
#include "utils/vec_copy.h" | ||
#include "funcs/cast_functor.h" | ||
|
||
|
||
using colossalAI::cuda::utils::copy; | ||
using colossalAI::cuda::utils::get_vec_size; | ||
using colossalAI::funcs::CastFunctor; | ||
|
||
template <typename InT, typename OutT, int VecSize> | ||
__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail) | ||
{ | ||
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x); | ||
const int64_t grid_size = blockDim.x * gridDim.x; | ||
if(idx > numel + tail) { | ||
return; | ||
} | ||
|
||
for(int64_t i = idx; i < numel; i += grid_size) { | ||
copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize); | ||
} | ||
// Tail process | ||
if(threadIdx.x == 0) | ||
{ | ||
for(int i = 0; i < tail; ++i) | ||
{ | ||
outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]); | ||
} | ||
} | ||
} | ||
|
||
template <typename InT, typename OutT> | ||
void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output) | ||
{ | ||
const int kVecSize = get_vec_size<InT>(input); | ||
const int kNumel = torch::numel(input); | ||
|
||
const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize))); | ||
const int kTail = kNumel & (kVecSize - 1); | ||
int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1; | ||
|
||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
dim3 grid(grid_size); | ||
dim3 block(256); | ||
|
||
#define _(VEC_SIZE) \ | ||
convert_fp8_kernel<InT, OutT, VEC_SIZE> \ | ||
<<<grid, block, 0, stream>>> \ | ||
(reinterpret_cast<const InT*>(input.data_ptr()), \ | ||
reinterpret_cast<OutT*>(output.data_ptr()), \ | ||
kVecNumel, \ | ||
kTail) | ||
|
||
switch (kVecSize) | ||
{ | ||
case 1: | ||
_(1); | ||
break; | ||
case 2: | ||
_(2); | ||
break; | ||
case 4: | ||
_(4); | ||
break; | ||
} | ||
#undef _ | ||
AT_CUDA_CHECK(cudaGetLastError()); | ||
} | ||
|
||
void convert_fp8(torch::Tensor& input, torch::Tensor& output) | ||
{ | ||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!"); | ||
TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!"); | ||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || | ||
input.scalar_type() == at::ScalarType::Float || | ||
input.scalar_type() == at::ScalarType::Half || | ||
input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!"); | ||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte || | ||
output.scalar_type() == at::ScalarType::Float || | ||
output.scalar_type() == at::ScalarType::Half || | ||
output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!"); | ||
TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!"); | ||
|
||
#define _(InT, OutT) \ | ||
apply_convert_fp8<InT, OutT>(input, output) | ||
|
||
|
||
if(input.scalar_type() == at::ScalarType::Byte) | ||
{ | ||
if(output.scalar_type() == at::ScalarType::Float) | ||
{ | ||
_(uint8_t, float); | ||
} | ||
else if(output.scalar_type() == at::ScalarType::Half) | ||
{ | ||
_(uint8_t, half); | ||
} | ||
else if(output.scalar_type() == at::ScalarType::BFloat16) | ||
{ | ||
_(uint8_t, __nv_bfloat16); | ||
} | ||
} | ||
else | ||
{ | ||
if(input.scalar_type() == at::ScalarType::Float) | ||
{ | ||
_(float, uint8_t); | ||
} | ||
else if(input.scalar_type() == at::ScalarType::Half) | ||
{ | ||
_(half, uint8_t); | ||
} | ||
else if(input.scalar_type() == at::ScalarType::BFloat16) | ||
{ | ||
_(__nv_bfloat16, uint8_t); | ||
} | ||
} | ||
|
||
#undef _ | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import random | ||
|
||
import pytest | ||
import torch | ||
|
||
from colossalai.kernel.kernel_loader import InferenceOpsLoader | ||
from colossalai.utils import get_current_device | ||
|
||
inference_ops = InferenceOpsLoader().load() | ||
|
||
DTYPES = [torch.half, torch.bfloat16, torch.float] | ||
NUM_TOKENS = [42] # Arbitrary values for testing | ||
NUM_LAYERS = [1] # Arbitrary values for testing | ||
NUM_HEADS = [8] # Arbitrary values for testing | ||
HEAD_SIZES = [64, 80, 96, 112, 128, 256] | ||
BLOCK_SIZES = [8, 16, 32] | ||
|
||
|
||
@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!") | ||
@pytest.mark.parametrize("num_heads", [8]) | ||
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256]) | ||
@pytest.mark.parametrize("block_size", [8, 16, 32]) | ||
@pytest.mark.parametrize("num_blocks", [1024, 10000]) | ||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float]) | ||
@pytest.mark.parametrize("seed", [0]) | ||
@torch.inference_mode() | ||
def test_fp8_conversion( | ||
num_heads: int, | ||
head_size: int, | ||
block_size: int, | ||
num_blocks: int, | ||
dtype: torch.dtype, | ||
seed: int, | ||
) -> None: | ||
random.seed(seed) | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
device = get_current_device() | ||
|
||
low = -224.0 | ||
high = 224.0 | ||
shape = (num_blocks, num_heads, head_size, block_size) | ||
cache = torch.empty(shape, dtype=dtype, device=device) | ||
cache.uniform_(low, high) | ||
|
||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) | ||
inference_ops.convert_fp8(cache, cache_fp8) | ||
|
||
converted_cache = torch.empty_like(cache) | ||
inference_ops.convert_fp8(cache_fp8, converted_cache) | ||
|
||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0) |