Skip to content

Commit

Permalink
Adding MLPerf optimization to 0.6.0 (opendatahub-io#182)
Browse files Browse the repository at this point in the history
* add weight padding for fp8

* add scaled act

* add scaled rms

* linter
  • Loading branch information
charlifu authored Sep 12, 2024
1 parent b1c3273 commit b53c35d
Show file tree
Hide file tree
Showing 13 changed files with 347 additions and 21 deletions.
42 changes: 42 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#ifdef USE_ROCM
#include "quantization/fp8/amd/hip_float8.h"
#endif

namespace vllm {

// Activation and gating kernel template.
Expand All @@ -23,6 +27,22 @@ __global__ void act_and_mul_kernel(
}
}

// Scaled activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void scaled_act_and_mul_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d, const float scale) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
float r = ACT_FN(x) * y * scale;
out[token_idx * d + idx] = c10::Float8_e4m3fnuz(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
Expand Down Expand Up @@ -69,12 +89,34 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
input.data_ptr<scalar_t>(), d); \
});

// Launch activation and gating kernel.
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void scaled_silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
Expand Down
203 changes: 195 additions & 8 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <c10/cuda/CUDAGuard.h>

#include "dispatch_utils.h"
#include "attention/attention_dtypes.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
Expand All @@ -13,6 +14,8 @@
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include "quantization/fp8/amd/hip_float8.h"
#include "quantization/fp8/amd/quant_utils.cuh"

using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
Expand Down Expand Up @@ -109,11 +112,11 @@ __global__ void rms_norm_kernel(

template <typename scalar_t>
__global__ void scaled_rms_norm_kernel(
hip_fp8* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* scale, const float epsilon, const int num_tokens,
const int hidden_size, const int hidden_size_padded) {
c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float scale, const float epsilon, const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

Expand All @@ -133,9 +136,9 @@ __global__ void scaled_rms_norm_kernel(

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
x = (x * s_variance) * (float)weight[idx] / (*scale);

out[blockIdx.x * hidden_size_padded + idx] = hip_fp8(x);
float r = (x * s_variance) * weight[idx] * scale;
out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}

Expand Down Expand Up @@ -379,6 +382,123 @@ fused_add_rms_norm_kernel(
}
}

/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */

template <>
struct Vec<c10::Float8_e4m3fnuz, 8> {
using Type = uint2;
};

template <>
struct Vec<c10::Half, 8> {
using Type = uint4;
};

template <>
struct Vec<c10::BFloat16, 8> {
using Type = bf16_8_t;
};

template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
scaled_fused_add_rms_norm_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float scale, const int num_tokens,
const int hidden_size) {
using in_v_t = typename Vec<scalar_t, width>::Type;
using out_v_t = typename Vec<c10::Float8_e4m3fnuz, width>::Type;
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ out_v = reinterpret_cast<out_v_t*>(out);
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
out_v_t temp_quant = fp8::scaled_vec_conversion<out_v_t, in_v_t>(
*reinterpret_cast<in_v_t*>(&temp), scale);
out_v[id] = temp_quant;
}
}

/* Generic scaled_fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
scaled_fused_add_rms_norm_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float scale, const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float r = (x * s_variance) * (float)weight[idx] / scale;
out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}

} // namespace vllm

void rms_norm(torch::Tensor& out, // [..., hidden_size]
Expand All @@ -399,6 +519,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
});
}

void scaled_rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_rms_norm_kernel", [&] {
vllm::scaled_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fnuz>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), 1.0 / (*scale.data_ptr<float>()),
epsilon, num_tokens, hidden_size);
});
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
Expand Down Expand Up @@ -443,3 +583,50 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}

#define LAUNCH_SCALED_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_fused_add_rms_norm_kernel", [&] { \
vllm::scaled_fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
*scale.data_ptr<float>(), num_tokens, hidden_size); \
});

void scaled_fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_SCALED_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_SCALED_FUSED_ADD_RMS_NORM(0);
}
}
11 changes: 11 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);

void scaled_rms_norm(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);

void scaled_fused_add_rms_norm(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& residual, torch::Tensor& weight,
torch::Tensor& scale, double epsilon);

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
Expand All @@ -44,6 +52,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void scaled_silu_and_mul(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
20 changes: 20 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

// Activation function used in SwiGLU.
ops.def("scaled_silu_and_mul(Tensor! out, Tensor input, Tensor scale) -> ()");
ops.impl("scaled_silu_and_mul", torch::kCUDA, &scaled_silu_and_mul);

// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
Expand Down Expand Up @@ -89,6 +93,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

// Apply Root Mean Square (RMS) Normalization to the input tensor with scaled
// output.
ops.def(
"scaled_rms_norm(Tensor! out, Tensor input, Tensor weight, Tensor scale, "
"float epsilon) -> "
"()");
ops.impl("scaled_rms_norm", torch::kCUDA, &scaled_rms_norm);

// Fused Add and RMS Normalization with scaled output.
ops.def(
"scaled_fused_add_rms_norm(Tensor! out, Tensor input, Tensor! residual, "
"Tensor weight, "
"Tensor scale, float epsilon) -> ()");
ops.impl("scaled_fused_add_rms_norm", torch::kCUDA,
&scaled_fused_add_rms_norm);

// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
Expand Down
Loading

0 comments on commit b53c35d

Please sign in to comment.