Skip to content

Commit

Permalink
Addressed PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
sumitsays committed Aug 7, 2024
1 parent 3432a67 commit 82dd841
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace onnxruntime {

bool VerifyNonCastChild(const Node& child_node) {
bool VerifyNotCastChild(const Node& child_node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
Expand Down Expand Up @@ -39,7 +39,7 @@ bool VerifyNonCastChild(const Node& child_node) {
return true;
}

void Update_Pad_Attribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());

Expand Down Expand Up @@ -113,9 +113,9 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
if (graph.NodeProducesGraphOutput(child_node)) {
return false;
}
return VerifyNonCastChild(*child_node.OutputNodesBegin());
return VerifyNotCastChild(*child_node.OutputNodesBegin());
} else {
return VerifyNonCastChild(child_node);
return VerifyNotCastChild(child_node);
}
}

Expand Down Expand Up @@ -146,13 +146,10 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
}

Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index());
if (child_node.OpType() == "Cast") {
// We don't need to cast the pad_constant_value because this fusion requires that constant_pad_value
// to be zero. See PadFusion::SatisfyCondition for details.
Update_Pad_Attribute(*graph.GetNode(child_node.OutputNodesBegin()->Index()), pads_values, pads_size);
} else {
Update_Pad_Attribute(child_node, pads_values, pads_size);
}
// We don't need to cast the pad_constant_value because this fusion requires that constant_pad_value
// to be zero. See PadFusion::SatisfyCondition for details.
Node& target_padding_node = (child_node.OpType() == "Cast") ? *graph.GetNode(child_node.OutputNodesBegin()->Index()) : child_node;

Check warning on line 151 in onnxruntime/core/optimizer/pad_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/optimizer/pad_fusion.cc:151: Lines should be <= 120 characters long [whitespace/line_length] [2]
UpdatePaddingAttribute(target_padding_node, pads_values, pads_size);

graph_utils::RemoveNodeOutputEdges(graph, pad_node);
graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]);
Expand Down

0 comments on commit 82dd841

Please sign in to comment.