diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index c9ee9e2cb760d..12af1ac784209 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -46,24 +46,13 @@ void ComputeJob( const T* gamma_data, const T* beta_data, const T* bias_data, - IAllocatorUniquePtr& skip_float_uptr, - IAllocatorUniquePtr& gamma_float_uptr, - IAllocatorUniquePtr& beta_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, ptrdiff_t task_idx, int hidden_size, int64_t skip_size, float epsilon, bool simplified, T* output_data, - T* skip_input_bias_add_output_data, - AllocatorPtr alloc) { - ORT_UNUSED_PARAMETER(skip_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(gamma_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(beta_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(alloc); - + T* skip_input_bias_add_output_data) { auto offset = task_idx * hidden_size; const T* p_input = input_data + offset; const T* p_skip = skip_data + (offset % skip_size); @@ -110,13 +99,11 @@ void ComputeJob( void ComputeJob( const MLFloat16* input_data, const MLFloat16* skip_data, - const MLFloat16* gamma_data, - const MLFloat16* beta_data, - const MLFloat16* bias_data, - IAllocatorUniquePtr& skip_float_uptr, - IAllocatorUniquePtr& gamma_float_uptr, - IAllocatorUniquePtr& beta_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, + const float* prepacked_skip_fp32_data, + const float* gamma_float_ptr, + const float* beta_float_ptr, + const float* bias_float_ptr, + float* output_float_ptr, ptrdiff_t task_idx, int hidden_size, int64_t skip_size, @@ -127,7 +114,6 @@ void ComputeJob( AllocatorPtr alloc) { auto offset = task_idx * hidden_size; const MLFloat16* p_input = input_data + offset; - const MLFloat16* p_skip = skip_data + (offset % skip_size); MLFloat16* p_output = output_data + offset; MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset; @@ -138,26 +124,19 @@ void ComputeJob( IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); - if (!skip_float_uptr) { + IAllocatorUniquePtr skip_float_uptr = nullptr; + if (prepacked_skip_fp32_data == nullptr && skip_data) { + const MLFloat16* p_skip = skip_data + (offset % skip_size); skip_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems); } - if (bias_data && !bias_float_uptr) { - bias_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems); - } - - IAllocatorUniquePtr output_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - float* output_float_ptr = output_float_uptr.get(); - const float* input_float_ptr = input_float_uptr.get(); - const float* skip_float_ptr = skip_float_uptr.get(); - const float* bias_float_ptr = bias_float_uptr.get(); + const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get(); for (size_t h = 0; h < num_elems; h++) { float val = input_float_ptr[h] + skip_float_ptr[h]; - if (bias_float_uptr) { + if (bias_float_ptr) { val += bias_float_ptr[h]; } @@ -177,22 +156,10 @@ void ComputeJob( mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); } - if (!gamma_float_uptr) { - gamma_float_uptr = std::move(input_float_uptr); // overwrite input with gamma values, since they have the same size - MlasConvertHalfToFloatBuffer(gamma_data, gamma_float_uptr.get(), num_elems); - } - - if (beta_data && !beta_float_uptr) { - beta_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(beta_data, beta_float_uptr.get(), num_elems); - } - - const float* gamma_float_ptr = gamma_float_uptr.get(); - const float* beta_float_ptr = beta_float_uptr.get(); for (size_t h = 0; h < num_elems; h++) { if (simplified) { output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h]; - } else if (nullptr == beta_float_uptr) { + } else if (nullptr == beta_float_ptr) { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h]; } else { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h]; @@ -218,7 +185,12 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I template SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) { + : OpKernel(op_kernel_info), + prepacked_skip_fp32_size_(0), + prepacked_skip_fp32_data_(nullptr), + prepacked_gamma_fp32_data_(nullptr), + prepacked_beta_fp32_data_(nullptr), + prepacked_bias_fp32_data_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } @@ -226,10 +198,10 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) template Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); - const Tensor* skip = p_ctx->Input(1); - const Tensor* gamma = p_ctx->Input(2); - const Tensor* beta = p_ctx->Input(3); - const Tensor* bias = p_ctx->Input(4); + const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input(1); + const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input(2); + const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(4); Tensor* output = p_ctx->Output(0, input->Shape()); // For inferencing, we support one more optional output which is the sum of the input and skip tensors Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape()); @@ -238,19 +210,21 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { size_t input_dims_size = input_dims.size(); int hidden_size = static_cast(input_dims[input_dims_size - 1]); - ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, - skip, - gamma, - beta, - bias, - hidden_size, - input_dims_size)); + ORT_RETURN_IF_ERROR(skip_layer_norm_helper::CheckPotentiallyPrepackedInputs(input, + skip, + gamma, + beta, + bias, + hidden_size, + input_dims_size, + prepacked_skip_fp32_data_ != nullptr, + prepacked_gamma_fp32_data_ != nullptr)); int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1); const T* input_data = input->Data(); - const T* skip_data = skip->Data(); - const T* gamma_data = gamma->Data(); + const T* skip_data = skip == nullptr ? nullptr : skip->Data(); + const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data(); const T* beta_data = beta == nullptr ? nullptr : beta->Data(); const T* bias_data = bias == nullptr ? nullptr : bias->Data(); @@ -259,17 +233,53 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { // For inferencing, we support one more optional output which is the sum of the input and skip tensors T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData(); - const int64_t& skip_size = skip->Shape().Size(); + const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_; AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + IAllocatorUniquePtr output_fp32; + IAllocatorUniquePtr gamma_fp32; + IAllocatorUniquePtr beta_fp32; + IAllocatorUniquePtr bias_fp32; + + if constexpr (std::is_same_v) { + const size_t num_elems = static_cast(hidden_size); + + output_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + + if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) { + gamma_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems); + } + + if (prepacked_beta_fp32_data_ == nullptr && beta_data) { + beta_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems); + } + + if (prepacked_bias_fp32_data_ == nullptr && bias_data) { + bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + } + } + concurrency::ThreadPool::TryBatchParallelFor( p_ctx->GetOperatorThreadPool(), static_cast(task_count), [&](ptrdiff_t task_idx) { - ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_, - bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, - skip_input_bias_add_output_data, alloc); + if constexpr (std::is_same_v) { + ComputeJob(input_data, skip_data, + prepacked_skip_fp32_data_.get(), + prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(), + prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(), + prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), + output_fp32.get(), + task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, + skip_input_bias_add_output_data, alloc); + } else { + ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, + epsilon_, simplified, output_data, skip_input_bias_add_output_data); + } }, 0); @@ -284,13 +294,14 @@ Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx is_packed = false; if (input_idx == 1) { // skip - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed); + prepacked_skip_fp32_size_ = tensor.Shape().Size(); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed); } else if (input_idx == 2) { // gamma - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed); } else if (input_idx == 3) { // beta - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); } else if (input_idx == 4) { // bias - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index d904c14857437..e725f648fe275 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -21,10 +21,11 @@ class SkipLayerNorm final : public OpKernel { private: float epsilon_; - mutable IAllocatorUniquePtr skip_fp32_; - mutable IAllocatorUniquePtr gamma_fp32_; - mutable IAllocatorUniquePtr beta_fp32_; - mutable IAllocatorUniquePtr bias_fp32_; + int64_t prepacked_skip_fp32_size_; + IAllocatorUniquePtr prepacked_skip_fp32_data_; + IAllocatorUniquePtr prepacked_gamma_fp32_data_; + IAllocatorUniquePtr prepacked_beta_fp32_data_; + IAllocatorUniquePtr prepacked_bias_fp32_data_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h index 6271f822287e6..4c901f5650dbd 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h @@ -11,14 +11,10 @@ namespace onnxruntime { namespace contrib { namespace skip_layer_norm_helper { +namespace { + template -Status CheckInputs(const T* input, - const T* skip, - const T* gamma, - const T* beta, - const T* bias, - int hidden_size_check, - size_t input_dims_size_check) { +Status CheckSkip(const T* input, const T* skip, size_t input_dims_size_check) { const auto& input_dims_check = input->Shape().GetDims(); const auto& skip_dims_check = skip->Shape().GetDims(); size_t skip_dims_size_check = skip_dims_check.size(); @@ -33,49 +29,150 @@ Status CheckInputs(const T* input, "skip is expected to have same shape as input or, a batch size of 1 or no batch size when input has 3 dimensions"); } - if (input_dims_size_check != 3 && input_dims_size_check != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); - } - if (skip_dims_check[skip_dims_size_check - 1] != input_dims_check[input_dims_size_check - 1] || skip_dims_check[skip_dims_size_check - 2] != input_dims_check[input_dims_size_check - 2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last two dimensions of skip needs to be same as input"); } + return Status::OK(); +} + +template +Status CheckGamma(const T* gamma, int hidden_size_check) { const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } + if (gamma_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of gamma and input does not match"); } + return Status::OK(); +} + +template +Status CheckBeta(const T* beta, int hidden_size_check) { if (nullptr != beta) { const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } + if (beta_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of beta and input does not match"); } } + return Status::OK(); +} + +template +Status CheckBias(const T* bias, int hidden_size_check) { if (nullptr != bias) { const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bias is expected to have 1 dimension, got ", bias_dims.size()); } + if (bias_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of bias and input does not match"); } } + + return Status::OK(); +} + +} // anonymous namespace + +template +Status CheckInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check) { + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + auto status = CheckSkip(input, skip, input_dims_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckGamma(gamma, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBeta(beta, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBias(bias, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + return Status::OK(); +} + +template +Status CheckPotentiallyPrepackedInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check, + bool prepacked_skip, + bool prepacked_gamma) { + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + if (nullptr != skip) { + auto status = CheckSkip(input, skip, input_dims_size_check); + if (status != Status::OK()) { + return status; + } + } else if (!prepacked_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "skip is expected but not provided"); + } + + if (nullptr != gamma) { + auto status = CheckGamma(gamma, hidden_size_check); + if (status != Status::OK()) { + return status; + } + } else if (!prepacked_gamma) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected but not provided"); + } + + auto status = CheckBeta(beta, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBias(bias, hidden_size_check); + if (status != Status::OK()) { + return status; + } + return Status::OK(); } diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index b9ca55073d411..4e8d1b9f016f0 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -194,6 +194,32 @@ static void RunTest( } } +TEST(SkipLayerNormTest, SkipLayerNormPrePack) { + OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 2; + std::vector input_skip_output_dims = {batch_size, sequence_length, hidden_size}; + std::vector gamma_beta_bias_dims = {hidden_size}; + test.AddInput("x", input_skip_output_dims, ToFloat16({1.f, 1.f, 1.f, 1.f})); + test.AddInput("skip", input_skip_output_dims, ToFloat16({1.f, 1.f, 1.f, 1.f})); + test.AddInput("gamma", gamma_beta_bias_dims, ToFloat16({1.f, 1.f}), true); + test.AddInput("beta", gamma_beta_bias_dims, ToFloat16({1.f, 1.f}), true); + test.AddOutput("output", input_skip_output_dims, ToFloat16({ + 1.f, + 1.f, + 1.f, + 1.f, + })); + + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider}); +} + TEST(SkipLayerNormTest, SkipLayerNormNullInput) { int batch_size = 1; int sequence_length = 0;