Skip to content

Commit

Permalink
If an expression has two branches, and the pattern ignores one with a…
Browse files Browse the repository at this point in the history
… wildcard, allow grouping via dominator analysis (apache#7355)
  • Loading branch information
Matthew Brookhart authored and alexwong committed Feb 11, 2021
1 parent a09f459 commit 38851c2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ class PatternGrouper {
auto node = matcher_->expr_graph_.node_map_.at(kv.first);
for (auto* output : node->outputs_) {
// and the node is used by nodes outside of the group
if (memo.count(output->ref_) == 0) {
if (memo.count(output->ref_) == 0 &&
!matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
// Exit because nodes in this pattern's body are used outside the pattern
// fusing it would be invalid
return;
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/indexed_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/dataflow_pattern.h>

#include <memory>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -74,6 +75,27 @@ class IndexedGraph {
Node* dominator_parent_;
/*! \brief The nodes this node dominates */
std::vector<Node*> dominator_children_;

bool Dominates(const Node* other) {
std::stack<const Node*> stack;
std::unordered_set<const Node*> visited;
stack.push(this);
while (!stack.empty()) {
const Node* current = stack.top();
stack.pop();
for (auto node : current->dominator_children_) {
if (visited.count(node) == 0) {
if (other == node) {
return true;
} else {
stack.push(node);
}
visited.insert(node);
}
}
}
return false;
}
};
/*! \brief Construct the domination tree inside IndexedGraph */
void PostDom() {
Expand Down
71 changes: 71 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=unused-wildcard-import
import numpy as np
import pytest

import tvm
from tvm import relay
Expand Down Expand Up @@ -1470,6 +1471,76 @@ def test_partition_function():
assert tvm.ir.structural_equal(pattern.partition(expr), expr2)


def test_rewrite_function_with_fuzzy_body():
"""Allow Rewriting a function with a fuzzy body via dominator analysis"""
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")

x1 = relay.var("x1")
w1 = relay.var("w1")

wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()

func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b

func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b

class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
self.pattern = pattern

def callback(self, pre, post, node_map):
return x + w

out = rewrite(TestRewrite(), expr)
assert tvm.ir.structural_equal(x + w, x + w)


@pytest.mark.skip(
"""TODO(mbrookhart): The current partitioner can't properly handle
the partitioned inputs on the fuzzy body"""
)
def test_partition_function_with_fuzzy_body():
"""
Allow Rewriting a function with a fuzzy body via dominator analysis
"""
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")

x1 = relay.var("x1")
w1 = relay.var("w1")

wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()

func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b

func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b

x2 = relay.var("x2")
w2 = relay.var("w2")
b2 = relay.var("b2")
func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr(
"PartitionedFromPattern", "FunctionCall_add_"
)
expr2 = func2(x, w, b) + b
assert tvm.ir.structural_equal(pattern.partition(expr), expr2)


def test_match_match():
add_pattern = is_op("add")(wildcard(), wildcard())

Expand Down

0 comments on commit 38851c2

Please sign in to comment.