Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
  • Loading branch information
liqunfu committed Nov 16, 2024
1 parent e73eaf4 commit 07cc88f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_;

if constexpr (std::is_same_v<T, MLFloat16>) {
const int64_t total_data_size = input->Shape().Size();
const size_t total_data_size = static_cast<size_t>(input->Shape().Size());

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
Expand All @@ -185,14 +185,14 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {

const size_t num_elems = static_cast<size_t>(hidden_size);

input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(total_data_size));
input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size);
input_data_f = input_fp32.get();

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(total_data_size));
output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
output_data_f = output_fp32.get();

skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(total_data_size));
skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get();

if (skip_data) {
Expand Down Expand Up @@ -234,9 +234,9 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f);
},
0);
MlasConvertFloatToHalfBuffer(output_data_f, output_data, static_cast<size_t>(total_data_size));
MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size);
if (skip_input_bias_add_output_data != nullptr)
MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, static_cast<size_t>(total_data_size));
MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size);
} else {
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
Expand All @@ -262,7 +262,7 @@ Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
} else if (input_idx == 3) {
if (simplified) {
if constexpr (simplified) {
// bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
} else {
Expand Down

0 comments on commit 07cc88f

Please sign in to comment.