Skip to content

Commit

Permalink
common: concat: make non-dim-1-over-axis md to decide on dst format
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin authored and vpirogov committed Oct 17, 2022
1 parent feb614d commit 2a60ade
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/common/concat_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,18 @@ struct concat_pd_t : public primitive_desc_t {
if (status != status::success) {
for (int i = 0; i < n_; ++i) {
const memory_desc_wrapper src_d(src_mds_[i]);
if (src_d.is_blocking_desc() && src_d.is_plain()
&& src_d.nelems() > 0) {
// Dim of `1` may tweak a destination format leading to
// sub-optimal performance. Limit it to an axis case to allow
// case like a:a->ab or a:ab->ab to work properly.
// TODO: update the whole logic to getting string tags of
// sources but discarding dims of one. If ndims of any source
// coincides with dst ndims, use that tag (if they are same).
// If dst has +1 ndim (due to concat dim), use slices as dense
// layers inside a dst, which means axis should be the least
// dense dimension.
const bool axis_dim_has_one = src_d.dims()[concat_dim()] == 1;
if (!axis_dim_has_one && src_d.is_blocking_desc()
&& src_d.is_plain() && src_d.nelems() > 0) {
status = memory_desc_init_by_blocking_desc(dst_md_,
memory_desc_wrapper(src_mds_[i]).blocking_desc());
if (status == status::success) return status;
Expand Down

0 comments on commit 2a60ade

Please sign in to comment.