Skip to content

Commit

Permalink
cpu: x64: brdgmm dw conv: enable per-tensor binary po
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and tprimak committed Oct 18, 2022
1 parent 8a1e959 commit f430a5a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 32 deletions.
31 changes: 8 additions & 23 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t(const brgemm_t &abrd)

static const bcast_set_t enabled_bcast_strategy
= {broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_oc};
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::no_broadcast};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(vmm_b().getIdx()), r14, r15, preserve_gpr,
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
Expand All @@ -63,11 +64,9 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t(const brgemm_t &abrd)
injector::jit_uni_postops_injector_t<avx512_core>>(
this, brg.attr->post_ops_, bsp);

using namespace dnnl::impl::cpu::binary_injector_utils;
std::tie(with_binary_per_oc_bcast_)
= bcast_strategies_present_tup(brg.attr->post_ops_.entry_,
dst_md_wrapper, broadcasting_strategy_t::per_oc);
handle_binary_po_offset_ = with_binary_per_oc_bcast_;
with_binary_non_scalar_bcast_
= binary_injector::any_binary_postop_rhs_non_scalar_broadcast(
brg.attr->post_ops_, dst_md_wrapper);
}
if (brg.is_bf16_emu)
bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
Expand Down Expand Up @@ -123,11 +122,6 @@ void jit_brdgmm_kernel_base_t::read_params() {
}

if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1);

if (with_binary_per_oc_bcast_) {
mov(reg_tmp, ptr[param1 + GET_OFF(oc_logical_off)]);
mov(ptr[rsp + reg_binary_postops_oc_l_offs_], reg_tmp);
}
}

void jit_brdgmm_kernel_base_t::load_accumulators(int m_blocks, int n_blocks) {
Expand Down Expand Up @@ -205,8 +199,7 @@ void jit_brdgmm_kernel_base_t::apply_post_ops(
if (brg.with_binary) {
mov(reg_binary_params, ptr[rsp + abi_param1_offs_]);

if (handle_binary_po_offset_) {
mov(reg_binary_po_stack_frame, rsp);
if (with_binary_non_scalar_bcast_) {

for_(int m_i = 0; m_i < m_blocks; m_i++)
for (int n_i = 0; n_i < n_blocks; n_i++) {
Expand All @@ -230,8 +223,8 @@ void jit_brdgmm_kernel_base_t::apply_post_ops(

const injector_utils::conditional_register_preserve_guard_t
register_guard_sum_scale(
(handle_binary_po_offset_) && p_sum_scale_reg_set, this,
{reg_ptr_sum_scale});
(with_binary_non_scalar_bcast_) && p_sum_scale_reg_set,
this, {reg_ptr_sum_scale});
const injector_utils::conditional_register_preserve_guard_t
register_guard_sum_zp(p_sum_zp_reg_set, this, {reg_ptr_sum_zp});

Expand Down Expand Up @@ -673,10 +666,6 @@ void jit_brdgmm_kernel_base_t::compute_loop() {
add(reg_a_offset, n_loop_step * brg.typesize_A);
add(reg_aux_C, n_loop_step * brg.typesize_C);
add(reg_aux_D, n_loop_step * brg.typesize_D);
if (with_binary_per_oc_bcast_) {
add(qword[rsp + reg_binary_postops_oc_l_offs_],
n_loop_step);
}
}

if (do_loop_n) {
Expand Down Expand Up @@ -711,10 +700,6 @@ void jit_brdgmm_kernel_base_t::compute_loop() {
add(reg_a_offset, A_offset(m_blocks, -n_loop_offset));
add(reg_aux_C, C_offset(m_blocks, -n_loop_offset));
add(reg_aux_D, D_offset(m_blocks, -n_loop_offset));
if (with_binary_per_oc_bcast_) {
add(qword[rsp + reg_binary_postops_oc_l_offs_],
oc_logical_offset(-n_loop_offset));
}
}

if (do_loop_m) {
Expand Down
10 changes: 3 additions & 7 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
const reg64_t reg_total_padding = reg_table_base;
const reg64_t reg_aux_bias = reg_table_base;
const reg64_t reg_aux_scales = reg_table_base;
const reg64_t reg_binary_po_stack_frame = reg_BS_loop;
const reg64_t reg_binary_params = abi_param1; // default for binary ops
const reg64_t reg_ptr_sum_scale = reg_aux_A_vpad_top;
const reg64_t reg_ptr_sum_zp = reg_aux_A_vpad_bottom;
Expand All @@ -110,12 +109,9 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
constexpr static int reg_A_offs_ = 24; // brgemm_strd
constexpr static int reg_B_offs_ = 32; // brgemm_strd
constexpr static int abi_param1_offs_ = 40;
constexpr static int reg_binary_postops_oc_l_offs_ = 48;
constexpr static int reg_data_C_ptr_offs_ = 56;
constexpr static int stack_space_needed_ = 64;
constexpr static int stack_space_needed_ = 48;

bool handle_binary_po_offset_ = false;
bool with_binary_per_oc_bcast_ = false;
bool with_binary_non_scalar_bcast_ = false;

inline int M() { return brg.bcast_dim; };
inline int N() { return brg.load_dim; };
Expand Down Expand Up @@ -204,4 +200,4 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
} // namespace impl
} // namespace dnnl

#endif
#endif
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_brdgmm_dw_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ bool post_ops_ok(jit_brdgmm_conv_conf_t &jcp, const primitive_attr_t &attr,
{sum, eltwise, binary}, post_ops, &dst_d,
false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/,
false /*sum_requires_zp_zero*/,
{broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::scalar}));
{broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar,
broadcasting_strategy_t::no_broadcast}));
}

status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
Expand Down

0 comments on commit f430a5a

Please sign in to comment.