From 8fe62812c9154d07ecd48b529a986f3c1488d4e4 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 18 Oct 2024 21:32:18 +0800 Subject: [PATCH] arm neon optimization for layernorm fp32/bf16s/fp16s (#5746) --- src/layer/arm/layernorm_arm.cpp | 517 ++++++++++++++++++++++++ src/layer/arm/layernorm_arm.h | 40 ++ src/layer/arm/layernorm_arm_asimdhp.cpp | 351 ++++++++++++++++ src/net.cpp | 2 +- tests/testutil.cpp | 2 +- 5 files changed, 910 insertions(+), 2 deletions(-) create mode 100644 src/layer/arm/layernorm_arm.cpp create mode 100644 src/layer/arm/layernorm_arm.h create mode 100644 src/layer/arm/layernorm_arm_asimdhp.cpp diff --git a/src/layer/arm/layernorm_arm.cpp b/src/layer/arm/layernorm_arm.cpp new file mode 100644 index 00000000000..4c49a5e76b7 --- /dev/null +++ b/src/layer/arm/layernorm_arm.cpp @@ -0,0 +1,517 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layernorm_arm.h" + +#if __ARM_NEON +#include +#include "neon_mathfun.h" +#endif // __ARM_NEON + +#include "arm_usability.h" +#include "cpu.h" + +namespace ncnn { + +LayerNorm_arm::LayerNorm_arm() +{ +#if __ARM_NEON + support_packing = true; +#if NCNN_ARM82 + support_fp16_storage = cpu_support_arm_asimdhp(); +#endif +#endif // __ARM_NEON + +#if NCNN_BF16 + support_bf16_storage = true; +#endif +} + +static void layernorm(float* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _mean = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float mean = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr0); + _mean = vaddq_f32(_mean, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + mean += ptr0[0]; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean = div_ps(_mean, _elemcount); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + mean += vaddvq_f32(_mean); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_mean), vget_high_f32(_mean)); + _s2 = vpadd_f32(_s2, _s2); + mean += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + mean = mean / elemcount; +#if __ARM_NEON + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + +#if __ARM_NEON + float32x4_t _var = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float var = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr0); + _p = vsubq_f32(_p, _mean); + _var = vmlaq_f32(_var, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = ptr0[0] - mean; + var += v * v; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var = div_ps(_var, _elemcount); + _var = vaddq_f32(_var, _eps); + float32x4_t _rsqrt_var = vrsqrteq_f32(_var); + _rsqrt_var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _mean = vmulq_f32(_mean, _var); + _mean = vnegq_f32(_mean); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + var += vaddvq_f32(_var); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_var), vget_high_f32(_var)); + _s2 = vpadd_f32(_s2, _s2); + var += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + var = 1.f / sqrtf(var / elemcount + eps); + mean = -mean * var; +#if __ARM_NEON + _var = vdupq_n_f32(var); + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + + if (gamma_ptr && beta_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + float32x4_t _beta = vld1q_f32(beta_ptr); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = (ptr[0] * var + mean) * gamma_ptr[0] + beta_ptr[0]; + ptr++; + gamma_ptr++; + beta_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + _p = vmlaq_f32(_mean, _p, _var); + vst1q_f32(ptr, _p); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = ptr[0] * var + mean; + ptr++; + } + } +} + +int LayerNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int elembits = bottom_top_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + return forward_inplace_fp16s(bottom_top_blob, opt); +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + float* ptr = bottom_top_blob; + layernorm(ptr, gamma_data, beta_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + layernorm(ptr, gamma_data, beta_data, eps, w * h, elempack); + } + } + } + + return 0; +} + +#if NCNN_BF16 +static void layernorm_bf16s(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _mean = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float mean = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr0)); + _mean = vaddq_f32(_mean, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + mean += bfloat16_to_float32(ptr0[0]); + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean = div_ps(_mean, _elemcount); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + mean += vaddvq_f32(_mean); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_mean), vget_high_f32(_mean)); + _s2 = vpadd_f32(_s2, _s2); + mean += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + mean = mean / elemcount; +#if __ARM_NEON + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + +#if __ARM_NEON + float32x4_t _var = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float var = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr0)); + _p = vsubq_f32(_p, _mean); + _var = vmlaq_f32(_var, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr0[0]) - mean; + var += v * v; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var = div_ps(_var, _elemcount); + _var = vaddq_f32(_var, _eps); + float32x4_t _rsqrt_var = vrsqrteq_f32(_var); + _rsqrt_var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _mean = vmulq_f32(_mean, _var); + _mean = vnegq_f32(_mean); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + var += vaddvq_f32(_var); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_var), vget_high_f32(_var)); + _s2 = vpadd_f32(_s2, _s2); + var += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + var = 1.f / sqrtf(var / elemcount + eps); + mean = -mean * var; +#if __ARM_NEON + _var = vdupq_n_f32(var); + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + + if (gamma_ptr && beta_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + float32x4_t _beta = vld1q_f32(beta_ptr); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16((v * var + mean) * gamma_ptr[0] + beta_ptr[0]); + ptr++; + gamma_ptr++; + beta_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + _p = vmlaq_f32(_mean, _p, _var); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16(v * var + mean); + ptr++; + } + } +} + +int LayerNorm_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + unsigned short* ptr = bottom_top_blob; + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.channel(q).row(i); + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/arm/layernorm_arm.h b/src/layer/arm/layernorm_arm.h new file mode 100644 index 00000000000..d3bcac1b276 --- /dev/null +++ b/src/layer/arm/layernorm_arm.h @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_LAYERNORM_ARM_H +#define LAYER_LAYERNORM_ARM_H + +#include "layernorm.h" + +namespace ncnn { + +class LayerNorm_arm : public LayerNorm +{ +public: + LayerNorm_arm(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_ARM82 + int forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +}; + +} // namespace ncnn + +#endif // LAYER_LAYERNORM_ARM_H diff --git a/src/layer/arm/layernorm_arm_asimdhp.cpp b/src/layer/arm/layernorm_arm_asimdhp.cpp new file mode 100644 index 00000000000..1b746707dc8 --- /dev/null +++ b/src/layer/arm/layernorm_arm_asimdhp.cpp @@ -0,0 +1,351 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layernorm_arm.h" + +#if __ARM_NEON +#include +#endif // __ARM_NEON + +#include "arm_usability.h" + +namespace ncnn { + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +static void layernorm_fp16s(__fp16* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + + float32x4_t _mean0 = vdupq_n_f32(0.f); + float32x4_t _mean1 = vdupq_n_f32(0.f); + float mean = 0.f; + { + const __fp16* ptr0 = ptr; + + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr0); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _mean0 = vaddq_f32(_mean0, _p0); + _mean1 = vaddq_f32(_mean1, _p1); + ptr0 += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr0)); + _mean0 = vaddq_f32(_mean0, _p); + ptr0 += 4; + } + for (; i < size; i++) + { + mean += (float)ptr0[0]; + ptr0++; + } + } + + if (elempack == 8) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean0 = vdivq_f32(_mean0, _elemcount); + _mean1 = vdivq_f32(_mean1, _elemcount); + } + if (elempack == 4) + { + _mean0 = vaddq_f32(_mean0, _mean1); + + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean0 = vdivq_f32(_mean0, _elemcount); + _mean1 = _mean0; + } + if (elempack == 1) + { + _mean0 = vaddq_f32(_mean0, _mean1); + mean += vaddvq_f32(_mean0); + + mean = mean / elemcount; + _mean0 = vdupq_n_f32(mean); + _mean1 = _mean0; + } + + float32x4_t _var0 = vdupq_n_f32(0.f); + float32x4_t _var1 = vdupq_n_f32(0.f); + float var = 0.f; + { + const __fp16* ptr0 = ptr; + + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr0); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _p0 = vsubq_f32(_p0, _mean0); + _p1 = vsubq_f32(_p1, _mean1); + _var0 = vmlaq_f32(_var0, _p0, _p0); + _var1 = vmlaq_f32(_var1, _p1, _p1); + ptr0 += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr0)); + _p = vsubq_f32(_p, _mean0); + _var0 = vmlaq_f32(_var0, _p, _p); + ptr0 += 4; + } + for (; i < size; i++) + { + float v = (float)ptr0[0] - mean; + var += v * v; + ptr0++; + } + } + + if (elempack == 8) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var0 = vdivq_f32(_var0, _elemcount); + _var1 = vdivq_f32(_var1, _elemcount); + _var0 = vaddq_f32(_var0, _eps); + _var1 = vaddq_f32(_var1, _eps); + float32x4_t _rsqrt_var0 = vrsqrteq_f32(_var0); + float32x4_t _rsqrt_var1 = vrsqrteq_f32(_var1); + _rsqrt_var0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var0), _rsqrt_var0), _rsqrt_var0); + _rsqrt_var1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var1, _rsqrt_var1), _rsqrt_var1), _rsqrt_var1); + _var0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var0), _rsqrt_var0), _rsqrt_var0); + _var1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var1, _rsqrt_var1), _rsqrt_var1), _rsqrt_var1); + _mean0 = vmulq_f32(_mean0, _var0); + _mean1 = vmulq_f32(_mean1, _var1); + _mean0 = vnegq_f32(_mean0); + _mean1 = vnegq_f32(_mean1); + } + if (elempack == 4) + { + _var0 = vaddq_f32(_var0, _var1); + + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var0 = vdivq_f32(_var0, _elemcount); + _var0 = vaddq_f32(_var0, _eps); + float32x4_t _rsqrt_var = vrsqrteq_f32(_var0); + _rsqrt_var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var1 = _var0; + _mean0 = vmulq_f32(_mean0, _var0); + _mean0 = vnegq_f32(_mean0); + _mean1 = _mean0; + } + if (elempack == 1) + { + _var0 = vaddq_f32(_var0, _var1); + var += vaddvq_f32(_var0); + + var = 1.f / sqrtf(var / elemcount + eps); + mean = -mean * var; + _var0 = vdupq_n_f32(var); + _var1 = _var0; + _mean0 = vdupq_n_f32(mean); + _mean1 = _mean0; + } + + if (gamma_ptr && beta_ptr) + { + int i = 0; + if (elempack == 8) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p0 = vmlaq_f32(_beta, _p0, _gamma); + _p1 = vmlaq_f32(_beta, _p1, _gamma); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 4) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _gamma1 = vdupq_n_f32(gamma_ptr[1]); + float32x4_t _beta0 = vdupq_n_f32(beta_ptr[0]); + float32x4_t _beta1 = vdupq_n_f32(beta_ptr[1]); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p0 = vmlaq_f32(_beta0, _p0, _gamma0); + _p1 = vmlaq_f32(_beta1, _p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 2; + beta_ptr += 2; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p = vmlaq_f32(_mean0, _p, _var0); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vld1q_f32(gamma_ptr); + float32x4_t _gamma1 = vld1q_f32(gamma_ptr + 4); + float32x4_t _beta0 = vld1q_f32(beta_ptr); + float32x4_t _beta1 = vld1q_f32(beta_ptr + 4); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p0 = vmlaq_f32(_beta0, _p0, _gamma0); + _p1 = vmlaq_f32(_beta1, _p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 8; + beta_ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + float32x4_t _beta = vld1q_f32(beta_ptr); + _p = vmlaq_f32(_mean0, _p, _var0); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } + for (; i < size; i++) + { + float v = (float)ptr[0]; + ptr[0] = (__fp16)((v * var + mean) * gamma_ptr[0] + beta_ptr[0]); + ptr++; + gamma_ptr++; + beta_ptr++; + } + } + else + { + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vmlaq_f32(_mean0, _p, _var0); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + } + for (; i < size; i++) + { + float v = (float)ptr[0]; + ptr[0] = (__fp16)(v * var + mean); + ptr++; + } + } +} + +int LayerNorm_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + __fp16* ptr = bottom_top_blob; + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.channel(q).row<__fp16>(i); + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + __fp16* ptr = bottom_top_blob.channel(q); + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +} // namespace ncnn diff --git a/src/net.cpp b/src/net.cpp index 32b5b2abd60..904e14cb2f7 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -707,7 +707,7 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio if (elembits == 16) { #if NCNN_ARM82 - if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic) + if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic && layer->support_fp16_storage) dst_elempack = 8; else if (elemcount % 4 == 0) dst_elempack = 4; diff --git a/tests/testutil.cpp b/tests/testutil.cpp index 837043cb754..ffc12bccfa3 100644 --- a/tests/testutil.cpp +++ b/tests/testutil.cpp @@ -406,7 +406,7 @@ static int convert_to_optimal_layout(const ncnn::Mat& a, ncnn::Mat& a4, const nc if (elembits == 16) { #if NCNN_ARM82 - if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic) + if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic && op->support_fp16_storage) dst_elempack = 8; else if (elemcount % 4 == 0) dst_elempack = 4;