-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU][ARM64] Implemented JIT Emitter for Eltwise SoftPlus Operation #29242
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job 👍🏼
Left some comments
const TReg vmm_neg_mask(aux_vec_idxs[7]); // mask to indicate whether n is negative | ||
|
||
h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f")); // load min allowed value | ||
h->fmaxnm(vmm_dst.s, vmm_src.s, vmm_aux0.s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be situation when in_vec_idxs[0] == out_vec_idxs[0]
- the same src and dst register.
In this case after this line vmm_src
will contain not original values and we won't be able to use this register (for example, on L2782 the register may have incorrect original data).
May I ask you to handle such possible cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. Will first moving to an auxillary register from the source and then applying the fmaxnm
operation be enough in this case?
} | ||
|
||
size_t jit_softplus_emitter::get_aux_vecs_count() const { | ||
return 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since part of aux vec registers are passed to exp_emitter
, I think we need to write some function with exp_emitter->get_aux_vecs_count()
calling here? If exp_emitter
requires (or will require) more than 8 registers, some vector registers will be spilled (saved on stack) during exp_emitter->emit_code()
call.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially thought to put exp_emitter->get_aux_vecs_count()+4
, as exp_emitter->get_aux_vecs_count()
is returning 4. I tried with using more than 9 aux registers (few solely for exp_emitter for its own manipulation), but it won't let me allocate that many. I noticed the same in the elu_emitter
where the get get_aux_vecs_count()
function for it returns max(exp_emitter->get_aux_vecs_count()+1ull, 2ull)
. I'm wondering if I can apply the same logic while allocating my auxillary registers.
Like below,
const TReg vmm_aux0(aux_vec_idxs[0]);
const TReg vmm_aux1(aux_vec_idxs[1]);
const TReg vmm_aux2(aux_vec_idxs[2]);
const TReg vmm_aux3(aux_vec_idxs[3]);
const TReg vmm_aux4(aux_vec_idxs[exp_aux_count]);
const TReg vmm_aux5(aux_vec_idxs[exp_aux_count + 1]);
const TReg vmm_mask(aux_vec_idxs[exp_aux_count + 2]);
const TReg vmm_neg_mask(aux_vec_idxs[exp_aux_count + 3]);
Is this approach correct?
h->fmul(vmm_dst.s, vmm_aux0.s, vmm_aux2.s); | ||
h->fadd(vmm_dst.s, vmm_dst.s, vmm_aux4.s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use FMA instruction here (fmla
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will modify it. Thanks
const TReg vmm_mask(aux_vec_idxs[6]); | ||
const TReg vmm_neg_mask(aux_vec_idxs[7]); // mask to indicate whether n is negative |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have only one mask register? I see that vmm_neg_mask
is free (after L2778) when we initialize vmm_mask
on L2782. Looks like we can reuse vmm_neg_mask
in the code part which handle big values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will update that
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa> | ||
void jit_softplus_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, | ||
const std::vector<size_t>& out_vec_idxs) const { | ||
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask you to add please brief description of the implemented logic? What is the formula?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implemented logic manipulates the base formula softplus(x) = ln(1+e^x) = ln(1+e^r*2^n)
into two separate formula for positive x and negative x,
- For positive we use
softplus(x) = nln2 + ln((2^-(n-1)+2e^r)/2)
to compute the log approximation - For negative we use
softplus(x) = ln(1+e^x) = ln(1+e^r*2^n) = ln2 + ln(1 + (e^r*2^(n-1) - 0.5))
, to compute the log approximation
h->fsub(vmm_aux0.s, vmm_aux2.s, vmm_aux4.s); // (e^r*2^(n-1) - 0.5) | ||
|
||
// Log approximation of (1 + (e^r*2^(n-1) - 0.5)) | ||
h->ld1r(vmm_aux5.s, table_val2("log_pol6")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate why we have to calculate two log approximation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base formula, softplus(x) = ln(1+e^x) = ln(1+e^r*2^n)
, where n
is the quotient and r
is the remainder when divided by ln2
. We have two cases one for positive value of n
, i.e., x >= 0
and for negative value of n
, i.e. for x < 0.
-
For positive values we finally arrive at the formula;
softplus(x) = nln2 + ln((2^-(n-1)+2e^r)/2)
, here we can calculate the log approximation since forn > 0
,2^-(n-1)
will be small and will be in desirable range. The problem occurs when n is negative, which makes the2^-(n-1)
value very large and cannot be approximated with 8th order polynomial -
For negative values we can inherently use this formula
softplus(x) = ln(1+e^x) = ln(1+e^r*2^n)
, i have done these manipulations as shown in the comment in the code to gety
inln(1+y)
in between [-0.5, 0] where we can accurately approximate for negativen
Details:
jit_softplus_emitter
derived class for element wise softplus operationAlgorithm::EltwiseSoftRelu
, inexecutors/aarch64
as one of the supported algorithmsget_supported_precisions
andcreate_eltwise_emitters
inkernel/aarch64
utils::ActivationTypes::SoftPlus
injit
kernel check in the tests inactivation.cpp
Tests:
Passed all local tests using
./bin/arm64/Release/ov_cpu_func_tests --gtest_filter="*smoke*Activation*SoftPlus*"
Tickets:
CC: @a-sidorova