Skip to content

Commit

Permalink
[CPU] [ARM64] jit tanh (#23202)
Browse files Browse the repository at this point in the history
### Details:
 - *[CPU] [AARCH64] jit tanh*

### Tickets:
 - *CVS-134539*
  • Loading branch information
eshoguli authored Apr 5, 2024
1 parent ba42be1 commit 5ca08a7
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void jit_clamp_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const s
}

std::set<std::vector<element::Type>> jit_clamp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
return {{element::f32}};
}

/// DIVIDE ///
Expand Down Expand Up @@ -388,7 +388,7 @@ void jit_exp_emitter::register_table_entries() {
}

std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
return {{element::f32}};
}

/// MUL_ADD ///
Expand Down Expand Up @@ -532,7 +532,7 @@ void jit_power_static_emitter::register_table_entries() {
}

std::set<std::vector<element::Type>> jit_power_static_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
return {{element::f32}};
}

void jit_power_static_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
Expand Down Expand Up @@ -852,7 +852,7 @@ void jit_sigmoid_emitter::emit_data() const {
}

std::set<std::vector<element::Type>> jit_sigmoid_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
return {{element::f32}};
}

/// SUBTRACT ///
Expand Down Expand Up @@ -893,6 +893,80 @@ std::set<std::vector<element::Type>> jit_subtract_emitter::get_supported_precisi
return {{element::f32, element::f32}};
}

/// TANH ///
jit_tanh_emitter::jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
sigmoid_emitter = std::make_unique<jit_sigmoid_emitter>(h, host_isa, node);
}

jit_tanh_emitter::jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
sigmoid_emitter = std::make_unique<jit_sigmoid_emitter>(h, host_isa, exec_prc);
}

size_t jit_tanh_emitter::get_inputs_count() const { return 1; }

size_t jit_tanh_emitter::get_aux_vecs_count() const {
return sigmoid_emitter->get_aux_vecs_count() + 1;
}

size_t jit_tanh_emitter::get_aux_gprs_count() const {
return sigmoid_emitter->get_aux_gprs_count() + 1;
}

void jit_tanh_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_tanh_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);

TReg aux = TReg(aux_vec_idxs.back());

h->ld1r(aux.s, table_val2("two"));
h->uni_fmul(aux.s, src.s, aux.s);

sigmoid_emitter->emit_code(
{ aux.getIdx() },
out_vec_idxs,
aux_vec_idxs,
aux_gpr_idxs);

h->ld1r(aux.s, table_val2("two"));
h->uni_fmul(dst.s, aux.s, dst.s);
h->ld1r(aux.s, table_val2("one"));
h->uni_fsub(dst.s, dst.s, aux.s);
}

void jit_tanh_emitter::register_table_entries() {
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("two", 0x40000000, true);
}

void jit_tanh_emitter::emit_data() const {
jit_emitter::emit_data();
sigmoid_emitter->emit_data();
}

std::set<std::vector<element::Type>> jit_tanh_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,36 @@ class jit_subtract_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_tanh_emitter : public jit_emitter {
public:
jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::f32);

jit_tanh_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

void register_table_entries() override;

void emit_data() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
std::unique_ptr<jit_sigmoid_emitter> sigmoid_emitter;

void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

} // namespace aarch64
} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ void jit_emitter::store_context(
// 2.1. store pair registers
int prev_reg_idx = -1;
size_t ignore_registers_count = 0;
for (size_t reg_idx = 0; reg_idx < vec_regs.size(); reg_idx++) {
for (const auto reg_idx : vec_regs) {
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
ignore_registers_count++;
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseRelu,
Algorithm::EltwiseSelect,
Algorithm::EltwiseSigmoid,
Algorithm::EltwiseSubtract);
Algorithm::EltwiseSubtract,
Algorithm::EltwiseTanh);
if (!is_supported) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,8 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter),
OV_CASE(Algorithm::EltwiseSelect, ov::intel_cpu::aarch64::jit_select_emitter),
OV_CASE(Algorithm::EltwiseSigmoid, ov::intel_cpu::aarch64::jit_sigmoid_emitter),
OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter));
OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter),
OV_CASE(Algorithm::EltwiseTanh, ov::intel_cpu::aarch64::jit_tanh_emitter));

if (!ctx.emitter)
OPENVINO_THROW("Unsupported operation type '" + algToString(data.algo) + "' for Eltwise emitter");
Expand Down Expand Up @@ -780,7 +781,8 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter),
OV_CASE(Algorithm::EltwiseSigmoid, jit_sigmoid_emitter),
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter));
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter),
OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter));
if (precisions.empty())
OPENVINO_THROW("Unsupported operation type for Eltwise emitter");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,15 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// 13 | aux
// 14 | aux
// 15 | aux
// 16 | src
// 16 | aux
// 17 | src
// 18 | src
// 19 | src
// 20 | src
// 21 | src
// 22 | src
// 23-31 | [not used]
// 23 | src
// 24-31 | [not used]


TReg vmm_dst {9};
Expand All @@ -193,18 +194,18 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
if (idx > MAX_ELTWISE_INPUTS) {
OPENVINO_THROW("source vector register " + std::to_string(idx) + " is not supported");
}
return TReg(16 + idx);
return TReg(17 + idx);
}

inline SReg get_scl_reg(const uint32_t idx) {
if (idx > MAX_ELTWISE_INPUTS) {
OPENVINO_THROW("source scalar register " + std::to_string(idx) + " is not supported");
}
return SReg(16 + idx);
return SReg(17 + idx);
}

inline TReg get_aux_vmm(const uint32_t idx) {
if (idx > 5) {
if (idx > 6) {
OPENVINO_THROW("aux vector register " + std::to_string(idx) + " is not supported");
}
return TReg(10 + idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
((activation_type == utils::ActivationTypes::Clamp) ||
(activation_type == utils::ActivationTypes::Exp) ||
(activation_type == utils::ActivationTypes::Relu) ||
(activation_type == utils::ActivationTypes::Sigmoid))) {
(activation_type == utils::ActivationTypes::Sigmoid) ||
(activation_type == utils::ActivationTypes::Tanh))) {
return "jit";
}

Expand Down

0 comments on commit 5ca08a7

Please sign in to comment.