Skip to content

Commit

Permalink
working on other architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
brightening-eyes committed Sep 1, 2023
1 parent 3a4c65d commit 455c3e9
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
35 changes: 35 additions & 0 deletions src/layer/arm/unaryop_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,35 @@ struct unary_op_trunc
#endif // __ARM_NEON
};

struct unary_op_erf
{
float func(const float& x) const
{
return (float)erf(x);
}
#if __ARM_NEON
float32x4_t func_pack4(const float32x4_t& x) const
{
float32x4_t a1 = vmovq_n_f32(0.254829592f);
float32x4_t a2 = vmovq_n_f32(-0.284496736f);
float32x4_t a3 = vmovq_n_f32(1.421413741f);
float32x4_t a4 = vmovq_n_f32(-1.453152027f);
float32x4_t a5 = vmovq_n_f32(1.061405429f);
float32x4_t p = vmovq_n_f32(0.3275911f);
float32x4_t s = vsign_f32(x);
float32x4_t x_abs = vabs_f32(x);
float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p));
float32x4_t y = vsub_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t));
y = vsub_f32(y, vmulq_f32(vmulq_f32(a3, t), t));
y = vsub_f32(y, vmulq_f32(vmulq_f32(a2, t), t));
y = vsub_f32(y, vmulq_f32(vmulq_f32(a1, t), t));
y = vmulq_f32(y, t);
y = vmulq_f32(y, exp_f32(-vmulq_f32(x_abs, x_abs)));
return s * y;
}
#endif // __ARM_NEON
};

} // namespace UnaryOp_arm_functor

int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
Expand Down Expand Up @@ -550,6 +579,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down Expand Up @@ -686,6 +718,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TRUNC)
return unary_op_inplace_bf16s<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace_bf16s<unary_op_erf>(bottom_top_blob, opt);

return 0;
}
#endif // NCNN_BF16
Expand Down
35 changes: 35 additions & 0 deletions src/layer/loongarch/unaryop_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,38 @@ struct unary_op_trunc
#endif // __loongarch_sx
};

struct unary_op_erf
{
float func(const float& x) const
{
return (float)erf(x);
}
#if __loongarch_sx
__m128 func_pack4(const __m128& x) const
{
__m128 a1 = (__m128)__lsx_vreplfr2vr_s(0.254829592f);
__m128 a2 = (__m128)__lsx_vreplfr2vr_s(-0.284496736f);
__m128 a3 = (__m128)__lsx_vreplfr2vr_s(1.421413741f);
__m128 a4 = (__m128)__lsx_vreplfr2vr_s(-1.453152027f);
__m128 a5 = (__m128)__lsx_vreplfr2vr_s(1.061405429f);
__m128 p = (__m128)__lsx_vreplfr2vr_s(0.3275911f);
__m128 x2 = (__m128)__lsx_vbitclri_w((__m128i)x, 31);
__m128i tiny_mask = __lsx_vfcmp_clt_s((__m128)x2, (__m128)(__m128)__lsx_vreplgr2vr_w(c_tanh_tiny.i));
__m128i sig_mask = __lsx_vreplgr2vr_w(1 << 31);
__m128i s = __lsx_vand_v((__m128i)x, sig_mask);
__m128 x_abs = (__m128)__lsx_vbitclri_w(x, 31);
__m128 t = (__m128)__lsx_vfadd_s(x_abs, p);
__m128 y = __lsx_vfsub_s(__lsx_vfmul_s(__lsx_vfmul_s(a5, t), t), __lsx_vfmul_s(__lsx_vfmul_s(a4, t), t));
y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfsub_s(a3, t), t));
y = __lsx_vfsub_s(y, __lsx_vfmul_s(__lsx_vfmul_s(a2, t), t));
y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfmul_s(a1, t), t));
y = __lsx_vfmul_s(y, t);
y = __lsx_vfmul_s(y, exp_ps(-__lsx_vfmul_s(x_abs, x_abs)));
return (__m128)__lsx_vfmul_s(x, y);
}
#endif // __loongarch_sx
};

} // namespace UnaryOp_loongarch_functor

int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
Expand Down Expand Up @@ -482,6 +514,9 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down
74 changes: 74 additions & 0 deletions src/layer/x86/unaryop_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,77 @@ struct unary_op_trunc
#endif // __SSE2__
};

struct unary_op_trunc
{
float func(const float& x) const
{
return (float)erf(x);
}
#if __SSE2__
__m128 func_pack4(const __m128& x) const
{
__m128 a1 = _mm_set1_ps(0.254829592f);
__m128 a2 = _mm_set1_ps(-0.284496736f);
__m128 a3 = _mm_set1_ps(1.421413741f);
__m128 a4 = _mm_set1_ps(-1.453152027f);
__m128 a5 = _mm_set1_ps(1.061405429f);
__m128 p = _mm_set1_ps(0.3275911f);
__m128 s = _mm_sign_ps(x);
__m128 x_abs = _mm_abs_ps(x);
__m128 t = _mm_rcp_ps(_mm_add_ps(x_abs, p));
__m128 y = _mm_sub_ps(_mm_mul_ps(_mm_mul_ps(a5, t), t), _mm_mul_ps(_mm_mul_ps(a4, t), t));
y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a3, t), t));
y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a2, t), t));
y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a1, t), t));
y = _mm_mul_ps(y, t);
y = _mm_mul_ps(y, _mm_exp_ps(-_mm_mul_ps(x_abs, x_abs)));
return _mm_mul_ps(s, y);
}
#if __AVX__
__m256 func_pack8(const __m256& x) const
{
__m256 a1 = _mm256_set1_ps(0.254829592f);
__m256 a2 = _mm256_set1_ps(-0.284496736f);
__m256 a3 = _mm256_set1_ps(1.421413741f);
__m256 a4 = _mm256_set1_ps(-1.453152027f);
__m256 a5 = _mm256_set1_ps(1.061405429f);
__m256 p = _mm256_set1_ps(0.3275911f);
__m256 s = _mm256_sign_ps(x);
__m256 x_abs = _mm256_abs_ps(x);
__m256 t = _mm256_rcp_ps(_mm256_add_ps(x_abs, p));
__m256 y = _mm256_sub_ps(_mm256_mul_ps(_mm256_mul_ps(a5, t), t), _mm256_mul_ps(_mm256_mul_ps(a4, t), t));
y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a3, t), t));
y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a2, t), t));
y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a1, t), t));
y = _mm256_mul_ps(y, t);
y = _mm256_mul_ps(y, _mm256_exp_ps(-_mm256_mul_ps(x_abs, x_abs)));
return _mm256_mul_ps(s, y);
}
#if __AVX512F__
__m512 func_pack16(const __m512& x) const
{
__m512 a1 = _mm512_set1_ps(0.254829592f);
__m512 a2 = _mm512_set1_ps(-0.284496736f);
__m512 a3 = _mm512_set1_ps(1.421413741f);
__m512 a4 = _mm512_set1_ps(-1.453152027f);
__m512 a5 = _mm512_set1_ps(1.061405429f);
__m512 p = _mm512_set1_ps(0.3275911f);
__m512 s = _mm512_sign_ps(x);
__m512 x_abs = _mm512_abs_ps(x);
__m512 t = _mm512_rcp_ps(_mm512_add_ps(x_abs, p));
__m512 y = _mm512_sub_ps(_mm512_mul_ps(_mm512_mul_ps(a5, t), t), _mm512_mul_ps(_mm512_mul_ps(a4, t), t));
y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a3, t), t));
y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a2, t), t));
y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a1, t), t));
y = _mm512_mul_ps(y, t);
y = _mm512_mul_ps(y, _mm512_exp_ps(-_mm512_mul_ps(x_abs, x_abs)));
return _mm512_mul_ps(s, y);
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};

} // namespace UnaryOp_x86_functor

int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
Expand Down Expand Up @@ -707,6 +778,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down

0 comments on commit 455c3e9

Please sign in to comment.