diff --git a/src/layer/x86/avx512_mathfun.h b/src/layer/x86/avx512_mathfun.h index 2892e3d2bf7..0513d5e1be1 100644 --- a/src/layer/x86/avx512_mathfun.h +++ b/src/layer/x86/avx512_mathfun.h @@ -182,6 +182,48 @@ static NCNN_FORCEINLINE __m512 exp512_ps(__m512 x) return y; } +_PS512_CONST(tanh_hi, 9.0f); +_PS512_CONST(tanh_lo, -9.0f); + +_PS512_CONST(cephes_tanh_p0, -2.76076847742355E-16f); +_PS512_CONST(cephes_tanh_p1, 2.00018790482477E-13f); +_PS512_CONST(cephes_tanh_p2, -8.60467152213735E-11f); +_PS512_CONST(cephes_tanh_p3, 5.12229709037114E-08f); +_PS512_CONST(cephes_tanh_p4, 1.48572235717979E-05f); +_PS512_CONST(cephes_tanh_p5, 6.37261928875436E-04f); +_PS512_CONST(cephes_tanh_p6, 4.89352455891786E-03f); + +_PS512_CONST(cephes_tanh_p7, 1.19825839466702e-06f); +_PS512_CONST(cephes_tanh_p8, 1.18534705686654e-04f); +_PS512_CONST(cephes_tanh_p9, 2.26843463243900e-03f); + +// an approximation of tanh +static inline __m512 tanh512_ps(const __m512 x) +{ + __m512 value = x; + value = _mm512_max_ps(*(__m512*)_ps512_tanh_lo, value); + value = _mm512_min_ps(*(__m512*)_ps512_tanh_hi, value); + + __m512 value_squared = _mm512_mul_ps(value, value); + + __m512 p; + p = _mm512_fmadd_ps(value_squared, *(__m512*)_ps512_cephes_tanh_p0, *(__m512*)_ps512_cephes_tanh_p1); + p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p2); + p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p3); + p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p4); + p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p5); + p = _mm512_fmadd_ps(p, value_squared, *(__m512*)_ps512_cephes_tanh_p6); + p = _mm512_mul_ps(p, value); + + __m512 q; + q = _mm512_fmadd_ps(value_squared, *(__m512*)_ps512_cephes_tanh_p7, *(__m512*)_ps512_cephes_tanh_p8); + q = _mm512_fmadd_ps(q, value_squared, *(__m512*)_ps512_cephes_tanh_p9); + q = _mm512_fmadd_ps(q, value_squared, *(__m512*)_ps512_cephes_tanh_p6); + + __m512 dst = _mm512_div_ps(p, q); + return dst; +} + _PS512_CONST(minus_cephes_DP1, -0.78515625f); _PS512_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f); _PS512_CONST(minus_cephes_DP3, -3.77489497744594108e-8f); diff --git a/src/layer/x86/avx_mathfun.h b/src/layer/x86/avx_mathfun.h index db28691344a..645c399e4eb 100644 --- a/src/layer/x86/avx_mathfun.h +++ b/src/layer/x86/avx_mathfun.h @@ -295,6 +295,48 @@ static NCNN_FORCEINLINE __m256 exp256_ps(__m256 x) return y; } +_PS256_CONST(tanh_hi, 9.0f); +_PS256_CONST(tanh_lo, -9.0f); + +_PS256_CONST(cephes_tanh_p0, -2.76076847742355E-16f); +_PS256_CONST(cephes_tanh_p1, 2.00018790482477E-13f); +_PS256_CONST(cephes_tanh_p2, -8.60467152213735E-11f); +_PS256_CONST(cephes_tanh_p3, 5.12229709037114E-08f); +_PS256_CONST(cephes_tanh_p4, 1.48572235717979E-05f); +_PS256_CONST(cephes_tanh_p5, 6.37261928875436E-04f); +_PS256_CONST(cephes_tanh_p6, 4.89352455891786E-03f); + +_PS256_CONST(cephes_tanh_p7, 1.19825839466702e-06f); +_PS256_CONST(cephes_tanh_p8, 1.18534705686654e-04f); +_PS256_CONST(cephes_tanh_p9, 2.26843463243900e-03f); + +// an approximation of tanh +static inline __m256 tanh256_ps(const __m256 x) +{ + __m256 value = x; + value = _mm256_max_ps(*(__m256*)_ps256_tanh_lo, value); + value = _mm256_min_ps(*(__m256*)_ps256_tanh_hi, value); + + __m256 value_squared = _mm256_mul_ps(value, value); + + __m256 p; + p = _mm256_comp_fmadd_ps(value_squared, *(__m256*)_ps256_cephes_tanh_p0, *(__m256*)_ps256_cephes_tanh_p1); + p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p2); + p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p3); + p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p4); + p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p5); + p = _mm256_comp_fmadd_ps(p, value_squared, *(__m256*)_ps256_cephes_tanh_p6); + p = _mm256_mul_ps(p, value); + + __m256 q; + q = _mm256_comp_fmadd_ps(value_squared, *(__m256*)_ps256_cephes_tanh_p7, *(__m256*)_ps256_cephes_tanh_p8); + q = _mm256_comp_fmadd_ps(q, value_squared, *(__m256*)_ps256_cephes_tanh_p9); + q = _mm256_comp_fmadd_ps(q, value_squared, *(__m256*)_ps256_cephes_tanh_p6); + + __m256 dst = _mm256_div_ps(p, q); + return dst; +} + _PS256_CONST(minus_cephes_DP1, -0.78515625f); _PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f); _PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8f); diff --git a/src/layer/x86/gelu_x86.cpp b/src/layer/x86/gelu_x86.cpp new file mode 100644 index 00000000000..352d330b877 --- /dev/null +++ b/src/layer/x86/gelu_x86.cpp @@ -0,0 +1,154 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "gelu_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +namespace ncnn { + +GELU_x86::GELU_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +int GELU_x86::create_pipeline(const Option& /*opt*/) +{ + if (!fast_gelu) + { + support_packing = false; + } + return 0; +} + +int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + if (!fast_gelu) + { + return GELU::forward_inplace(bottom_top_blob, opt); + } + + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int elempack = bottom_top_blob.elempack; + int channels = bottom_top_blob.c; + int size = w * h * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _half512 = _mm512_set1_ps(0.5f); + __m512 _one512 = _mm512_set1_ps(1.f); + __m512 _fast1c512 = _mm512_set1_ps(0.79788452f); + __m512 _fast2c512 = _mm512_set1_ps(0.044715f); + for (; i + 15 < size; i += 16) + { + __m512 _pLoad = _mm512_loadu_ps(ptr); + + __m512 _cube = _mm512_mul_ps(_pLoad, _pLoad); + _cube = _mm512_mul_ps(_pLoad, _cube); + + __m512 _blob = _mm512_mul_ps(_fast2c512, _cube); + _blob = _mm512_add_ps(_pLoad, _blob); + _blob = _mm512_mul_ps(_fast1c512, _blob); + _blob = tanh512_ps(_blob); + _blob = _mm512_add_ps(_one512, _blob); + + _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); + + _mm512_storeu_ps(ptr, _blob); + + ptr += 16; + } +#endif // __AVX512F__ + __m256 _half256 = _mm256_set1_ps(0.5f); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _fast1c256 = _mm256_set1_ps(0.79788452f); + __m256 _fast2c256 = _mm256_set1_ps(0.044715f); + for (; i + 7 < size; i += 8) + { + __m256 _pLoad = _mm256_loadu_ps(ptr); + + __m256 _cube = _mm256_mul_ps(_pLoad, _pLoad); + _cube = _mm256_mul_ps(_pLoad, _cube); + + __m256 _blob = _mm256_mul_ps(_fast2c256, _cube); + _blob = _mm256_add_ps(_pLoad, _blob); + _blob = _mm256_mul_ps(_fast1c256, _blob); + _blob = tanh256_ps(_blob); + _blob = _mm256_add_ps(_one256, _blob); + + _blob = _mm256_mul_ps(_half256, _mm256_mul_ps(_blob, _pLoad)); + + _mm256_storeu_ps(ptr, _blob); + + ptr += 8; + } +#endif // __AVX__ + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _fast1c128 = _mm_set1_ps(0.79788452f); + __m128 _fast2c128 = _mm_set1_ps(0.044715f); + for (; i + 3 < size; i += 4) + { + __m128 _pLoad = _mm_loadu_ps(ptr); + + __m128 _cube = _mm_mul_ps(_pLoad, _pLoad); + _cube = _mm_mul_ps(_pLoad, _cube); + + __m128 _blob = _mm_mul_ps(_fast2c128, _cube); + _blob = _mm_add_ps(_pLoad, _blob); + _blob = _mm_mul_ps(_fast1c128, _blob); + _blob = tanh_ps(_blob); + _blob = _mm_add_ps(_one128, _blob); + + _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); + + _mm_storeu_ps(ptr, _blob); + + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3))) + *ptr = 0.5f * *ptr * (1.0f + tanhf(0.79788452f * (*ptr + 0.044715f * *ptr * *ptr * *ptr))); + + ptr++; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/gelu_x86.h b/src/layer/x86/gelu_x86.h new file mode 100644 index 00000000000..75d821bfd45 --- /dev/null +++ b/src/layer/x86/gelu_x86.h @@ -0,0 +1,33 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_GELU_X86_H +#define LAYER_GELU_X86_H + +#include "gelu.h" + +namespace ncnn { + +class GELU_x86 : virtual public GELU +{ +public: + GELU_x86(); + + virtual int create_pipeline(const Option& opt); + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_GELU_X86_H diff --git a/src/layer/x86/sse_mathfun.h b/src/layer/x86/sse_mathfun.h index 764e33e7976..75527d507b0 100644 --- a/src/layer/x86/sse_mathfun.h +++ b/src/layer/x86/sse_mathfun.h @@ -286,6 +286,47 @@ static NCNN_FORCEINLINE v4sf exp_ps(v4sf x) return y; } +_PS_CONST(tanh_hi, 9.0f); +_PS_CONST(tanh_lo, -9.0f); + +_PS_CONST(cephes_tanh_p0, -2.76076847742355E-16f); +_PS_CONST(cephes_tanh_p1, 2.00018790482477E-13f); +_PS_CONST(cephes_tanh_p2, -8.60467152213735E-11f); +_PS_CONST(cephes_tanh_p3, 5.12229709037114E-08f); +_PS_CONST(cephes_tanh_p4, 1.48572235717979E-05f); +_PS_CONST(cephes_tanh_p5, 6.37261928875436E-04f); +_PS_CONST(cephes_tanh_p6, 4.89352455891786E-03f); +_PS_CONST(cephes_tanh_p7, 1.19825839466702e-06f); +_PS_CONST(cephes_tanh_p8, 1.18534705686654e-04f); +_PS_CONST(cephes_tanh_p9, 2.26843463243900e-03f); + +// an approximation of tanh +static inline v4sf tanh_ps(const v4sf x) +{ + v4sf value = x; + value = _mm_max_ps(*(v4sf*)_ps_tanh_lo, value); + value = _mm_min_ps(*(v4sf*)_ps_tanh_hi, value); + + v4sf value_squared = _mm_mul_ps(value, value); + + v4sf p; + p = _mm_comp_fmadd_ps(value_squared, *(v4sf*)_ps_cephes_tanh_p0, *(v4sf*)_ps_cephes_tanh_p1); + p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p2); + p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p3); + p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p4); + p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p5); + p = _mm_comp_fmadd_ps(p, value_squared, *(v4sf*)_ps_cephes_tanh_p6); + p = _mm_mul_ps(p, value); + + v4sf q; + q = _mm_comp_fmadd_ps(value_squared, *(v4sf*)_ps_cephes_tanh_p7, *(v4sf*)_ps_cephes_tanh_p8); + q = _mm_comp_fmadd_ps(q, value_squared, *(v4sf*)_ps_cephes_tanh_p9); + q = _mm_comp_fmadd_ps(q, value_squared, *(v4sf*)_ps_cephes_tanh_p6); + + v4sf dst = _mm_div_ps(p, q); + return dst; +} + _PS_CONST(minus_cephes_DP1, -0.78515625f); _PS_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f); _PS_CONST(minus_cephes_DP3, -3.77489497744594108e-8f); diff --git a/tests/test_gelu.cpp b/tests/test_gelu.cpp index 974079edea8..f4ac70cf8e2 100644 --- a/tests/test_gelu.cpp +++ b/tests/test_gelu.cpp @@ -34,6 +34,8 @@ static int test_gelu(const ncnn::Mat& a, bool fast_gelu) static int test_gelu_0() { return 0 + || test_gelu(RandomMat(9, 7, 32), false) + || test_gelu(RandomMat(9, 7, 32), true) || test_gelu(RandomMat(5, 7, 24), false) || test_gelu(RandomMat(5, 7, 24), true) || test_gelu(RandomMat(7, 9, 12), false) @@ -45,6 +47,8 @@ static int test_gelu_0() static int test_gelu_1() { return 0 + || test_gelu(RandomMat(13, 32), false) + || test_gelu(RandomMat(13, 32), true) || test_gelu(RandomMat(15, 24), false) || test_gelu(RandomMat(15, 24), true) || test_gelu(RandomMat(17, 12), false) @@ -61,7 +65,9 @@ static int test_gelu_2() || test_gelu(RandomMat(124), false) || test_gelu(RandomMat(124), true) || test_gelu(RandomMat(127), false) - || test_gelu(RandomMat(127), true); + || test_gelu(RandomMat(127), true) + || test_gelu(RandomMat(120), false) + || test_gelu(RandomMat(120), true); } int main()