diff --git a/src/gpu/jit/conv/config.cpp b/src/gpu/jit/conv/config.cpp index 6ef86f0338f..b4d671abe3c 100644 --- a/src/gpu/jit/conv/config.cpp +++ b/src/gpu/jit/conv/config.cpp @@ -95,8 +95,8 @@ status_t conv_config_t::init_common_blocking() { // Set base blocks to align kernel blocking with layout blocking. if (is_fwd) { bh->set_base_iter_block("mb", src_layout.inner_block(0)); - int src_g_blk = src_layout.inner_block(1); - int wei_g_blk = wei_layout.inner_block(0); + int src_g_blk = is_dw ? src_layout.inner_block(1) : 1; + int wei_g_blk = is_dw ? wei_layout.inner_block(0) : 1; bh->set_base_iter_block("g", src_g_blk, wei_g_blk); int src_ic_blk = src_layout.inner_block(2); int wei_ic_blk = wei_layout.inner_block(2); @@ -572,7 +572,7 @@ struct nc_block_t { static nc_block_t get_default_blocking(type_t type, bool is_dw, int n, int c, int g, bool is_input, bool is_small_ic) { bool is_small_ic_input - = (type.size() <= 2 && is_input && !is_dw && is_small_ic); + = (type.size() <= 2 && is_input && g == 1 && is_small_ic); auto c_block = [&]() { // Special case for small input channel shapes with dpas. if (is_small_ic_input) { @@ -580,14 +580,15 @@ struct nc_block_t { return std::max(packed_dword_elems, utils::rnd_up_pow2(c)); } auto default_c_blk = type.size() == 1 ? 32 : 16; - auto blk_dim = is_dw ? g : c; + auto blk_dim = is_dw ? g : g * c; return pick_block_rnd_up(blk_dim, default_c_blk); }(); // Non-depthwise convolutions currently require channel is a multiple of // c_block. If that implementation restriction is removed, this logic // could be removed. - if (g > 1 && !is_dw && c % c_block != 0) c_block = 1; + if (g > 1 && !is_dw && c % c_block != 0 && c_block % c != 0) + c_block = 1; auto n_block = [&]() { auto default_n_blk = type.size() < 4 ? 32 : 16; @@ -816,8 +817,8 @@ status_t conv_config_t::init_data_layouts(convolution_pd_t *conv_pd) { // If src/dst is nhwc then set the other one with any to nhwc too (except // 1st convolution). - bool is_small_ic_non_dw = is_small_ic() && !is_dw; - bool is_small_oc_non_dw = is_small_oc() && !is_dw; + bool is_small_ic_non_dw = is_small_ic() && g == 1; + bool is_small_oc_non_dw = is_small_oc() && g == 1; bool propagate_nhwc = (matches_tag(src_md, "axb") && !is_small_ic_non_dw) || matches_tag(dst_md, "axb"); if (propagate_nhwc) { @@ -1250,7 +1251,9 @@ void conv_config_t::init_bwd_d_optimize_strided(int iw_thr_blk) { void conv_config_t::init_use_ow_kw_grf_cache() { use_ow_kw_grf_cache = false; - if (!is_fwd || !is_small_ic() || kw < 3 || is_dw_large_mb()) return; + if (!is_fwd || !is_small_ic() || (is_small_ic() && !is_dw && g > 1) + || kw < 3 || is_dw_large_mb()) + return; if (is_dp_fma()) return; if (fuse_spatial) return;