Skip to content

Commit

Permalink
apply Alexandra comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed May 17, 2024
1 parent 38b21b4 commit 253ba95
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 367 deletions.
36 changes: 29 additions & 7 deletions src/common/snippets/include/snippets/lowered/loop_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ class LoopInfo {
return get_type_info().name;
}

/**
* @brief Find LoopPort in input and output ports
* @param loop_port target port
* @return iterator of the corresponding collection
*/
template<typename T>
std::vector<LoopPort>::iterator find_loop_port(const T& loop_port);

protected:
/**
* @brief Helper to clone Loop ports using `ExpressionMap`
Expand All @@ -148,13 +156,6 @@ class LoopInfo {
* @return vector with new cloned loop ports
*/
static std::vector<LoopPort> clone_loop_ports(const ExpressionMap& expr_map, const std::vector<LoopPort>& loop_ports);
/**
* @brief Find LoopPort in input and output ports
* @param loop_port target port
* @return iterator of the corresponding collection
*/
template<typename T>
std::vector<LoopPort>::iterator find_loop_port(const T& loop_port);

size_t m_work_amount = 0;
size_t m_increment = 0;
Expand Down Expand Up @@ -295,6 +296,27 @@ class UnifiedLoopInfo : public LoopInfo {
*/
void replace_with_new_ports(const ExpressionPort& actual_port, const std::vector<ExpressionPort>& target_ports) override;

/**
* @brief Remove the current LoopPort that contains ExpressionPort.
* Note: If there is no LoopPort with ExpressionPort `ports`, does nothing.
* This function remove directly without respect ports order, caller is responsible for the order by sort.
* @param ports need to be removed
*/
void update_loop_ports(const std::vector<ExpressionPort>& actual_ports, const std::vector<ExpressionPort>& target_ports, bool is_input);
/**
* @brief Remove the current LoopPort that contains ExpressionPort.
* Note: If there is no LoopPort with ExpressionPort `ports`, does nothing.
* This function remove directly without respect ports order, caller is responsible for the order by sort.
* @param ports need to be removed
*/
void remove_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input);
/**
* @brief Add ports to the current LoopPort.
* This function add in back of current LoopPort without respect ports order, caller is responsible for the order by sort.
* @param ports need to be added
*/
void add_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input);

/**
* @brief Iterates through all LoopPortDesc and call `caller` for each of them
* @param caller - function that called for each LoopPortDesc
Expand Down
34 changes: 6 additions & 28 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,29 +210,6 @@ class LoopManager {
* @param expr the target expression
*/
void update_loop_ports(const ExpressionPtr& expr);
/**
* @brief Insert Loop ports for one Loop.
* The method inserts ports at end of loop ports. It does not respect ports order, so sort_loop_ports is recommended where necessary.
* @param loop_id the target Loop ID
* @param target_ports vector of the ports need insert
* @param is_entry True if these ports are input and insert to loop entry points, otherwise - output and insert to exit point
*/
void insert_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry = true);
/**
* @brief Delete Loop ports for one Loop.
* The method delete ports directly. It does not respect ports order, so sort_loop_ports is recommended where necessary.
* @param loop_id the target Loop ID
* @param target_ports vector of the ports need delete
* @param is_entry True if these ports are input and delete from loop entry points, otherwise - output and delete from exit point
*/
void delete_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry = true);
/**
* @brief Check if a expression port is in Loop Ports.
* @param loop_ports the Loop Ports
* @param target_port the Expression Port
* @return True if Expression Port is in Loop Ports, otherwise return false.
*/
bool is_loop_port(const std::vector<LoopPort>& loop_ports, const ExpressionPort& target_port);
/**
* @brief Sort Unified Loop Ports by expression locations in Linear IR
* @param loop_begin_pos the first expression iterator of the Loop
Expand Down Expand Up @@ -297,18 +274,19 @@ class LoopManager {
*/
bool reorder_identifiers(const std::map<size_t, size_t>& loop_id_map);

/**
* @brief Remove LoopInfo from the map
* @param index the target index of Loop
*/
void remove_loop_info(size_t index);

private:
/**
* @brief Add new Loop Info to the map
* @param loop target loop info
* @return the loop ID
*/
size_t add_loop_info(const LoopInfoPtr& loop);
/**
* @brief Remove LoopInfo from the map
* @param index the target index of Loop
*/
void remove_loop_info(size_t index);
/**
* @brief Find expression ports in bounds that are connected to consumers or parent that aren't in these bounds
* @param loop_begin_pos the first expression iterator of the Loop
Expand Down
35 changes: 35 additions & 0 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,41 @@ void UnifiedLoopInfo::replace_with_new_ports(const ExpressionPort& actual_port,
validate();
}

void UnifiedLoopInfo::update_loop_ports(const std::vector<ExpressionPort>& actual_ports, const std::vector<ExpressionPort>& target_ports, bool is_input) {
if (!actual_ports.empty())
remove_loop_ports(actual_ports, is_input);
if (!target_ports.empty())
add_loop_ports(target_ports, is_input);
validate();
}

void UnifiedLoopInfo::remove_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input) {
auto& loop_ports = is_input ? m_input_ports : m_output_ports;
auto& loop_ports_desc = is_input ? m_input_port_descs : m_output_port_descs;
for (size_t i = 0; i < ports.size(); i++) {
auto port_it = find_loop_port(ports[i]);
// if not in loop ports, skip
if (port_it == loop_ports.end())
continue;

loop_ports.erase(port_it);
auto dist = std::distance(loop_ports.begin(), port_it);
loop_ports_desc.erase(loop_ports_desc.begin() + dist);
}
}

void UnifiedLoopInfo::add_loop_ports(const std::vector<ExpressionPort>& ports, bool is_input) {
auto& loop_ports = is_input ? m_input_ports : m_output_ports;
auto& loop_ports_desc = is_input ? m_input_port_descs : m_output_port_descs;
for (size_t i = 0; i < ports.size(); i++) {
// if already in loop ports, skip
if (find_loop_port(ports[i]) != loop_ports.end())
continue;
loop_ports.push_back(ports[i]);
loop_ports_desc.push_back(LoopPortDesc());
}
}

ExpandedLoopInfo::ExpandedLoopInfo(size_t work_amount, size_t increment,
const std::vector<LoopPort>& entries, const std::vector<LoopPort>& exits,
std::vector<int64_t> ptr_increments, std::vector<int64_t> final_offsets, std::vector<int64_t> data_sizes,
Expand Down
36 changes: 0 additions & 36 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,42 +373,6 @@ void LoopManager::update_loop_ports(const ExpressionPtr& expr) {
}
}

void LoopManager::insert_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry) {
const auto& loop_info = get_loop_info(loop_id);
auto ports = is_entry ? loop_info->get_entry_points() : loop_info->get_exit_points();
for (size_t i = 0; i < target_ports.size(); i++) {
// if already in loop ports, skip
const auto& target_port = target_ports[i];
if (is_loop_port(ports, target_port))
continue;

ports.push_back(target_port);
}
is_entry ? loop_info->set_entry_points(ports) : loop_info->set_exit_points(ports);
}

void LoopManager::delete_loop_ports(size_t loop_id, const std::vector<ExpressionPort>& target_ports, bool is_entry) {
const auto& loop_info = get_loop_info(loop_id);
auto ports = is_entry ? loop_info->get_entry_points() : loop_info->get_exit_points();
for (size_t i = 0; i < target_ports.size(); i++) {
// if not in loop ports, skip
const auto& target_port = target_ports[i];
auto port_it = std::find_if(ports.begin(), ports.end(),
[&target_port](const LoopPort& point) { return *point.expr_port.get() == target_port; });
if (port_it == ports.end())
continue;

ports.erase(port_it);
}
is_entry ? loop_info->set_entry_points(ports) : loop_info->set_exit_points(ports);
}

bool LoopManager::is_loop_port(const std::vector<LoopPort>& loop_ports, const ExpressionPort& target_port) {
auto port_it = std::find_if(loop_ports.begin(), loop_ports.end(),
[&](const LoopPort& point) { return *point.expr_port.get() == target_port; });
return port_it != loop_ports.end();
}

void LoopManager::expression_replacement(LinearIR::constExprIt new_expr_begin, LinearIR::constExprIt new_expr_end, const ExpressionPtr& decomposed_expr,
size_t loop_id, const std::vector<ExpressionPort>& entries, const std::vector<ExpressionPort>& exits) {
for (auto it = new_expr_begin; it!= new_expr_end; ++it) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ bool DefineBufferClusters::are_buffer_neighbours(const ExpressionPtr& up, const

void DefineBufferClusters::parse_memory_access_op(const ExpressionPtr& expr) {
const auto ma = std::dynamic_pointer_cast<modifier::MemoryAccess>(expr->get_node());
if (!ma->is_full_memory_access_op(expr->get_node()))
return;
// TODO: Some full MemoryAccess ops can have inplace inputs and outputs in general.
// Need to add mechanism of inplace ports using MemoryAccess::PortDescriptor::inplace
for (const auto& input : expr->get_input_port_connectors()) {
Expand Down
Loading

0 comments on commit 253ba95

Please sign in to comment.