Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update skip layer norm #22719

Merged
merged 21 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 78 additions & 67 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,13 @@ void ComputeJob(
const T* gamma_data,
const T* beta_data,
const T* bias_data,
IAllocatorUniquePtr<float>& skip_float_uptr,
IAllocatorUniquePtr<float>& gamma_float_uptr,
IAllocatorUniquePtr<float>& beta_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
T* output_data,
T* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(skip_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(gamma_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(beta_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(alloc);

T* skip_input_bias_add_output_data) {
auto offset = task_idx * hidden_size;
const T* p_input = input_data + offset;
const T* p_skip = skip_data + (offset % skip_size);
Expand Down Expand Up @@ -110,13 +99,11 @@ void ComputeJob(
void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
amarin16 marked this conversation as resolved.
Show resolved Hide resolved
const MLFloat16* gamma_data,
const MLFloat16* beta_data,
const MLFloat16* bias_data,
IAllocatorUniquePtr<float>& skip_float_uptr,
IAllocatorUniquePtr<float>& gamma_float_uptr,
IAllocatorUniquePtr<float>& beta_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const float* prepacked_skip_fp32_data,
const float* gamma_float_ptr,
const float* beta_float_ptr,
const float* bias_float_ptr,
float* output_float_ptr,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
Expand All @@ -127,7 +114,6 @@ void ComputeJob(
AllocatorPtr alloc) {
auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
const MLFloat16* p_skip = skip_data + (offset % skip_size);
MLFloat16* p_output = output_data + offset;
MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

Expand All @@ -138,26 +124,19 @@ void ComputeJob(
IAllocatorUniquePtr<float> input_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems);

if (!skip_float_uptr) {
IAllocatorUniquePtr<float> skip_float_uptr = nullptr;
if (prepacked_skip_fp32_data == nullptr && skip_data) {
const MLFloat16* p_skip = skip_data + (offset % skip_size);
skip_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems);
}

if (bias_data && !bias_float_uptr) {
bias_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems);
}

IAllocatorUniquePtr<float> output_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
float* output_float_ptr = output_float_uptr.get();

const float* input_float_ptr = input_float_uptr.get();
const float* skip_float_ptr = skip_float_uptr.get();
const float* bias_float_ptr = bias_float_uptr.get();
const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
float val = input_float_ptr[h] + skip_float_ptr[h];

if (bias_float_uptr) {
if (bias_float_ptr) {
val += bias_float_ptr[h];
}

Expand All @@ -177,22 +156,10 @@ void ComputeJob(
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

if (!gamma_float_uptr) {
gamma_float_uptr = std::move(input_float_uptr); // overwrite input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, gamma_float_uptr.get(), num_elems);
}

if (beta_data && !beta_float_uptr) {
beta_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_float_uptr.get(), num_elems);
}

const float* gamma_float_ptr = gamma_float_uptr.get();
const float* beta_float_ptr = beta_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
} else if (nullptr == beta_float_uptr) {
} else if (nullptr == beta_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
Expand All @@ -218,18 +185,23 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I

template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) {
: OpKernel(op_kernel_info),
prepacked_skip_fp32_size_(0),
prepacked_skip_fp32_data_(nullptr),
prepacked_gamma_fp32_data_(nullptr),
prepacked_beta_fp32_data_(nullptr),
prepacked_bias_fp32_data_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}

template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input<Tensor>(0);
const Tensor* skip = p_ctx->Input<Tensor>(1);
const Tensor* gamma = p_ctx->Input<Tensor>(2);
const Tensor* beta = p_ctx->Input<Tensor>(3);
const Tensor* bias = p_ctx->Input<Tensor>(4);
const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3);
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
Expand All @@ -238,19 +210,21 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
size_t input_dims_size = input_dims.size();
int hidden_size = static_cast<int>(input_dims[input_dims_size - 1]);

ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs<Tensor>(input,
skip,
gamma,
beta,
bias,
hidden_size,
input_dims_size));
ORT_RETURN_IF_ERROR(skip_layer_norm_helper::CheckPotentiallyPrepackedInputs<Tensor>(input,
skip,
gamma,
beta,
bias,
hidden_size,
input_dims_size,
prepacked_skip_fp32_data_ != nullptr,
prepacked_gamma_fp32_data_ != nullptr));

int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);

const T* input_data = input->Data<T>();
const T* skip_data = skip->Data<T>();
const T* gamma_data = gamma->Data<T>();
const T* skip_data = skip == nullptr ? nullptr : skip->Data<T>();
const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data<T>();
const T* beta_data = beta == nullptr ? nullptr : beta->Data<T>();
const T* bias_data = bias == nullptr ? nullptr : bias->Data<T>();

Expand All @@ -259,17 +233,53 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();

const int64_t& skip_size = skip->Shape().Size();
const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_;

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> output_fp32;
amarin16 marked this conversation as resolved.
Show resolved Hide resolved
IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;

if constexpr (std::is_same_v<T, MLFloat16>) {
const size_t num_elems = static_cast<size_t>(hidden_size);

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);

if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
}

if (prepacked_beta_fp32_data_ == nullptr && beta_data) {
beta_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems);
}

if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
}

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_,
bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
if constexpr (std::is_same_v<T, MLFloat16>) {
ComputeJob(input_data, skip_data,
prepacked_skip_fp32_data_.get(),
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
output_fp32.get(),
task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
amarin16 marked this conversation as resolved.
Show resolved Hide resolved
} else {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
epsilon_, simplified, output_data, skip_input_bias_add_output_data);
}
},
0);

Expand All @@ -284,13 +294,14 @@ Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx

is_packed = false;
if (input_idx == 1) { // skip
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed);
prepacked_skip_fp32_size_ = tensor.Shape().Size();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed);
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
} else if (input_idx == 4) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

return Status::OK();
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ class SkipLayerNorm final : public OpKernel {

private:
float epsilon_;
mutable IAllocatorUniquePtr<float> skip_fp32_;
mutable IAllocatorUniquePtr<float> gamma_fp32_;
mutable IAllocatorUniquePtr<float> beta_fp32_;
mutable IAllocatorUniquePtr<float> bias_fp32_;
int64_t prepacked_skip_fp32_size_;
IAllocatorUniquePtr<float> prepacked_skip_fp32_data_;
IAllocatorUniquePtr<float> prepacked_gamma_fp32_data_;
IAllocatorUniquePtr<float> prepacked_beta_fp32_data_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
};

} // namespace contrib
Expand Down
Loading
Loading