From f430a5a4c883ef846f938f571020565d41719e9c Mon Sep 17 00:00:00 2001 From: Tomasz Czeszun Date: Tue, 11 Oct 2022 13:21:32 -0700 Subject: [PATCH] cpu: x64: brdgmm dw conv: enable per-tensor binary po --- src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp | 31 ++++++------------------ src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp | 10 +++----- src/cpu/x64/jit_brdgmm_dw_conv.cpp | 4 +-- 3 files changed, 13 insertions(+), 32 deletions(-) diff --git a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp index 136f476880a..3a5bca652c9 100644 --- a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp @@ -49,7 +49,8 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t(const brgemm_t &abrd) static const bcast_set_t enabled_bcast_strategy = {broadcasting_strategy_t::scalar, - broadcasting_strategy_t::per_oc}; + broadcasting_strategy_t::per_oc, + broadcasting_strategy_t::no_broadcast}; const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(vmm_b().getIdx()), r14, r15, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), @@ -63,11 +64,9 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t(const brgemm_t &abrd) injector::jit_uni_postops_injector_t>( this, brg.attr->post_ops_, bsp); - using namespace dnnl::impl::cpu::binary_injector_utils; - std::tie(with_binary_per_oc_bcast_) - = bcast_strategies_present_tup(brg.attr->post_ops_.entry_, - dst_md_wrapper, broadcasting_strategy_t::per_oc); - handle_binary_po_offset_ = with_binary_per_oc_bcast_; + with_binary_non_scalar_bcast_ + = binary_injector::any_binary_postop_rhs_non_scalar_broadcast( + brg.attr->post_ops_, dst_md_wrapper); } if (brg.is_bf16_emu) bf16_emu_ = utils::make_unique(this, @@ -123,11 +122,6 @@ void jit_brdgmm_kernel_base_t::read_params() { } if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1); - - if (with_binary_per_oc_bcast_) { - mov(reg_tmp, ptr[param1 + GET_OFF(oc_logical_off)]); - mov(ptr[rsp + reg_binary_postops_oc_l_offs_], reg_tmp); - } } void jit_brdgmm_kernel_base_t::load_accumulators(int m_blocks, int n_blocks) { @@ -205,8 +199,7 @@ void jit_brdgmm_kernel_base_t::apply_post_ops( if (brg.with_binary) { mov(reg_binary_params, ptr[rsp + abi_param1_offs_]); - if (handle_binary_po_offset_) { - mov(reg_binary_po_stack_frame, rsp); + if (with_binary_non_scalar_bcast_) { for_(int m_i = 0; m_i < m_blocks; m_i++) for (int n_i = 0; n_i < n_blocks; n_i++) { @@ -230,8 +223,8 @@ void jit_brdgmm_kernel_base_t::apply_post_ops( const injector_utils::conditional_register_preserve_guard_t register_guard_sum_scale( - (handle_binary_po_offset_) && p_sum_scale_reg_set, this, - {reg_ptr_sum_scale}); + (with_binary_non_scalar_bcast_) && p_sum_scale_reg_set, + this, {reg_ptr_sum_scale}); const injector_utils::conditional_register_preserve_guard_t register_guard_sum_zp(p_sum_zp_reg_set, this, {reg_ptr_sum_zp}); @@ -673,10 +666,6 @@ void jit_brdgmm_kernel_base_t::compute_loop() { add(reg_a_offset, n_loop_step * brg.typesize_A); add(reg_aux_C, n_loop_step * brg.typesize_C); add(reg_aux_D, n_loop_step * brg.typesize_D); - if (with_binary_per_oc_bcast_) { - add(qword[rsp + reg_binary_postops_oc_l_offs_], - n_loop_step); - } } if (do_loop_n) { @@ -711,10 +700,6 @@ void jit_brdgmm_kernel_base_t::compute_loop() { add(reg_a_offset, A_offset(m_blocks, -n_loop_offset)); add(reg_aux_C, C_offset(m_blocks, -n_loop_offset)); add(reg_aux_D, D_offset(m_blocks, -n_loop_offset)); - if (with_binary_per_oc_bcast_) { - add(qword[rsp + reg_binary_postops_oc_l_offs_], - oc_logical_offset(-n_loop_offset)); - } } if (do_loop_m) { diff --git a/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp b/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp index d3e17f6901a..5c6b942514a 100644 --- a/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp +++ b/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp @@ -83,7 +83,6 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { const reg64_t reg_total_padding = reg_table_base; const reg64_t reg_aux_bias = reg_table_base; const reg64_t reg_aux_scales = reg_table_base; - const reg64_t reg_binary_po_stack_frame = reg_BS_loop; const reg64_t reg_binary_params = abi_param1; // default for binary ops const reg64_t reg_ptr_sum_scale = reg_aux_A_vpad_top; const reg64_t reg_ptr_sum_zp = reg_aux_A_vpad_bottom; @@ -110,12 +109,9 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { constexpr static int reg_A_offs_ = 24; // brgemm_strd constexpr static int reg_B_offs_ = 32; // brgemm_strd constexpr static int abi_param1_offs_ = 40; - constexpr static int reg_binary_postops_oc_l_offs_ = 48; - constexpr static int reg_data_C_ptr_offs_ = 56; - constexpr static int stack_space_needed_ = 64; + constexpr static int stack_space_needed_ = 48; - bool handle_binary_po_offset_ = false; - bool with_binary_per_oc_bcast_ = false; + bool with_binary_non_scalar_bcast_ = false; inline int M() { return brg.bcast_dim; }; inline int N() { return brg.load_dim; }; @@ -204,4 +200,4 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { } // namespace impl } // namespace dnnl -#endif \ No newline at end of file +#endif diff --git a/src/cpu/x64/jit_brdgmm_dw_conv.cpp b/src/cpu/x64/jit_brdgmm_dw_conv.cpp index e7097785b38..e5797c790c3 100644 --- a/src/cpu/x64/jit_brdgmm_dw_conv.cpp +++ b/src/cpu/x64/jit_brdgmm_dw_conv.cpp @@ -65,8 +65,8 @@ bool post_ops_ok(jit_brdgmm_conv_conf_t &jcp, const primitive_attr_t &attr, {sum, eltwise, binary}, post_ops, &dst_d, false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/, false /*sum_requires_zp_zero*/, - {broadcasting_strategy_t::per_oc, - broadcasting_strategy_t::scalar})); + {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar, + broadcasting_strategy_t::no_broadcast})); } status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {