diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index a1c7f8de9e6fe..e266946b0d9e0 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -12,7 +12,7 @@ namespace onnxruntime { * It matches following pattern: * Pad * | - * Conv/MaxPool + * Conv/MaxPool/AveragePool */ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { // if Pad has input axis, don't fuse it. @@ -28,6 +28,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log const Node& child_node = *node.OutputNodesBegin(); if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) && !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { return false; } diff --git a/onnxruntime/core/optimizer/pad_fusion.h b/onnxruntime/core/optimizer/pad_fusion.h index a1b6978a83d1e..ca05d219b7e2c 100644 --- a/onnxruntime/core/optimizer/pad_fusion.h +++ b/onnxruntime/core/optimizer/pad_fusion.h @@ -8,7 +8,7 @@ namespace onnxruntime { /* * This fusion submerges a Pad operator to it's child - * Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition() + * Conv or MaxPool or AveragePool operator, if and only if PadFusion::SatisfyCondition() * is true. */ class PadFusion : public RewriteRule {