diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp index 244d80038219b3..297164fe1d84e1 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp @@ -224,6 +224,137 @@ void jit_equal_emitter::register_table_entries() { push_arg_entry_of("one", 0x3f800000, true); } +/// EXPONENT /// +jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + prepare_table(); +} + +jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + prepare_table(); +} + +size_t jit_exp_emitter::get_inputs_count() const { return 1; } + +size_t jit_exp_emitter::get_aux_vecs_count() const { return 4; } + +size_t jit_exp_emitter::get_aux_gprs_count() const { return 1; } + +void jit_exp_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { + if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { + emit_isa(in_vec_idxs, out_vec_idxs); + } else { + OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel"); + } +} + +template +void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { + if (exec_prc_ != ov::element::f32) { + OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); + } + + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + const TReg vmm_src(in_vec_idxs[0]); + const TReg vmm_dst(out_vec_idxs[0]); + const TReg vmm_aux1(aux_vec_idxs[0]); + const TReg vmm_aux2(aux_vec_idxs[1]); + const TReg vmm_aux0(aux_vec_idxs[2]); + + const TReg vmm_mask(aux_vec_idxs[3]); + + h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_max_f")); + h->fmin(vmm_dst.s, vmm_src.s, vmm_aux0.s); + h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f")); + + // get mask of values lower than log(FLT_MIN) to zero them in the output + h->fcmgt(vmm_mask.s, vmm_src.s, vmm_aux0.s); + + h->fmax(vmm_dst.s, vmm_dst.s, vmm_aux0.s); + h->mov(vmm_aux1.b16, vmm_dst.b16); + + // calculate exp(x) + // fx = x * log2ef + 0.5 + h->ld1r(vmm_aux0.s, table_val2("exp_log2ef")); + h->ld1r(vmm_aux2.s, table_val2("half")); + h->fmla(vmm_aux2.s, vmm_dst.s, vmm_aux0.s); + + // tmp = floorf(fx) + h->frintm(vmm_aux2.s, vmm_aux2.s); + + // keep vmm_src = fx for further computations + h->mov(vmm_dst.b16, vmm_aux2.b16); + + // x = x - fx * ln2 + h->ld1r(vmm_aux0.s, table_val2("ln2f")); + h->fmls(vmm_aux1.s, vmm_aux2.s, vmm_aux0.s); + + // We do not count 2^n here, because n can reach 128 and 2^128 is not + // representable by fp32, so to get around this problem, instead of computing + // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 + // and 2 are numbers representable in fp32. + + // compute 2^(n-1) + h->ld1r(vmm_aux0.s, table_val2("one")); + h->fsub(vmm_dst.s, vmm_dst.s, vmm_aux0.s); + h->fcvtzs(vmm_aux2.s, vmm_dst.s); + + h->ld1r(vmm_aux0.s, table_val2("exponent_bias")); + h->add(vmm_aux2.s, vmm_aux2.s, vmm_aux0.s); + + const int n_mantissa_bits = 23; + h->sqshl(vmm_aux2.s, vmm_aux2.s, n_mantissa_bits); + + // set zeroes at those points which were < log(FLT_MIN) + h->and_(vmm_aux2.b16, vmm_mask.b16, vmm_aux2.b16); + + // compute polynomial + h->ld1r(vmm_aux0.s, table_val2("exp_pol5")); + h->ld1r(vmm_dst.s, table_val2("exp_pol4")); + h->fmla(vmm_dst.s, vmm_aux1.s, vmm_aux0.s); + + h->ld1r(vmm_aux0.s, table_val2("exp_pol3")); + h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s); + + h->ld1r(vmm_dst.s, table_val2("exp_pol2")); + h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s); + + h->ld1r(vmm_aux0.s, table_val2("exp_pol1")); + h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s); + + h->ld1r(vmm_dst.s, table_val2("one")); + h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s); + + // y = y * 2^n + h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux2.s); + h->ld1r(vmm_aux0.s, table_val2("two")); + h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux0.s); +} + +void jit_exp_emitter::register_table_entries() { + push_arg_entry_of("exp_ln_flt_max_f", 0x42b17218, true); + push_arg_entry_of("exp_ln_flt_min_f", 0xc2aeac50, true); + push_arg_entry_of("exp_log2ef", 0x3fb8aa3b, true); + push_arg_entry_of("one", 0x3f800000, true); + push_arg_entry_of("two", 0x40000000, true); + push_arg_entry_of("half", 0x3f000000, true); + push_arg_entry_of("ln2f", 0x3f317218, true); + push_arg_entry_of("exponent_bias", 0x0000007f, true); + push_arg_entry_of("exp_pol1", 0x3f7ffffb, true); + push_arg_entry_of("exp_pol2", 0x3efffee3, true); + push_arg_entry_of("exp_pol3", 0x3e2aad40, true); + push_arg_entry_of("exp_pol4", 0x3d2b9d0d, true); + push_arg_entry_of("exp_pol5", 0x3c07cfce, true); +} + +std::set> jit_exp_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + /// MUL_ADD /// jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp index 58184933e3e1a7..b5e7fafa29ae55 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp @@ -112,6 +112,33 @@ class jit_equal_emitter : public jit_emitter { void register_table_entries() override; }; +class jit_exp_emitter : public jit_emitter { +public: + jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); + + jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node); + + size_t get_inputs_count() const override; + + size_t get_aux_vecs_count() const override; + + size_t get_aux_gprs_count() const override; + + void register_table_entries() override; + + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + +private: + void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + + template + void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; +}; + class jit_mul_add_emitter : public jit_emitter { public: jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp index 112c156652e9f8..45a2e99641cadd 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp @@ -22,6 +22,7 @@ bool JitEltwiseExecutor::isSupported( Algorithm::EltwiseClamp, Algorithm::EltwiseDivide, Algorithm::EltwiseEqual, + Algorithm::EltwiseExp, Algorithm::EltwiseMultiply, Algorithm::EltwiseMulAdd, Algorithm::EltwisePowerStatic, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp index 1954a65317fde7..c054e073d242bd 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp @@ -613,6 +613,7 @@ std::shared_ptr jit_uni_eltwise_generic::create_eltwise_emitte OV_CASE(Algorithm::EltwiseClamp, ov::intel_cpu::aarch64::jit_clamp_emitter), OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter), OV_CASE(Algorithm::EltwiseEqual, ov::intel_cpu::aarch64::jit_equal_emitter), + OV_CASE(Algorithm::EltwiseExp, ov::intel_cpu::aarch64::jit_exp_emitter), OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter), OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter), OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter), @@ -769,13 +770,13 @@ std::set> eltwise_precision_helper::get_supported_pre OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter), OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter), OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter), + OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter)); - if (precisions.empty()) OPENVINO_THROW("Unsupported operation type for Eltwise emitter"); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp index fac17b28830156..71eee36187c967 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp @@ -173,7 +173,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // 09 | dst // 10 | aux // 11 | aux - // 12-15 | [not used] + // 12 | aux + // 13 | aux + // 14 | aux + // 15 | [not used] // 16 | src // 17 | src // 18 | src @@ -201,7 +204,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { } inline TReg get_aux_vmm(const uint32_t idx) { - if (idx > 2) { + if (idx > 4) { OPENVINO_THROW("aux vector register " + std::to_string(idx) + " is not supported"); } return TReg(10 + idx); diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp index 044730aac4b009..00b6ce3d3ed32b 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp @@ -52,9 +52,14 @@ void ActivationLayerCPUTest::generate_inputs(const std::vector& targe uint32_t range = 0; int32_t resolution = 0; - if (activationType == utils::ActivationTypes::Exp && netPrecision == ov::element::bf16) { - startFrom = 0; - range = 2; + if (activationType == utils::ActivationTypes::Exp) { + if (netPrecision == ov::element::bf16) { + startFrom = 0; + range = 2; + } else { + startFrom = -10; + range = 25; + } resolution = 32768; } else if (activationType == utils::ActivationTypes::Acosh) { startFrom = 2; @@ -149,6 +154,7 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType #if defined(OPENVINO_ARCH_ARM64) if ((element_type == ov::element::f32) && ((activation_type == utils::ActivationTypes::Clamp) || + (activation_type == utils::ActivationTypes::Exp) || (activation_type == utils::ActivationTypes::Relu))) { return "jit"; }