From 455c3e9ba04f323a420c728d7b88f44e74fe8738 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 2 Sep 2023 01:19:07 +0330 Subject: [PATCH] working on other architectures --- src/layer/arm/unaryop_arm.cpp | 35 +++++++++++ src/layer/loongarch/unaryop_loongarch.cpp | 35 +++++++++++ src/layer/x86/unaryop_x86.cpp | 74 +++++++++++++++++++++++ 3 files changed, 144 insertions(+) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 5a054cc7c4d6..3b3714fe17bb 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -472,6 +472,35 @@ struct unary_op_trunc #endif // __ARM_NEON }; +struct unary_op_erf +{ + float func(const float& x) const + { + return (float)erf(x); + } +#if __ARM_NEON + float32x4_t func_pack4(const float32x4_t& x) const + { + float32x4_t a1 = vmovq_n_f32(0.254829592f); + float32x4_t a2 = vmovq_n_f32(-0.284496736f); + float32x4_t a3 = vmovq_n_f32(1.421413741f); + float32x4_t a4 = vmovq_n_f32(-1.453152027f); + float32x4_t a5 = vmovq_n_f32(1.061405429f); + float32x4_t p = vmovq_n_f32(0.3275911f); + float32x4_t s = vsign_f32(x); + float32x4_t x_abs = vabs_f32(x); + float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); + float32x4_t y = vsub_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); + y = vsub_f32(y, vmulq_f32(vmulq_f32(a3, t), t)); + y = vsub_f32(y, vmulq_f32(vmulq_f32(a2, t), t)); + y = vsub_f32(y, vmulq_f32(vmulq_f32(a1, t), t)); + y = vmulq_f32(y, t); + y = vmulq_f32(y, exp_f32(-vmulq_f32(x_abs, x_abs))); + return s * y; + } +#endif // __ARM_NEON +}; + } // namespace UnaryOp_arm_functor int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -550,6 +579,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } @@ -686,6 +718,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) if (op_type == Operation_TRUNC) return unary_op_inplace_bf16s(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + return 0; } #endif // NCNN_BF16 diff --git a/src/layer/loongarch/unaryop_loongarch.cpp b/src/layer/loongarch/unaryop_loongarch.cpp index 4d4818cb5af0..d49fb2e435cf 100644 --- a/src/layer/loongarch/unaryop_loongarch.cpp +++ b/src/layer/loongarch/unaryop_loongarch.cpp @@ -416,6 +416,38 @@ struct unary_op_trunc #endif // __loongarch_sx }; +struct unary_op_erf +{ + float func(const float& x) const + { + return (float)erf(x); + } +#if __loongarch_sx + __m128 func_pack4(const __m128& x) const + { + __m128 a1 = (__m128)__lsx_vreplfr2vr_s(0.254829592f); + __m128 a2 = (__m128)__lsx_vreplfr2vr_s(-0.284496736f); + __m128 a3 = (__m128)__lsx_vreplfr2vr_s(1.421413741f); + __m128 a4 = (__m128)__lsx_vreplfr2vr_s(-1.453152027f); + __m128 a5 = (__m128)__lsx_vreplfr2vr_s(1.061405429f); + __m128 p = (__m128)__lsx_vreplfr2vr_s(0.3275911f); + __m128 x2 = (__m128)__lsx_vbitclri_w((__m128i)x, 31); + __m128i tiny_mask = __lsx_vfcmp_clt_s((__m128)x2, (__m128)(__m128)__lsx_vreplgr2vr_w(c_tanh_tiny.i)); + __m128i sig_mask = __lsx_vreplgr2vr_w(1 << 31); + __m128i s = __lsx_vand_v((__m128i)x, sig_mask); + __m128 x_abs = (__m128)__lsx_vbitclri_w(x, 31); + __m128 t = (__m128)__lsx_vfadd_s(x_abs, p); + __m128 y = __lsx_vfsub_s(__lsx_vfmul_s(__lsx_vfmul_s(a5, t), t), __lsx_vfmul_s(__lsx_vfmul_s(a4, t), t)); + y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfsub_s(a3, t), t)); + y = __lsx_vfsub_s(y, __lsx_vfmul_s(__lsx_vfmul_s(a2, t), t)); + y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfmul_s(a1, t), t)); + y = __lsx_vfmul_s(y, t); + y = __lsx_vfmul_s(y, exp_ps(-__lsx_vfmul_s(x_abs, x_abs))); + return (__m128)__lsx_vfmul_s(x, y); + } +#endif // __loongarch_sx +}; + } // namespace UnaryOp_loongarch_functor int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -482,6 +514,9 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 8629ab2093b4..89a1d1a58029 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -642,6 +642,77 @@ struct unary_op_trunc #endif // __SSE2__ }; +struct unary_op_trunc +{ + float func(const float& x) const + { + return (float)erf(x); + } +#if __SSE2__ + __m128 func_pack4(const __m128& x) const + { + __m128 a1 = _mm_set1_ps(0.254829592f); + __m128 a2 = _mm_set1_ps(-0.284496736f); + __m128 a3 = _mm_set1_ps(1.421413741f); + __m128 a4 = _mm_set1_ps(-1.453152027f); + __m128 a5 = _mm_set1_ps(1.061405429f); + __m128 p = _mm_set1_ps(0.3275911f); + __m128 s = _mm_sign_ps(x); + __m128 x_abs = _mm_abs_ps(x); + __m128 t = _mm_rcp_ps(_mm_add_ps(x_abs, p)); + __m128 y = _mm_sub_ps(_mm_mul_ps(_mm_mul_ps(a5, t), t), _mm_mul_ps(_mm_mul_ps(a4, t), t)); + y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a3, t), t)); + y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a2, t), t)); + y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a1, t), t)); + y = _mm_mul_ps(y, t); + y = _mm_mul_ps(y, _mm_exp_ps(-_mm_mul_ps(x_abs, x_abs))); + return _mm_mul_ps(s, y); + } +#if __AVX__ + __m256 func_pack8(const __m256& x) const + { + __m256 a1 = _mm256_set1_ps(0.254829592f); + __m256 a2 = _mm256_set1_ps(-0.284496736f); + __m256 a3 = _mm256_set1_ps(1.421413741f); + __m256 a4 = _mm256_set1_ps(-1.453152027f); + __m256 a5 = _mm256_set1_ps(1.061405429f); + __m256 p = _mm256_set1_ps(0.3275911f); + __m256 s = _mm256_sign_ps(x); + __m256 x_abs = _mm256_abs_ps(x); + __m256 t = _mm256_rcp_ps(_mm256_add_ps(x_abs, p)); + __m256 y = _mm256_sub_ps(_mm256_mul_ps(_mm256_mul_ps(a5, t), t), _mm256_mul_ps(_mm256_mul_ps(a4, t), t)); + y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a3, t), t)); + y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a2, t), t)); + y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a1, t), t)); + y = _mm256_mul_ps(y, t); + y = _mm256_mul_ps(y, _mm256_exp_ps(-_mm256_mul_ps(x_abs, x_abs))); + return _mm256_mul_ps(s, y); + } +#if __AVX512F__ + __m512 func_pack16(const __m512& x) const + { + __m512 a1 = _mm512_set1_ps(0.254829592f); + __m512 a2 = _mm512_set1_ps(-0.284496736f); + __m512 a3 = _mm512_set1_ps(1.421413741f); + __m512 a4 = _mm512_set1_ps(-1.453152027f); + __m512 a5 = _mm512_set1_ps(1.061405429f); + __m512 p = _mm512_set1_ps(0.3275911f); + __m512 s = _mm512_sign_ps(x); + __m512 x_abs = _mm512_abs_ps(x); + __m512 t = _mm512_rcp_ps(_mm512_add_ps(x_abs, p)); + __m512 y = _mm512_sub_ps(_mm512_mul_ps(_mm512_mul_ps(a5, t), t), _mm512_mul_ps(_mm512_mul_ps(a4, t), t)); + y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a3, t), t)); + y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a2, t), t)); + y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a1, t), t)); + y = _mm512_mul_ps(y, t); + y = _mm512_mul_ps(y, _mm512_exp_ps(-_mm512_mul_ps(x_abs, x_abs))); + return _mm512_mul_ps(s, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + } // namespace UnaryOp_x86_functor int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -707,6 +778,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; }