Skip to content

Commit

Permalink
Alexandra comments apply
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed May 28, 2024
1 parent 266e760 commit ccbc9c7
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 213 deletions.
38 changes: 19 additions & 19 deletions src/common/snippets/include/snippets/lowered/loop_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,26 +297,13 @@ class UnifiedLoopInfo : public LoopInfo {
void replace_with_new_ports(const ExpressionPort& actual_port, const std::vector<ExpressionPort>& target_ports) override;

/**
* @brief Remove the current LoopPort that contains ExpressionPort.
* Note: If there is no LoopPort with ExpressionPort `ports`, does nothing.
* This function remove directly without respect ports order, caller is responsible for the order by sort.
* @param ports need to be removed
*/
void update_loop_ports(const std::vector<ExpressionPort>& actual_ports, const std::vector<ExpressionPort>& target_ports, bool is_input);
/**
* @brief Remove the current LoopPort that contains ExpressionPort.
* Note: If there is no LoopPort with ExpressionPort `ports`, does nothing.
* This function remove directly without respect ports order, caller is responsible for the order by sort.
* @param ports need to be removed
* @brief Remove remove_ports and add add_ports to the current LoopPort.
* This function remove and add directly without respect ports order, caller is responsible for the order by sort.
* @param remove_ports need to be removed
* @param add_ports need to be added
* @param is_input true if update input port
*/
void remove_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input);
/**
* @brief Add ports to the current LoopPort.
* This function add in back of current LoopPort without respect ports order, caller is responsible for the order by sort.
* @param ports need to be added
*/
void add_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input);

void update_loop_ports(const std::vector<ExpressionPort>& remove_ports, const std::vector<ExpressionPort>& add_ports, bool is_input);
/**
* @brief Iterates through all LoopPortDesc and call `caller` for each of them
* @param caller - function that called for each LoopPortDesc
Expand Down Expand Up @@ -367,6 +354,19 @@ class UnifiedLoopInfo : public LoopInfo {
* - Consistency of ports and descriptors
*/
void validate() const;
/**
* @brief Remove the current LoopPort that contains ExpressionPort.
* Note: If there is no LoopPort with ExpressionPort `ports`, does nothing.
* This function remove directly without respect ports order, caller is responsible for the order by sort.
* @param ports need to be removed
*/
void remove_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input);
/**
* @brief Add ports to the current LoopPort.
* This function add in back of current LoopPort without respect ports order, caller is responsible for the order by sort.
* @param ports need to be added
*/
void add_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input);

SpecificIterationHandlers m_handlers = {};
std::vector<LoopPortDesc> m_input_port_descs = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include "pass.hpp"
#include "snippets/lowered/loop_manager.hpp"

namespace ov {
namespace snippets {
Expand All @@ -16,11 +17,19 @@ namespace pass {
* @brief Extract the exprs that produce identical result in loop iteration to outside of loop
* @ingroup snippets
*/
class ExtractLoopInvariants : public Pass {
class ExtractLoopInvariants : public RangedPass {
public:
OPENVINO_RTTI("ExtractLoopInvariants", "Pass")
ExtractLoopInvariants() = default;
bool run(LinearIR& linear_ir) override;
OPENVINO_RTTI("ExtractLoopInvariants", "RangedPass")
ExtractLoopInvariants();
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

private:
bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info);
void extract_expr(const ExpressionPtr& expr, LinearIR& linear_ir,
LinearIR::constExprIt& inner_loop_begin_pos, LinearIR::constExprIt& inner_loop_end_pos);
void update_loop_ports(const ExpressionPtr& expr, const LoopManagerPtr& loop_manager, size_t inner_loop_id,
LinearIR::constExprIt& inner_loop_begin_pos, LinearIR::constExprIt& inner_loop_end_pos);
bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir);
};

} // namespace pass
Expand Down
Loading

0 comments on commit ccbc9c7

Please sign in to comment.