Skip to content

Commit

Permalink
Fix layout handling for fused input/output transposes
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov committed Jun 13, 2024
1 parent 451fe36 commit 99983aa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ std::shared_ptr<BrgemmCompiledKernel> BrgemmKernelExecutor::compile_kernel(const
}

void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, std::shared_ptr<BrgemmKernelConfig>& config) const {
auto get_projected_subtensor = [](const snippets::lowered::PortDescriptorPtr& desc) {
auto shape = desc->get_shape();
auto get_projected_input_subtensor = [](const snippets::lowered::PortDescriptorPtr& desc) {
// Note: for output shape you will need get_preordered_vdims()
auto shape = snippets::utils::get_planar_vdims(desc->get_shape(), desc->get_layout());
auto subtensor = desc->get_subtensor();
// Note: Scalar is a special case, so it's easier to prepend shapes than to handle it explicitly
if (shape.size() == 1)
Expand All @@ -133,15 +134,15 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
OV_CPU_JIT_EMITTER_ASSERT(input_pds.size() == 2 && output_pds.size() == 1, "Invalid number of in/out port descriptors");
// Update runtime-defined config fields:
// Matrix A (first input)
config->LDA = DIM_CAST(snippets::utils::get_out_leading_dim(input_pds[0]->get_shape(), m_layouts[0]));
config->K = DIM_CAST(*get_projected_subtensor(input_pds[0]).rbegin());
config->LDA = DIM_CAST(snippets::utils::get_in_leading_dim(input_pds[0]->get_shape(), m_layouts[0]));
const auto& in0_subtensor = get_projected_input_subtensor(input_pds[0]);
config->K = DIM_CAST(*in0_subtensor.rbegin());
config->M = DIM_CAST(*++in0_subtensor.rbegin());
// Matrix B (second input)
config->LDB = DIM_CAST(snippets::utils::get_out_leading_dim(input_pds[1]->get_shape(), m_layouts[1]));
config->LDB = DIM_CAST(snippets::utils::get_in_leading_dim(input_pds[1]->get_shape(), m_layouts[1]));
config->N = DIM_CAST(*get_projected_input_subtensor(input_pds[1]).rbegin());
// Matrix C (output)
config->LDC = DIM_CAST(snippets::utils::get_out_leading_dim(output_pds[0]->get_shape(), m_layouts[2]));
const auto& out_subtensor = get_projected_subtensor(output_pds[0]);
config->N = DIM_CAST(*out_subtensor.rbegin());
config->M = DIM_CAST(*++out_subtensor.rbegin());
}

void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* desc, call_args* args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
// N blocking is disabled in dynamism by default
if (N_dim.is_dynamic())
return snippets::utils::get_dynamic_value<size_t>();
return input_1_precision == ov::element::f32 || N_dim.is_dynamic() ? 64 : N_dim.get_length();
return input_1_precision == ov::element::f32 ? 64 : N_dim.get_length();
};

const auto brgemm_in0_dims = snippets::utils::get_planar_pshape(brgemm->input(0));
Expand Down

0 comments on commit 99983aa

Please sign in to comment.