Skip to content

Commit

Permalink
[CPU] [ARM64] jit subtract (openvinotoolkit#23285)
Browse files Browse the repository at this point in the history
### Details:
 - *[CPU] [AARCH64] jit subtract*

### Tickets:
 - *CVS-134748*
  • Loading branch information
eshoguli authored Mar 10, 2024
1 parent 090bf92 commit e0472dd
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,46 @@ void jit_relu_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const st
h->fmaxnm(dst.s, src.s, tmp.s);
}

/// SUBTRACT ///
jit_subtract_emitter::jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
}

jit_subtract_emitter::jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {
}

size_t jit_subtract_emitter::get_inputs_count() const { return 2; }

void jit_subtract_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_subtract_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src0 = TReg(in_vec_idxs[0]);
TReg src1 = TReg(in_vec_idxs[1]);
TReg dst = TReg(out_vec_idxs[0]);

h->uni_fsub(dst.s, src0.s, src1.s);
}

std::set<std::vector<element::Type>> jit_subtract_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,28 @@ class jit_relu_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_subtract_emitter : public jit_emitter {
public:
jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};


} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Algorithm::EltwisePowerStatic,
Algorithm::EltwiseRelu);
Algorithm::EltwiseRelu,
Algorithm::EltwiseSubtract);
if (!is_supported) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter),
OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter));
OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter),
OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter));

if (!ctx.emitter)
OPENVINO_THROW("Unsupported operation type '" + algToString(data.algo) + "' for Eltwise emitter");
Expand Down Expand Up @@ -655,7 +656,8 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter));
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter));

if (precisions.empty())
OPENVINO_THROW("Unsupported operation type for Eltwise emitter");
Expand Down

0 comments on commit e0472dd

Please sign in to comment.