Skip to content

Commit

Permalink
add lower pass test to compare graph and loop
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed May 31, 2024
1 parent c575707 commit 015a5ae
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/common/snippets/tests/include/lir_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,24 @@ void create_and_add_unified_loop_info(const std::shared_ptr<ov::snippets::lowere
const std::vector<ov::snippets::lowered::LoopPort>& entries,
const std::vector<ov::snippets::lowered::LoopPort>& exits,
bool add_default_handlers = true);
/**
* @brief Creates unified loop info based on provided entry and exit points, and adds it to the linear_ir's loops map.
* Meanwhile set loop id to expr range [loop_begin_pos, loop_end_pos).
* @attention This helper wraps LoopManager::mark_loop method, which also marks expressions with the corresponding loop info
* @param linear_ir linear_ir in which loop info should be added
* @param loop_begin_pos begin expr postion in this loop
* @param loop_end_pos end expr postion in this loop
* @param entries entry points of loop
* @param exits exit points of loop
*/
void create_and_add_unified_loop_info(const std::shared_ptr<ov::snippets::lowered::LinearIR>& linear_ir,
ov::snippets::lowered::LinearIR::constExprIt loop_begin_pos,
ov::snippets::lowered::LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t increment,
const std::vector<ov::snippets::lowered::LoopPort>& entries,
const std::vector<ov::snippets::lowered::LoopPort>& exits,
bool add_default_handlers = true);
} // namespace snippets
} // namespace test
} // namespace ov
12 changes: 12 additions & 0 deletions src/common/snippets/tests/src/lir_test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ void create_and_add_unified_loop_info(const LinearIRPtr& linear_ir,
loop_manager->mark_loop(linear_ir->begin(), linear_ir->begin(), work_amount, increment, entries, exits, set_default_handlers);
}

void create_and_add_unified_loop_info(const LinearIRPtr& linear_ir,
ov::snippets::lowered::LinearIR::constExprIt loop_begin_pos,
ov::snippets::lowered::LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t increment,
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits,
bool set_default_handlers) {
const auto& loop_manager = linear_ir->get_loop_manager();
loop_manager->mark_loop(loop_begin_pos, loop_end_pos, work_amount, increment, entries, exits, set_default_handlers);
}

} // namespace snippets
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "lir_test_utils.hpp"

#include "openvino/opsets/opset10.hpp"
#include "snippets/lowered/pass/extract_loop_invariants.hpp"

namespace ov {
namespace test {
namespace snippets {

using namespace ov::snippets::lowered;
using namespace ov::snippets::lowered::pass;

class ExtractLoopInvariantsTest : public LoweredPassTestsF {
public:
ExtractLoopInvariantsTest() : LoweredPassTestsF() {
comparator.enable(LIRComparator::LIRCmpValues::LOOP_INDICES);
comparator.enable(LIRComparator::LIRCmpValues::PORT_DESCRIPTORS);
comparator.enable(LIRComparator::LIRCmpValues::PORT_CONNECTORS);
comparator.enable(LIRComparator::LIRCmpValues::LOOP_MANAGER);
}

void SetUp() override {
pipeline.register_pass<ExtractLoopInvariants>();
}
};

TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariants) {
size_t vector_size = 16;
const auto input_precision = ov::element::f32;
const ov::Shape scalar_shape{1};
const ov::Shape input_shape0{1};
const ov::Shape input_shape1{512};
/*
* Param0 Scalar
* \ /
* Multiply(loopBegin)
* |
* Broadcast Param1
* \ /
* Substract(loopBeginRef)
* |
* Store
* |
* Result(LoopEnd and LoopEndRef)
*/
{
auto param0 = linear_ir->push_node<ov::opset10::Parameter>(input_precision, input_shape0);
auto param1 = linear_ir->push_node<ov::opset10::Parameter>(input_precision, input_shape1);
auto scalar = linear_ir->push_node<ov::snippets::op::Scalar>(input_precision, scalar_shape, 3.8f);
auto multiply = linear_ir->push_node<ov::opset10::Multiply>(param0.second, scalar.second);
init_expr_descriptors(*multiply.first, {{1}, {1}, {1}}, {{0}, {0}, {0}});
auto broadcastmove = linear_ir->push_node<ov::snippets::op::BroadcastMove>(multiply.second, 256);
auto sub = linear_ir->push_node<ov::opset10::Subtract>(param1.second, broadcastmove.second);
auto result = linear_ir->push_node<ov::opset10::Result>(sub.second);
init_expr_descriptors(*sub.first, {{512}, {1}, {512}}, {{0}, {0}, {0}});
auto begin = linear_ir->find(*scalar.first);
auto end = linear_ir->find(*result.first);
create_and_add_unified_loop_info(linear_ir, begin, end, 512, vector_size,
{LoopPort((*multiply.first)->get_input_port(0)), LoopPort((*sub.first)->get_input_port(0))},
{LoopPort((*sub.first)->get_output_port(0))});
}
{
auto param0 = linear_ir_ref->push_node<ov::opset10::Parameter>(input_precision, input_shape0);
auto param1 = linear_ir_ref->push_node<ov::opset10::Parameter>(input_precision, input_shape1);
auto scalar = linear_ir_ref->push_node<ov::snippets::op::Scalar>(input_precision, scalar_shape, 3.8f);
auto multiply = linear_ir_ref->push_node<ov::opset10::Multiply>(param0.second, scalar.second);
init_expr_descriptors(*multiply.first, {{1}, {1}, {1}}, {{0}, {0}, {0}});
auto broadcastmove = linear_ir_ref->push_node<ov::snippets::op::BroadcastMove>(multiply.second, 256);
auto sub = linear_ir_ref->push_node<ov::opset10::Subtract>(param1.second, broadcastmove.second);
auto result = linear_ir_ref->push_node<ov::opset10::Result>(sub.second);
init_expr_descriptors(*sub.first, {{512}, {1}, {512}}, {{0}, {0}, {0}});
auto begin = linear_ir_ref->find(*sub.first);
auto end = linear_ir_ref->find(*result.first);
create_and_add_unified_loop_info(linear_ir_ref, begin, end, 512, vector_size,
{LoopPort((*sub.first)->get_input_port(0)), LoopPort((*sub.first)->get_input_port(1))},
{LoopPort((*sub.first)->get_output_port(0))});
}
}
} // namespace snippets
} // namespace test
} // namespace ov

0 comments on commit 015a5ae

Please sign in to comment.