Skip to content

Commit

Permalink
gpu: jit: conv: allow blocked format tags for grouped conv
Browse files Browse the repository at this point in the history
  • Loading branch information
dyoussif authored and vpirogov committed Oct 21, 2022
1 parent 53532a9 commit 7ba3c40
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/gpu/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ status_t conv_config_t::init_common_blocking() {
// Set base blocks to align kernel blocking with layout blocking.
if (is_fwd) {
bh->set_base_iter_block("mb", src_layout.inner_block(0));
int src_g_blk = src_layout.inner_block(1);
int wei_g_blk = wei_layout.inner_block(0);
int src_g_blk = is_dw ? src_layout.inner_block(1) : 1;
int wei_g_blk = is_dw ? wei_layout.inner_block(0) : 1;
bh->set_base_iter_block("g", src_g_blk, wei_g_blk);
int src_ic_blk = src_layout.inner_block(2);
int wei_ic_blk = wei_layout.inner_block(2);
Expand Down Expand Up @@ -572,22 +572,23 @@ struct nc_block_t {
static nc_block_t get_default_blocking(type_t type, bool is_dw, int n,
int c, int g, bool is_input, bool is_small_ic) {
bool is_small_ic_input
= (type.size() <= 2 && is_input && !is_dw && is_small_ic);
= (type.size() <= 2 && is_input && g == 1 && is_small_ic);
auto c_block = [&]() {
// Special case for small input channel shapes with dpas.
if (is_small_ic_input) {
int packed_dword_elems = 4 / type.size();
return std::max(packed_dword_elems, utils::rnd_up_pow2(c));
}
auto default_c_blk = type.size() == 1 ? 32 : 16;
auto blk_dim = is_dw ? g : c;
auto blk_dim = is_dw ? g : g * c;
return pick_block_rnd_up(blk_dim, default_c_blk);
}();

// Non-depthwise convolutions currently require channel is a multiple of
// c_block. If that implementation restriction is removed, this logic
// could be removed.
if (g > 1 && !is_dw && c % c_block != 0) c_block = 1;
if (g > 1 && !is_dw && c % c_block != 0 && c_block % c != 0)
c_block = 1;

auto n_block = [&]() {
auto default_n_blk = type.size() < 4 ? 32 : 16;
Expand Down Expand Up @@ -816,8 +817,8 @@ status_t conv_config_t::init_data_layouts(convolution_pd_t *conv_pd) {

// If src/dst is nhwc then set the other one with any to nhwc too (except
// 1st convolution).
bool is_small_ic_non_dw = is_small_ic() && !is_dw;
bool is_small_oc_non_dw = is_small_oc() && !is_dw;
bool is_small_ic_non_dw = is_small_ic() && g == 1;
bool is_small_oc_non_dw = is_small_oc() && g == 1;
bool propagate_nhwc = (matches_tag(src_md, "axb") && !is_small_ic_non_dw)
|| matches_tag(dst_md, "axb");
if (propagate_nhwc) {
Expand Down Expand Up @@ -1250,7 +1251,9 @@ void conv_config_t::init_bwd_d_optimize_strided(int iw_thr_blk) {

void conv_config_t::init_use_ow_kw_grf_cache() {
use_ow_kw_grf_cache = false;
if (!is_fwd || !is_small_ic() || kw < 3 || is_dw_large_mb()) return;
if (!is_fwd || !is_small_ic() || (is_small_ic() && !is_dw && g > 1)
|| kw < 3 || is_dw_large_mb())
return;
if (is_dp_fma()) return;
if (fuse_spatial) return;

Expand Down

0 comments on commit 7ba3c40

Please sign in to comment.