Skip to content

Commit

Permalink
[BugFix][Relay] skip leaf args when matching 'path' part for dominato…
Browse files Browse the repository at this point in the history
…r pattern (#16983)

* [BugFix][Relay] skip leaf args when matching 'path' part for dominator pattern

* add testcase
  • Loading branch information
wanghuibin0 authored Jun 21, 2024
1 parent 36b9535 commit e6bfaf8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,17 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex

// Recursively find the Dominator parent along all inputs paths.
bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
// utilities
auto is_leaf_node = [](const Expr& expr) {
return expr.as<ConstantNode>() || expr.as<VarNode>();
};

// logic
auto call_node = expr.as<CallNode>();
auto index_node = expr_to_node(expr);
size_t arg_counter{0};
for (auto node : index_node->inputs_) {
if (!(call_node && node->ref() == call_node->op)) {
if (!(call_node && (node->ref() == call_node->op || is_leaf_node(node->ref())))) {
arg_counter += 1;
memoize_ = true;
if (!VisitDFPattern(op->parent, node->ref())) {
Expand Down
24 changes: 23 additions & 1 deletion tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1

K_INJECTIVE = 2

## NODE TESTS
def test_expr_pattern():
Expand Down Expand Up @@ -696,6 +696,28 @@ def test_match_dominator():
assert diamond.match(out)


def test_match_dominator2():
# Pattern
conv2d_pat = is_op("nn.conv2d")(wildcard(), wildcard())
eltwise_pat = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(None)
broadcast_pat = (wildcard().has_attr({"TOpPattern": K_BROADCAST}))(None)
path_pat = eltwise_pat | broadcast_pat
injective_pat = (wildcard().has_attr({"TOpPattern": K_INJECTIVE}))(wildcard())
pattern = injective_pat.dominates(conv2d_pat, path_pat)

# Graph
inp = relay.var("input")
weight = relay.var("weight")
bias = relay.var("bias")
conv2d = relay.op.nn.conv2d(inp, weight)
bias_add = relay.op.nn.bias_add(conv2d, bias)
relu = relay.op.nn.relu(bias_add)
reshape = relay.op.reshape(relu, newshape=[-1, 2, 8])

# Check
assert pattern.match(reshape)


def test_not_match_dominator():
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
Expand Down

0 comments on commit e6bfaf8

Please sign in to comment.