diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index f0653523850..30c15286040 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -175,6 +175,21 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag( blocked_16n_B_layout_tag) : memory_desc_matches_one_of_tag(B_md, plain_tensor_layout_tag, transposed_tensor_layout_tag, acbd, adbc); + + // For cases when the weights tensor is transposed but has + // 'dim_size == 1', we can ignore transposition and compute as a plain + // format tensor. This removes the need of allocating a scratchpad for + // copy_B. + if (transposed_tensor_layout_tag == bgmmc.wei_tag) { + memory_desc_t B_md_plain; + const status_t status + = dnnl_memory_desc_init_by_tag(&B_md_plain, B_md.ndims, + B_md.dims, B_md.data_type, plain_tensor_layout_tag); + if (status != status::success) return status; + if (status == status::success && B_md_plain == B_md) + bgmmc.wei_tag = plain_tensor_layout_tag; + } + if (format_tag::undef == bgmmc.wei_tag) return status::unimplemented; }