Skip to content

Commit

Permalink
extract invariants pass
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Apr 29, 2024
1 parent 2ce0e4a commit 4e018b8
Show file tree
Hide file tree
Showing 6 changed files with 449 additions and 17 deletions.
23 changes: 23 additions & 0 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,29 @@ class LoopManager {
* @param expr the target expression
*/
void update_loop_ports(const ExpressionPtr& expr);
/**
* @brief Insert Loop ports for one Loop.
* The method inserts ports at end of loop ports. It does not respect ports order, so sort_loop_ports is recommended where necessary.
* @param loop_id the target Loop ID
* @param target_ports vector of the ports need insert
* @param is_entry True if these ports are input and insert to loop entry points, otherwise - output and insert to exit point
*/
void insert_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry = true);
/**
* @brief Delete Loop ports for one Loop.
* The method delete ports directly. It does not respect ports order, so sort_loop_ports is recommended where necessary.
* @param loop_id the target Loop ID
* @param target_ports vector of the ports need delete
* @param is_entry True if these ports are input and delete from loop entry points, otherwise - output and delete from exit point
*/
void delete_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry = true);
/**
* @brief Check if a expression port is in Loop Ports.
* @param loop_ports the Loop Ports
* @param target_port the Expression Port
* @return True if Expression Port is in Loop Ports, otherwise return false.
*/
bool is_loop_port(const std::vector<LoopPort>& loop_ports, const ExpressionPort& target_port);
/**
* @brief Sort Loop Ports by expression locations in Linear IR
* @param loop_begin_pos the first expression iterator of the Loop
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface ExtractLoopInvariants
* @brief Extract the exprs that produce identical result in loop iteration to outside of loop
* @ingroup snippets
*/
class ExtractLoopInvariants : public Pass {
public:
OPENVINO_RTTI("ExtractLoopInvariants", "Pass")
ExtractLoopInvariants() = default;
bool run(LinearIR& linear_ir) override;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
38 changes: 37 additions & 1 deletion src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LoopManager::get_loop_bo
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits) {
OPENVINO_ASSERT(!entries.empty(), "Loop must have entry points");
OPENVINO_ASSERT(!exits.empty(), "Loop must have entry points");
OPENVINO_ASSERT(!exits.empty(), "Loop must have exit points");

const auto& entry_expr = entries.front().expr_port->get_expr();
auto loop_begin_pos = linear_ir.find(entry_expr);
Expand Down Expand Up @@ -420,6 +420,42 @@ void LoopManager::update_loop_ports(const ExpressionPtr& expr) {
}
}

void LoopManager::insert_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry) {
const auto& loop_info = get_loop_info(loop_id);
auto ports = is_entry ? loop_info->get_entry_points() : loop_info->get_exit_points();
for (size_t i = 0; i < target_ports.size(); i++) {
// if already in loop ports, skip
const auto& target_port = target_ports[i];
if (is_loop_port(ports, target_port))
continue;

ports.push_back(target_port);
}
is_entry ? loop_info->set_entry_points(ports) : loop_info->set_exit_points(ports);
}

void LoopManager::delete_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry) {
const auto& loop_info = get_loop_info(loop_id);
auto ports = is_entry ? loop_info->get_entry_points() : loop_info->get_exit_points();
for (size_t i = 0; i < target_ports.size(); i++) {
// if not in loop ports, skip
const auto& target_port = target_ports[i];
auto port_it = std::find_if(ports.begin(), ports.end(),
[&target_port](const LoopPort& point) { return *point.expr_port.get() == target_port; });
if (port_it == ports.end())
continue;

ports.erase(port_it);
}
is_entry ? loop_info->set_entry_points(ports) : loop_info->set_exit_points(ports);
}

bool LoopManager::is_loop_port(const std::vector<LoopPort>& loop_ports, const ExpressionPort& target_port) {
auto port_it = std::find_if(loop_ports.begin(), loop_ports.end(),
[&](const LoopPort& point) { return *point.expr_port.get() == target_port; });
return port_it != loop_ports.end();
}

void LoopManager::expression_replacement(LinearIR::constExprIt new_expr_begin, LinearIR::constExprIt new_expr_end, const ExpressionPtr& decomposed_expr,
size_t loop_id, const std::vector<ExpressionPort>& entries, const std::vector<ExpressionPort>& exits) {
for (auto it = new_expr_begin; it!= new_expr_end; ++it) {
Expand Down
Loading

0 comments on commit 4e018b8

Please sign in to comment.