Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Pad/AveragePool fusion #23190

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace onnxruntime {

bool VerifyNotCastChild(const Node& child_node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
return false;
}
Expand All @@ -31,11 +31,32 @@ bool VerifyNotCastChild(const Node& child_node) {
return false;
}

if (child_node.OpType() == "AveragePool") {
// in case there's already padding and count_include_pad is 0, fusion can't be performed
auto has_pad = false;
if (child_node.GetAttributes().find("pads") != child_node.GetAttributes().end()) {
auto const& pads_values = child_node.GetAttributes().at("pads").ints();
if (!pads_values.empty()) {
has_pad = std::any_of(pads_values.begin(), pads_values.end(), [](int64_t value) { return value != 0; });
}
}
if (has_pad && child_node.GetAttributes().find("count_include_pad") != child_node.GetAttributes().end()) {
if (child_node.GetAttributes().at("count_include_pad").i() == 0) {
return false;
}
}
}

return true;
}

void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) {
auto reset_pads = true;
if (child_node.GetAttributes().find("pads") != child_node.GetAttributes().end()) {
/* pads can be empty, overwrite pads attribute in this case */
reset_pads = child_node.GetAttributes().at("pads").ints().empty();
}
if (reset_pads) {
std::vector<int64_t> pads(pads_size - 4, 0);
child_node.AddAttribute("pads", pads);
}
Expand All @@ -49,6 +70,10 @@ void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_v
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}

if (child_node.OpType() == "AveragePool") {
child_node.AddAttribute("count_include_pad", static_cast<int64_t>(1));
}
}
/*
* Before:
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,128 @@ TEST_F(GraphTransformationTests, FusePadWithMaxPoolOpsetLessThan11) {
}
}

TEST_F(GraphTransformationTests, FusePadWithAvgPool) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::vector<int64_t> expected_pads;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "Pad") {
auto const& pads_proto = node.GetAttributes().at("pads").ints();
gsl::span<const int64_t> pads_values = gsl::make_span(pads_proto.data(), pads_proto.size());
expected_pads.resize(pads_values.size() - 4);
for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
expected_pads[index] = pads_values[pads_index];
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
}
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 0);
ASSERT_EQ(op_to_count["AveragePool"], 1);

for (auto& node : graph.Nodes()) {
if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
auto const& count_include_pad = node.GetAttributes().at("count_include_pad");
ASSERT_NE(count_include_pad.i(), 0) << "fusion should ensure count_include_pad!=0";
ASSERT_EQ(child_pads.size(), static_cast<int32_t>(expected_pads.size()))
<< "fusion should produce the same size of pads integer as the AvgPool node";
for (uint32_t index = 0; index < expected_pads.size(); index++) {
ASSERT_EQ(expected_pads[index], child_pads.Get(index))
<< "fusion does not produce correct padding value";
}
}
}
}

TEST_F(GraphTransformationTests, FusePadWithAvgPoolWithPad) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool_with_pad.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::vector<int64_t> expected_pads;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "Pad") {
auto const& pads_proto = node.GetAttributes().at("pads").ints();
gsl::span<const int64_t> pads_values = gsl::make_span(pads_proto.data(), pads_proto.size());
expected_pads.resize(pads_values.size() - 4);

for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
expected_pads[index] = pads_values[pads_index];
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
}
} else if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
for (uint32_t index = 0; index < expected_pads.size(); index++) {
expected_pads[index] += child_pads.Get(index);
}
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 0);
ASSERT_EQ(op_to_count["AveragePool"], 1);

for (auto& node : graph.Nodes()) {
if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
auto const& count_include_pad = node.GetAttributes().at("count_include_pad");
ASSERT_NE(count_include_pad.i(), 0) << "fusion should ensure count_include_pad!=0";
ASSERT_EQ(child_pads.size(), static_cast<int32_t>(expected_pads.size()))
<< "fusion should produce the same size of pads integer as the AvgPool node";
for (uint32_t index = 0; index < expected_pads.size(); index++) {
ASSERT_EQ(expected_pads[index], child_pads.Get(index))
<< "fusion does not produce correct padding value";
}
}
}
}

// should not fuse
TEST_F(GraphTransformationTests, FusePadWithAvgPoolWithPadNoInclude) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool_with_pad-nofuse.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 1);
ASSERT_EQ(op_to_count["AveragePool"], 1);
}

TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";

Expand Down
58 changes: 58 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/fuse-pad-avgpool-gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path
Fixed Show fixed Hide fixed

import numpy as np
import onnx

HERE = Path(__file__).parent.resolve(strict=True)
TEST = False

if TEST:
import onnxruntime
Dismissed Show dismissed Hide dismissed


def generate_fuse_pad_avgpool():
parameters = {
"fuse-pad-avgpool": (
{},
[[1.333333, 2.333333, 1.777778], [3., 5., 3.666667], [2.666667, 4.333333, 3.111111]],
),
"fuse-pad-avgpool_with_pad": (
{"pads": [1, 1, 0, 0], "count_include_pad": 1},
[[0.111111, 0.333333, 0.666667, 0.555556], [0.555556, 1.333333, 2.333333, 1.777778], [1.333333, 3., 5., 3.666667], [1.222222, 2.666667, 4.333333, 3.111111]],
),
"fuse-pad-avgpool_with_pad-nofuse": (
{"pads": [1, 1, 0, 0]},
[[0.25, 0.5, 1., 0.833333], [0.833333, 1.333333, 2.333333, 1.777778], [2., 3., 5., 3.666667], [1.833333, 2.666667, 4.333333, 3.111111]]
),
}
for name in parameters:
model_path = HERE / f"{name}.onnx"
input_ = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 1, 3, 3))
pad = onnx.helper.make_node("Pad", ["input"], ["tp"], mode="constant", pads=[0, 0, 1, 1, 0, 0, 1, 1])
pool = onnx.helper.make_node("AveragePool", ["tp"], ["output"], kernel_shape=[3, 3], **parameters[name][0])
nodes = [pad, pool]
output_shape = (1, 1, 3, 3) if name == "fuse-pad-avgpool" else (1, 1, 4, 4)
output_ = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, output_shape)
graph = onnx.helper.make_graph(nodes, name, [input_], [output_])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 7)])
onnx.checker.check_model(model)
onnx.save_model(model, model_path)
if TEST:
input_array = np.array([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=np.float32)
expected = np.array(parameters[name][1], dtype=np.float32)
session_options = onnxruntime.SessionOptions()
session_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(model_path, session_options)
out = session.run(["output"], {"input": input_array})
actual = out[0].squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=0.0)
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session = onnxruntime.InferenceSession(model_path, session_options)
out = session.run(["output"], {"input": input_array})
actual = out[0].squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=0.0)


if __name__ == "__main__":
generate_fuse_pad_avgpool()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading