Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] [ARM64] jit exp #22937

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,137 @@ void jit_equal_emitter::register_table_entries() {
push_arg_entry_of("one", 0x3f800000, true);
}

/// EXPONENT ///
jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
}

jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_exp_emitter::get_inputs_count() const { return 1; }

size_t jit_exp_emitter::get_aux_vecs_count() const { return 4; }

size_t jit_exp_emitter::get_aux_gprs_count() const { return 1; }

void jit_exp_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_exp_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg vmm_src(in_vec_idxs[0]);
const TReg vmm_dst(out_vec_idxs[0]);
const TReg vmm_aux1(aux_vec_idxs[0]);
const TReg vmm_aux2(aux_vec_idxs[1]);
const TReg vmm_aux0(aux_vec_idxs[2]);

const TReg vmm_mask(aux_vec_idxs[3]);

h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_max_f"));
h->fmin(vmm_dst.s, vmm_src.s, vmm_aux0.s);
h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f"));

// get mask of values lower than log(FLT_MIN) to zero them in the output
h->fcmgt(vmm_mask.s, vmm_src.s, vmm_aux0.s);

h->fmax(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->mov(vmm_aux1.b16, vmm_dst.b16);

// calculate exp(x)
// fx = x * log2ef + 0.5
h->ld1r(vmm_aux0.s, table_val2("exp_log2ef"));
h->ld1r(vmm_aux2.s, table_val2("half"));
h->fmla(vmm_aux2.s, vmm_dst.s, vmm_aux0.s);

// tmp = floorf(fx)
h->frintm(vmm_aux2.s, vmm_aux2.s);

// keep vmm_src = fx for further computations
h->mov(vmm_dst.b16, vmm_aux2.b16);

// x = x - fx * ln2
h->ld1r(vmm_aux0.s, table_val2("ln2f"));
h->fmls(vmm_aux1.s, vmm_aux2.s, vmm_aux0.s);

// We do not count 2^n here, because n can reach 128 and 2^128 is not
// representable by fp32, so to get around this problem, instead of computing
// 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127
// and 2 are numbers representable in fp32.

// compute 2^(n-1)
h->ld1r(vmm_aux0.s, table_val2("one"));
h->fsub(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->fcvtzs(vmm_aux2.s, vmm_dst.s);

h->ld1r(vmm_aux0.s, table_val2("exponent_bias"));
h->add(vmm_aux2.s, vmm_aux2.s, vmm_aux0.s);

const int n_mantissa_bits = 23;
h->sqshl(vmm_aux2.s, vmm_aux2.s, n_mantissa_bits);

// set zeroes at those points which were < log(FLT_MIN)
h->and_(vmm_aux2.b16, vmm_mask.b16, vmm_aux2.b16);

// compute polynomial
h->ld1r(vmm_aux0.s, table_val2("exp_pol5"));
h->ld1r(vmm_dst.s, table_val2("exp_pol4"));
h->fmla(vmm_dst.s, vmm_aux1.s, vmm_aux0.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol3"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("exp_pol2"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol1"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("one"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

// y = y * 2^n
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux2.s);
h->ld1r(vmm_aux0.s, table_val2("two"));
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
}

void jit_exp_emitter::register_table_entries() {
push_arg_entry_of("exp_ln_flt_max_f", 0x42b17218, true);
push_arg_entry_of("exp_ln_flt_min_f", 0xc2aeac50, true);
push_arg_entry_of("exp_log2ef", 0x3fb8aa3b, true);
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("two", 0x40000000, true);
push_arg_entry_of("half", 0x3f000000, true);
push_arg_entry_of("ln2f", 0x3f317218, true);
push_arg_entry_of("exponent_bias", 0x0000007f, true);
push_arg_entry_of("exp_pol1", 0x3f7ffffb, true);
push_arg_entry_of("exp_pol2", 0x3efffee3, true);
push_arg_entry_of("exp_pol3", 0x3e2aad40, true);
push_arg_entry_of("exp_pol4", 0x3d2b9d0d, true);
push_arg_entry_of("exp_pol5", 0x3c07cfce, true);
}

std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

/// MUL_ADD ///
jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,33 @@ class jit_equal_emitter : public jit_emitter {
void register_table_entries() override;
};

class jit_exp_emitter : public jit_emitter {
public:
jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

void register_table_entries() override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_mul_add_emitter : public jit_emitter {
public:
jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseClamp,
Algorithm::EltwiseDivide,
Algorithm::EltwiseEqual,
Algorithm::EltwiseExp,
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Algorithm::EltwisePowerStatic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseClamp, ov::intel_cpu::aarch64::jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter),
OV_CASE(Algorithm::EltwiseEqual, ov::intel_cpu::aarch64::jit_equal_emitter),
OV_CASE(Algorithm::EltwiseExp, ov::intel_cpu::aarch64::jit_exp_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter),
Expand Down Expand Up @@ -769,13 +770,13 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter),
OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter),
OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter),
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter));

if (precisions.empty())
OPENVINO_THROW("Unsupported operation type for Eltwise emitter");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// 09 | dst
// 10 | aux
// 11 | aux
// 12-15 | [not used]
// 12 | aux
// 13 | aux
// 14 | aux
// 15 | [not used]
// 16 | src
// 17 | src
// 18 | src
Expand Down Expand Up @@ -201,7 +204,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
}

inline TReg get_aux_vmm(const uint32_t idx) {
if (idx > 2) {
if (idx > 4) {
OPENVINO_THROW("aux vector register " + std::to_string(idx) + " is not supported");
}
return TReg(10 + idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ void ActivationLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targe
uint32_t range = 0;
int32_t resolution = 0;

if (activationType == utils::ActivationTypes::Exp && netPrecision == ov::element::bf16) {
startFrom = 0;
range = 2;
if (activationType == utils::ActivationTypes::Exp) {
if (netPrecision == ov::element::bf16) {
startFrom = 0;
range = 2;
} else {
startFrom = -10;
range = 25;
}
resolution = 32768;
} else if (activationType == utils::ActivationTypes::Acosh) {
startFrom = 2;
Expand Down Expand Up @@ -149,6 +154,7 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
#if defined(OPENVINO_ARCH_ARM64)
if ((element_type == ov::element::f32) &&
((activation_type == utils::ActivationTypes::Clamp) ||
(activation_type == utils::ActivationTypes::Exp) ||
(activation_type == utils::ActivationTypes::Relu))) {
return "jit";
}
Expand Down
Loading