Skip to content

Commit

Permalink
add vec_type_trait implementation (#5473)
Browse files Browse the repository at this point in the history
  • Loading branch information
Courtesy-Xs authored Mar 19, 2024
1 parent b96557b commit 7ff42cc
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 113 deletions.
12 changes: 4 additions & 8 deletions extensions/csrc/common/mp_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,22 @@ namespace colossalAI {
namespace common {

template <typename T>
class MPTypeTrait {
public:
struct MPTypeTrait {
using Type = float;
};

template <>
class MPTypeTrait<float> {
public:
struct MPTypeTrait<float> {
using Type = float;
};

template <>
class MPTypeTrait<at::Half> {
public:
struct MPTypeTrait<at::Half> {
using Type = float;
};

template <>
class MPTypeTrait<at::BFloat16> {
public:
struct MPTypeTrait<at::BFloat16> {
using Type = float;
};

Expand Down
1 change: 0 additions & 1 deletion extensions/csrc/cuda/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "../common/micros.h"
#include "../common/mp_type_traits.h"
#include "utils/gpu_launch_config.h"

template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
Expand Down
75 changes: 73 additions & 2 deletions extensions/csrc/cuda/utils/vec_type_traits.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,82 @@
#pragma once

#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>

#include <cfloat>

namespace colossalAI {
namespace cuda {
namespace utils {

template <typename T, size_t VecSize>
class VecTypeTraits {};
template <typename T, int VecSize>
struct VecTypeTrait {};

template <typename T>
struct VecTypeTrait<T, 1> {
using Type = T;
};

template <>
struct VecTypeTrait<c10::BFloat16, 2> {
using Type = float;
};

template <>
struct VecTypeTrait<c10::BFloat16, 4> {
using Type = float2;
};

template <>
struct VecTypeTrait<c10::BFloat16, 8> {
using Type = float4;
};

template <>
struct VecTypeTrait<c10::Half, 2> {
using Type = float;
};

template <>
struct VecTypeTrait<c10::Half, 4> {
using Type = float2;
};

template <>
struct VecTypeTrait<c10::Half, 8> {
using Type = float4;
};

template <>
struct VecTypeTrait<float, 2> {
using Type = float2;
};

template <>
struct VecTypeTrait<float, 4> {
using Type = float4;
};

template <>
struct VecTypeTrait<float, 8> {
using Type = float4;
};

template <>
struct VecTypeTrait<uint8_t, 2> {
using Type = half;
};

template <>
struct VecTypeTrait<uint8_t, 4> {
using Type = half2;
};

template <>
struct VecTypeTrait<uint8_t, 8> {
using Type = float2;
};

} // namespace utils
} // namespace cuda
Expand Down
120 changes: 18 additions & 102 deletions extensions/csrc/cuda/utils/vector_copy_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,117 +5,28 @@
#include <cuda_fp16.h>
#include <stdint.h>

#include <cfloat>
#include "vec_type_traits.h"

template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float *)dst) = *((float *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float4 *)dst) = *((float4 *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
const c10::Half *src) {
*((float *)dst) = *((float *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
const c10::Half *src) {
*((float4 *)dst) = *((float4 *)src);
}

template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}

template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}

template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
*dst = *src;
}

template <>
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}

template <>
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
*((float4 *)dst) = *((float4 *)src);
template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<VT *>(src));
}

template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*((float4 *)dst) = *((float4 *)src);
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
}

template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);

template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst) {
*dst = 0.0;
}

template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst) {
*((float2 *)dst) = make_float2(0.0f, 0.0f);
}

template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) {
*dst = 0.0;
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<float4 *>(src + 4));
}

template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
*((float2 *)dst) = make_float2(0.0f, 0.0f);
template <typename T, int VecSize>
__device__ __inline__ void copy_zero_vector(T *dst) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = {0.0};
}

template <typename T>
Expand All @@ -126,6 +37,11 @@ int get_vec_size(const torch::Tensor &tensor) {

const int vec_size = max_aligned_size / sizeof(T) / 8;

// Note(LiuYang): Performance of situation of which
// vec_size equals to 8 need to be profiled in the future
// if (address % (dtype_size * 8) == 0) {
// return std::min(8, vec_size);
// }
if (address % (dtype_size * 4) == 0) {
return std::min(4, vec_size);
} else if (address % (dtype_size * 2) == 0) {
Expand Down

0 comments on commit 7ff42cc

Please sign in to comment.