Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] added erf support #4991

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/layer/arm/unaryop_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,23 @@ struct unary_op_trunc
#endif // __ARM_NEON
};

struct unary_op_erf
{
float func(const float& x) const
{
return (float)erf(x);
}
#if __ARM_NEON
float32x4_t func_pack4(const float32x4_t& x) const
{
float32x4_t norm_x_vec = x / vsqrtq_f32(vdupq_n_f32(2.0f));
float32x4_t erf_approx = vmovq_n_f32(1.0f);
float32x4_t tanh_x = tanh_ps(vmulq_f32(pi, norm_x_vec));
return vsubq_f32(erf_approx, vmulq_f32(0.5f, tanh_x));
}
#endif // __ARM_NEON
};

} // namespace UnaryOp_arm_functor

int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
Expand Down Expand Up @@ -550,6 +567,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down Expand Up @@ -686,6 +706,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TRUNC)
return unary_op_inplace_bf16s<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace_bf16s<unary_op_erf>(bottom_top_blob, opt);

return 0;
}
#endif // NCNN_BF16
Expand Down
11 changes: 11 additions & 0 deletions src/layer/unaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ struct unary_op_trunc
}
};

struct unary_op_erf
{
float operator()(const float& x) const
{
return (float)erf(x);
}
};

int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
if (op_type == Operation_ABS)
Expand Down Expand Up @@ -280,6 +288,9 @@ int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down
3 changes: 2 additions & 1 deletion src/layer/unaryop.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class UnaryOp : public Layer
Operation_TANH = 16,
Operation_LOG10 = 17,
Operation_ROUND = 18,
Operation_TRUNC = 19
Operation_TRUNC = 19,
Operation_ERF = 20
};

public:
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/unaryop.comp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void main()
if (op_type == 17) res = log(v) * afp(0.434294481903);
if (op_type == 18) res = round(v);
if (op_type == 19) res = trunc(v);
if (op_type == 20) res = erf(v);

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/unaryop_pack4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void main()
if (op_type == 17) res = log(v) * afp(0.434294481903);
if (op_type == 18) res = round(v);
if (op_type == 19) res = trunc(v);
if (op_type == 20) res = erf(v);

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/unaryop_pack8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ void main()
res[0] = trunc(v[0]);
res[1] = trunc(v[1]);
}
if (op_type == 20)
{
res[0] = erf(v[0]);
res[1] = erf(v[1]);
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
30 changes: 30 additions & 0 deletions src/layer/x86/unaryop_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <math.h>

#if __SSE2__
#include <xmmintrin.h>
#include <emmintrin.h>
#include "sse_mathfun.h"
#if __SSE4_1__
Expand Down Expand Up @@ -642,6 +643,32 @@ struct unary_op_trunc
#endif // __SSE2__
};

struct unary_op_erf
{
float func(const float& x) const
{
return (float)erf(x);
}
#if __SSE2__
__m128 func_pack4(const __m128& x) const
{
return _mm_erf_ps(x);
}
#if __AVX__
__m256 func_pack8(const __m256& x) const
{
return _mm256_erf_ps(x);
}
#if __AVX512F__
__m512 func_pack16(const __m512& x) const
{
return _mm512_erf_ps(x);
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};

} // namespace UnaryOp_x86_functor

int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
Expand Down Expand Up @@ -707,6 +734,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_unaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "layer/unaryop.h"
#include "testutil.h"

#define OP_TYPE_MAX 20
#define OP_TYPE_MAX 21

static int op_type = 0;

Expand Down
9 changes: 9 additions & 0 deletions tools/onnx/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3714,6 +3714,10 @@ int main(int argc, char** argv)
{
fprintf(pp, "%-16s", "EmbedLayerNormalization");
}
else if (op == "Erf")
{
fprintf(pp, "%-16s", "UnaryOp");
}
else if (op == "Exp")
{
fprintf(pp, "%-16s", "UnaryOp");
Expand Down Expand Up @@ -4510,6 +4514,11 @@ int main(int argc, char** argv)

fwrite_tensor_proto_data(B, bp);
}
else if (op == "Erf")
{
int op_type = 20;
fprintf(pp, " 0=%d", op_type);
}
else if (op == "Exp")
{
int op_type = 7;
Expand Down