Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] extract invariants pass #24044

Merged
36 changes: 36 additions & 0 deletions src/common/snippets/include/snippets/lowered/loop_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ class LoopInfo {
const char* get_type_name() const {
return get_type_info().name;
}
/**
* @brief Return true if expression port is a loop port
* @param expr_port - expression port to check
*/
bool is_loop_port(const ExpressionPort& expr_port);
/**
* @brief Return loop port of an expression port
* @param expr_port - expression port.
*/
const LoopPort& get_loop_port(const ExpressionPort& expr_port);
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved

protected:
/**
Expand Down Expand Up @@ -324,6 +334,15 @@ class UnifiedLoopInfo : public LoopInfo {
*/
void replace_with_new_ports(const ExpressionPort& actual_port, const std::vector<ExpressionPort>& target_ports) override;

/**
* @brief Remove remove_ports and add add_ports to the current LoopPort.
* This function removes ports directly and adds ports at the end of current LoopPort, caller is responsible to
* sort the LoopPort after LoopPort being updated according to execution order of the expressions.
* Note: all port in remove_ports and add_ports should have the same type.
* @param remove_ports need to be removed
* @param add_ports need to be added
*/
void update_loop_ports(const std::vector<ExpressionPort>& remove_ports, const std::vector<ExpressionPort>& add_ports);
/**
* @brief Iterates through all LoopPortDesc and call `caller` for each of them
* @param caller - function that called for each LoopPortDesc
Expand Down Expand Up @@ -374,6 +393,23 @@ class UnifiedLoopInfo : public LoopInfo {
* - Consistency of ports and descriptors
*/
void validate() const;
/**
* @brief Remove the current LoopPort that contains ExpressionPort.
* Note: If there is no LoopPort with ExpressionPort `ports`, does nothing.
* This function removes ports directly, caller is responsible to sort the LoopPort after updated
* according to execution order of the expressions.
* Note: all port in ports should have the same type.
* @param ports need to be removed
*/
void remove_loop_ports(const std::vector<ExpressionPort>& ports);
/**
* @brief Add ports to the current LoopPort.
* This function adds ports in end of current LoopPort vector, caller is responsible to
* sort the LoopPort after updated according to execution order of the expressions.
* Note: all port in ports should have the same type.
* @param ports need to be added
*/
void add_loop_ports(const std::vector<ExpressionPort>& ports);

SpecificIterationHandlers m_handlers = {};
std::vector<LoopPortDesc> m_input_port_descs = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
//

#pragma once

#include <openvino/core/node.hpp>
#include <openvino/opsets/opset1.hpp>

Expand Down Expand Up @@ -218,7 +217,7 @@ class LoopManager {
* @param loop_end_pos the next iterator after the last expression
* @param loop_id target Loop ID
*/
void sort_loop_ports(LinearIR::constExprIt& loop_begin_pos, LinearIR::constExprIt& loop_end_pos, size_t loop_id);
void sort_loop_ports(const LinearIR::constExprIt& loop_begin_pos, const LinearIR::constExprIt& loop_end_pos, size_t loop_id);
/**
* @brief When the previous expression was replaced with new expressions (decomposition), the method updates the corresponding Loop.
* If ports of decomposed expression were the Loop ports, these Loop ports may be updated by parameters `entries` and `exits`
Expand Down Expand Up @@ -276,7 +275,6 @@ class LoopManager {
*/
bool reorder_identifiers(const std::map<size_t, size_t>& loop_id_map);

private:
/**
* @brief Add new Loop Info to the map
* @param loop target loop info
Expand All @@ -288,6 +286,8 @@ class LoopManager {
* @param index the target index of Loop
*/
void remove_loop_info(size_t index);

private:
/**
* @brief Find expression ports in bounds that are connected to consumers or parent that aren't in these bounds
* @param loop_begin_pos the first expression iterator of the Loop
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"
#include "snippets/lowered/loop_manager.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface ExtractLoopInvariants
* @brief Extracts expressions that produce identical result on every loop iteration outside of the loop's body.
* This extraction is to remove repeated computation, not cover constant subgraph extraction.
* @ingroup snippets
*/
class ExtractLoopInvariants : public RangedPass {
public:
OPENVINO_RTTI("ExtractLoopInvariants", "RangedPass")
ExtractLoopInvariants() = default;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
4 changes: 4 additions & 0 deletions src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ inline size_t get_output_dim_idx(const std::vector<size_t>& layout, size_t dim_i
// dim_idx starts from the layout end
size_t get_dim_idx(const lowered::ExpressionPort& port, size_t dim_idx);

// get stride on dimenison of dim_idx
// given shape [a,b,c,d], the stride is [b*c*d, c*d, d, 1]
int64_t get_stride(size_t dim_idx, const VectorDims& shape);

/* ----- Shape `getters` ----- */
/**
* @brief Returns a dense shape after applying the order.
Expand Down
60 changes: 58 additions & 2 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ std::vector<LoopPort>::iterator LoopInfo::find_loop_port(const LoopPort& loop_po
auto& ports = loop_port.expr_port->get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
const auto it = std::find_if(ports.begin(), ports.end(),
[&loop_port](const LoopPort& port) { return port == loop_port; });
OPENVINO_ASSERT(it != ports.end(), "Failed update_loop_port: existing loop port has not been found");
OPENVINO_ASSERT(it != ports.end(), "Failed find_loop_port: existing loop port has not been found");
return it;
}

Expand All @@ -110,6 +110,17 @@ std::vector<LoopPort>::iterator LoopInfo::find_loop_port(const ExpressionPort& e
return it;
}

bool LoopInfo::is_loop_port(const ExpressionPort& expr_port) {
const auto& loop_port_it = find_loop_port(expr_port);
const auto& ports = expr_port.get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
return loop_port_it != ports.end();
}

const LoopPort& LoopInfo::get_loop_port(const ExpressionPort& expr_port) {
OPENVINO_ASSERT(is_loop_port(expr_port), "Failed get_loop_port: expr_port is not a loop port");
return *find_loop_port(expr_port);
}

void LoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector<LoopPort>& target_ports) {
auto& ports = actual_port.expr_port->get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
auto port_it = find_loop_port(actual_port);
Expand Down Expand Up @@ -259,7 +270,7 @@ void order(const std::vector<size_t>& new_order, std::vector<T>& values) {
"Failed to sort values: `new_order` must contain new indexes for ALL values");
std::vector<T> ordered_values(values.size());
for (size_t i = 0; i < values.size(); ++i) {
ordered_values[new_order[i]] = values[i];
ordered_values[i] = values[new_order[i]];
chenhu-wang marked this conversation as resolved.
Show resolved Hide resolved
}
values = std::move(ordered_values);
}
Expand Down Expand Up @@ -314,6 +325,51 @@ void UnifiedLoopInfo::replace_with_new_ports(const ExpressionPort& actual_port,
validate();
}

void UnifiedLoopInfo::update_loop_ports(const std::vector<ExpressionPort>& actual_ports, const std::vector<ExpressionPort>& target_ports) {
add_loop_ports(target_ports);
remove_loop_ports(actual_ports);
validate();
}

void UnifiedLoopInfo::remove_loop_ports(const std::vector<ExpressionPort>& ports) {
if (ports.empty())
return;
bool is_input = ports[0].get_type() == ExpressionPort::Input;
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
auto& loop_ports = is_input ? m_input_ports : m_output_ports;
auto& loop_ports_desc = is_input ? m_input_port_descs : m_output_port_descs;
for (size_t i = 0; i < ports.size(); i++) {
OPENVINO_ASSERT(is_input ? (ports[i].get_type() == ExpressionPort::Input) : (ports[i].get_type() == ExpressionPort::Output),
"ports in remove_loop_ports have different type.");
auto port_it = find_loop_port(ports[i]);
// if not in loop ports, skip
if (port_it == loop_ports.end())
continue;

loop_ports.erase(port_it);
auto dist = std::distance(loop_ports.begin(), port_it);
loop_ports_desc.erase(loop_ports_desc.begin() + dist);
}
}

void UnifiedLoopInfo::add_loop_ports(const std::vector<ExpressionPort>& ports) {
if (ports.empty())
return;
bool is_input = ports[0].get_type() == ExpressionPort::Input;
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
auto& loop_ports = is_input ? m_input_ports : m_output_ports;
auto& loop_ports_desc = is_input ? m_input_port_descs : m_output_port_descs;
size_t loop_dim_idx = get_dim_idx();
for (size_t i = 0; i < ports.size(); i++) {
OPENVINO_ASSERT(is_input ? (ports[i].get_type() == ExpressionPort::Input) : (ports[i].get_type() == ExpressionPort::Output),
"ports in add_loop_ports have different type.");
// if already in loop ports, skip
auto loop_port = find_loop_port(ports[i]);
if (loop_port != loop_ports.end())
continue;
loop_ports.push_back(LoopPort(ports[i], true, loop_dim_idx));
loop_ports_desc.push_back(LoopPortDesc());
}
}

ExpandedLoopInfo::ExpandedLoopInfo(size_t work_amount, size_t increment,
const std::vector<LoopPort>& entries, const std::vector<LoopPort>& exits,
std::vector<int64_t> ptr_increments, std::vector<int64_t> final_offsets, std::vector<int64_t> data_sizes,
Expand Down
11 changes: 2 additions & 9 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,8 @@ std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LoopManager::get_loop_bo
}

LoopPort LoopManager::get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id) {
auto get_loop_port = [&](const std::vector<LoopPort>& ports) {
auto it = std::find_if(ports.cbegin(), ports.cend(), [&](const LoopPort& p) { return *p.expr_port == expr_port; });
if (it == ports.cend())
OPENVINO_THROW("Expression has not been found among loop ports. Loop id: " + std::to_string(loop_id));
return *it;
};
const auto& loop_info = get_loop_info(loop_id);
return expr_port.get_type() == ExpressionPort::Input ? get_loop_port(loop_info->get_input_ports())
: get_loop_port(loop_info->get_output_ports());
return loop_info->get_loop_port(expr_port);
}

void LoopManager::get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
Expand Down Expand Up @@ -397,7 +390,7 @@ void LoopManager::expression_replacement(LinearIR::constExprIt new_expr_begin, L
}
}

void LoopManager::sort_loop_ports(LinearIR::constExprIt& loop_begin_pos, LinearIR::constExprIt& loop_end_pos, size_t loop_id) {
void LoopManager::sort_loop_ports(const LinearIR::constExprIt& loop_begin_pos, const LinearIR::constExprIt& loop_end_pos, size_t loop_id) {
// [113536] Update this logic please, when expression numeration will be implemented
const auto& loop_info = get_loop_info<UnifiedLoopInfo>(loop_id);
const auto& loop_entries = loop_info->get_input_ports();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,6 @@ bool DefineBufferClusters::are_buffer_neighbours(const ExpressionPtr& up, const

void DefineBufferClusters::parse_memory_access_op(const ExpressionPtr& expr) {
const auto ma = std::dynamic_pointer_cast<modifier::MemoryAccess>(expr->get_node());
if (!ma->is_full_memory_access_op(expr->get_node()))
return;
v-Golubev marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Some full MemoryAccess ops can have inplace inputs and outputs in general.
// Need to add mechanism of inplace ports using MemoryAccess::PortDescriptor::inplace
for (const auto& input : expr->get_input_port_connectors()) {
Expand Down
Loading
Loading