Skip to content

Commit

Permalink
cpu: x64: simplify alpha and beta parameters in brgconv postops kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
kwiersch authored and tprimak committed Jan 9, 2023
1 parent d28f2c1 commit 4761ee9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
bcfg->LDD = (is_init && jcp.use_buffer) ? jcp.LDC : jcp.LDD;
bcfg->dt_c = (!is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // inp
bcfg->dt_d = (is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // out
bcfg->alpha = is_init ? 0 : 1;
bcfg->alpha = !is_init && IMPLICATION(jcp.with_sum, jcp.use_buffer);
bcfg->beta = is_init ? 0 : 1;
CHECK(safe_ptr_assign(kernels_po_[ker_idx],
new jit_brgemm_kernel_post_ops(jcp, *bcfg, *_pd->attr())));
Expand Down
72 changes: 31 additions & 41 deletions src/cpu/x64/jit_brgemm_post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
brg.attr->post_ops_,
memory_desc_wrapper(brg.dst_md))) {

if ((jcp.with_sum && brg.beta != 0)
|| ((jcp.with_binary || jcp.with_eltwise) && brg.alpha != 0)) {
if (brg.beta != 0) {
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = true;
static constexpr bool use_exact_tail_scalar_bcast = false;
Expand All @@ -306,7 +305,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
k_tail_mask, use_exact_tail_scalar_bcast};
const binary_injector::static_params_t bsp {this->param1, rhs_sp};

const bool save_state = (brg.alpha != 0) && jcp.with_eltwise;
const bool save_state = jcp.with_eltwise;
const auto &reserved_eltwise_gpr = reg_reserved_eltwise;
const auto reserved_eltwise_maskr = Xbyak::Opmask(1);

Expand Down Expand Up @@ -509,7 +508,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
};

if (jcp.with_sum && brg.beta != 0) {
if (jcp.with_sum) {
postops_injector_->set_lambda_injector(
primitive_kind::sum, sum_injector);
}
Expand Down Expand Up @@ -537,7 +536,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
void apply_comp(int m_block, int n_block, int tail = 0) {
auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask;

if (brg.alpha != 0 && brg.zp_type_a != brgemm_broadcast_t::none) {
if (brg.zp_type_a != brgemm_broadcast_t::none) {
auto zmm_zp_a_val = Xbyak::Zmm(30);
mov(reg_zp_a_val, ptr[rsp + reg_zp_a_val_offs_]);
vpbroadcastd(zmm_zp_a_val, reg_zp_a_val.cvt32());
Expand All @@ -558,7 +557,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
}

if (brg.alpha != 0 && brg.req_s8s8_compensation) {
if (brg.req_s8s8_compensation) {
mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]);
for (int n = 0; n < n_block; n++) {
auto zmm_comp = Xbyak::Zmm(31);
Expand Down Expand Up @@ -593,30 +592,20 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
const auto vector
= [=](int m, int n) { return Xbyak::Zmm(m * n_block + n); };
auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask;
const auto &p = attr.post_ops_;
const int sum_idx = p.find(primitive_kind::sum);
const auto maybe_req_comp = brg.is_int8 && brg.alpha != 0
const auto maybe_req_comp = brg.is_int8 && brg.beta != 0
&& (brg.req_s8s8_compensation
|| brg.zp_type_a != brgemm_broadcast_t::none);

// brg.alpha == 0 means no read from input, no bias, no eltwise - just
// initialize registers by zero at the beginning of kernel
// brg.beta == 0 means no sum - just registers write to output
// brg.alpha == 0 means initialize registers, 1 means read from input
// brg.beta == 0 means skip postwork, 1 means do postwork
// maybe_req_comp == true -> convert accumulated values to f32 after apply
// compensation to avoid the lost of accuracy when converting s32 to f32
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
if (brg.alpha == 0) {
if (sum_idx != -1 && brg.beta != 0) {
// if sum then have to init zmm each time
vpxord(vector(m, n), vector(m, n), vector(m, n));
}
} else if (!IMPLICATION(jcp.with_sum, jcp.use_buffer)) {
if (sum_idx != -1 && brg.beta != 0) {
// if sum without buffer then have to init vmm each time
uni_vpxor(vector(m, n), vector(m, n), vector(m, n));
}
} else {
if (brg.alpha == 0 && brg.beta != 0) {
// if postwork then have to init vmm each time
uni_vpxor(vector(m, n), vector(m, n), vector(m, n));
} else if (brg.alpha != 0) {
auto inp_addr = ptr[aux_reg_in
+ inp_typesize_ * (m * brg.LDC + n * brg.ld_block)];
cvt2ps(inp_dt_, vector(m, n), inp_addr, true, false, k_mask,
Expand All @@ -626,7 +615,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

if (maybe_req_comp) maybe_apply_comp(m_block, n_block, tail);

if (brg.alpha != 0 && jcp.with_bias) {
if (brg.beta != 0 && jcp.with_bias) {
for (int n = 0; n < n_block; n++) {
auto zmm_bias = Xbyak::Zmm(31);
auto bias_addr = ptr[aux_reg_bias
Expand All @@ -638,7 +627,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
}

if (brg.alpha != 0) {
if (brg.beta != 0) {
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
const Xbyak::Zmm zmm
Expand All @@ -652,7 +641,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

if (postops_injector_) inject_attr_postops(m_block, n_block, tail);

if (brg.alpha != 0 && brg.zp_type_c != brgemm_broadcast_t::none) {
if (brg.beta != 0 && brg.zp_type_c != brgemm_broadcast_t::none) {
mov(aux_reg_zp_c_values, ptr[rsp + aux_reg_zp_c_values_offs_]);
auto zmm_zp_c = Xbyak::Zmm(31);
if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) {
Expand Down Expand Up @@ -693,7 +682,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

if (out_dt_ == data_type::bf16) {
Xbyak::Ymm ymm = Xbyak::Ymm(zmm.getIdx());
if (brg.alpha != 0 || (sum_idx != -1 && brg.beta != 0)) {
if (brg.beta != 0) {
if (brg.is_bf16_emu)
bf16_emu_->vcvtneps2bf16(ymm, zmm);
else
Expand All @@ -702,7 +691,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
const Xbyak::Ymm r_ymm = ymm_mask(ymm, true, true, k_mask);
vmovdqu16(addr, r_ymm);
} else {
if (brg.alpha != 0 || (sum_idx != -1 && brg.beta != 0)) {
if (brg.beta != 0) {
saturate_f32(zmm, zmm_lbound, zmm_ubound, brg.dt_d);
if (out_dt_ != data_type::f32) vcvtps2dq(zmm, zmm);
}
Expand All @@ -721,8 +710,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

void loop_by_N(int m_block, int nb2, int nb2_tail, int nb_tail) {

if (brg.alpha) {
mov(aux_reg_in, reg_in);
if (brg.alpha) { mov(aux_reg_in, reg_in); }
if (brg.beta != 0) {
if (jcp.with_bias) mov(aux_reg_bias, reg_bias);
if (brg.zp_type_c != brgemm_broadcast_t::none) {
mov(aux_reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]);
Expand All @@ -748,7 +737,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
add(aux_reg_out, out_typesize_ * oc_l_offset);
if (brg.alpha != 0) {
add(aux_reg_in, inp_typesize_ * oc_l_offset);

}
if (brg.beta != 0) {
if (jcp.with_bias)
add(aux_reg_bias, bia_typesize_ * oc_l_offset);
if (brg.zp_type_c != brgemm_broadcast_t::none) {
Expand Down Expand Up @@ -779,6 +769,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
add(aux_reg_out, out_typesize_ * oc_l_offset);
if (brg.alpha != 0) {
add(aux_reg_in, inp_typesize_ * oc_l_offset);
}
if (brg.beta != 0) {
if (jcp.with_bias)
add(aux_reg_bias, bia_typesize_ * oc_l_offset);
if (brg.zp_type_c != brgemm_broadcast_t::none) {
Expand All @@ -805,8 +797,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
if (nb_tail > 0) {
apply_post_ops(m_block, 1, nb_tail);

if (brg.alpha != 0) {
add(aux_reg_in, inp_typesize_ * (nb_tail));
if (brg.alpha != 0) { add(aux_reg_in, inp_typesize_ * (nb_tail)); }
if (brg.beta != 0) {
if (jcp.with_bias) add(aux_reg_bias, bia_typesize_ * (nb_tail));
if (brg.zp_type_c != brgemm_broadcast_t::none) {
mov(aux_reg_zp_c_values,
Expand Down Expand Up @@ -859,8 +851,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
mov(reg_mask, tail_mask);
kmovq(k_tail_mask, reg_mask);

if (brg.alpha != 0) {
mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]);
if (brg.alpha != 0) { mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]); }
if (brg.beta != 0) {
mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]);
mov(reg_apply_comp, ptr[param1 + GET_OFF(apply_comp)]);
mov(ptr[rsp + reg_apply_comp_offs_], reg_apply_comp);
Expand All @@ -884,10 +876,9 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
mov(reg_out, ptr[param1 + GET_OFF(ptr_out)]);

// brg.alpha == 0 means no read from input, no bias, no eltwise - just
// initialize registers by zero
// brg.beta == 0 means no sum - just registers write to output
if (brg.alpha == 0) {
// brg.alpha == 0 means initialize registers, 1 means read from input
// brg.beta == 0 means skip postwork, 1 means do postwork
if (brg.alpha == 0 && brg.beta == 0) {
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
auto zmm = Xbyak::Zmm(m * n_block + n);
Expand All @@ -908,8 +899,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

postamble();

if (brg.alpha != 0 && jcp.with_eltwise)
postops_injector_->prepare_table();
if (postops_injector_) postops_injector_->prepare_table();
}
};

Expand Down

0 comments on commit 4761ee9

Please sign in to comment.