Skip to content

Commit

Permalink
gpu: jit: conv: normalize group layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
dyoussif authored and vpirogov committed Oct 21, 2022
1 parent 7ba3c40 commit 4e84474
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/gpu/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ status_t conv_config_t::init_data_layouts(convolution_pd_t *conv_pd) {
// Normalize layouts: add group dimension for all layouts and reduce/fuse
// spatial dimensions when applicable.
normalize_conv_layouts(src_layout, wei_layout, dst_layout, bia_layout,
with_groups, g, is_dw, reduced_dim, fuse_spatial,
with_groups, g, ic, oc, is_dw, reduced_dim, fuse_spatial,
/*add_groups=*/true);

return status::success;
Expand Down
14 changes: 8 additions & 6 deletions src/gpu/jit/conv/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,20 +719,22 @@ std::vector<dim_t> normalize_conv_dims(std::vector<dim_t> &dims,
}

void normalize_conv_layouts(layout_t &src_layout, layout_t &wei_layout,
layout_t &dst_layout, layout_t &bia_layout, bool with_groups,
int groups, bool is_dw, int reduced_dim, bool fuse_spatial,
layout_t &dst_layout, layout_t &bia_layout, bool with_groups, int g,
int ic, int oc, bool is_dw, int reduced_dim, bool fuse_spatial,
bool add_groups) {
src_layout = normalize_conv_layout(src_layout, /*with_groups=*/false,
groups, is_dw, reduced_dim, fuse_spatial, add_groups,
g > 1 ? src_layout.dim(1) / ic : 1, is_dw, reduced_dim,
fuse_spatial, add_groups,
/*is_wei=*/false);
wei_layout = normalize_conv_layout(wei_layout, with_groups, groups, is_dw,
wei_layout = normalize_conv_layout(wei_layout, with_groups, g, is_dw,
reduced_dim, /*fuse_spatial=*/false, add_groups, /*is_wei=*/true);
dst_layout = normalize_conv_layout(dst_layout, /*with_groups=*/false,
groups, is_dw, reduced_dim, fuse_spatial, add_groups,
g > 1 ? dst_layout.dim(1) / oc : 1, is_dw, reduced_dim,
fuse_spatial, add_groups,
/*is_wei=*/false);
if (add_groups && !bia_layout.is_empty()) {
ir_assert(bia_layout.ndims() == 1) << bia_layout;
bia_layout = split_dimension(bia_layout, 0, groups);
bia_layout = split_dimension(bia_layout, 0, g);
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/gpu/jit/conv/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1953,16 +1953,17 @@ layout_t normalize_conv_layout(const layout_t &_layout, bool with_groups,
bool add_groups, bool is_wei);

void normalize_conv_layouts(layout_t &src_layout, layout_t &wei_layout,
layout_t &dst_layout, layout_t &bia_layout, bool with_groups,
int groups, bool is_dw, int reduced_dim, bool fuse_spatial,
layout_t &dst_layout, layout_t &bia_layout, bool with_groups, int g,
int ic, int oc, bool is_dw, int reduced_dim, bool fuse_spatial,
bool add_groups);

inline void normalize_conv_layouts(layout_t &src_layout, layout_t &wei_layout,
layout_t &dst_layout, bool with_groups, int groups, bool is_dw,
int reduced_dim, bool fuse_spatial, bool add_groups) {
layout_t &dst_layout, bool with_groups, int g, int ic, int oc,
bool is_dw, int reduced_dim, bool fuse_spatial, bool add_groups) {
layout_t bia_layout;
normalize_conv_layouts(src_layout, wei_layout, dst_layout, bia_layout,
with_groups, groups, is_dw, reduced_dim, fuse_spatial, add_groups);
with_groups, g, ic, oc, is_dw, reduced_dim, fuse_spatial,
add_groups);
}

} // namespace jit
Expand Down

0 comments on commit 4e84474

Please sign in to comment.