Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang] Add If pattern #7282

Merged
merged 5 commits into from
Jan 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ The next example is matching function nodes with a specific attribute:
f = relay.Function([x, y], x + y).with_attr("Composite", "add")
assert pattern.match(f)

A Relay ``If`` expression can be matched if all of its condition, true branch and false branch
are matched:

.. code-block:: python

def test_match_if():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")
cond = x < y

assert pat.match(relay.expr.If(cond, x, y))

Matching Diamonds and Post-Dominator Graphs
*******************************************
Expand Down Expand Up @@ -294,6 +309,7 @@ The high level design is to introduce a language of patterns for now we propose
| is_op(op_name)
| is_tuple()
| is_tuple_get_item(pattern, index = None)
| is_if(cond, tru, fls)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
| FunctionPattern(params, body)
Expand Down
20 changes: 20 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,26 @@ class TupleGetItemPatternNode : public DFPatternNode {
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode);
};

class IfPatternNode : public DFPatternNode {
public:
DFPattern cond, true_branch, false_branch;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.IfPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(IfPatternNode, DFPatternNode);
};

class IfPattern : public DFPattern {
public:
TVM_DLL IfPattern(DFPattern cond, DFPattern then_clause, DFPattern else_clause);
TVM_DEFINE_OBJECT_REF_METHODS(IfPattern, DFPattern, IfPatternNode);
};

class TupleGetItemPattern : public DFPattern {
public:
TVM_DLL TupleGetItemPattern(DFPattern tuple, int index);
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
Expand All @@ -116,6 +117,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
Expand Down Expand Up @@ -144,6 +146,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,29 @@ def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) ->
return TupleGetItemPattern(tuple_value, index)


def is_if(cond, true_branch, false_branch):
"""
Syntatic sugar for creating an IfPattern.

Parameters
----------
cond: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the condition of If.

true_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the true branch of If.

false_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the false branch of If.

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting pattern.
"""
return IfPattern(cond, true_branch, false_branch)


def wildcard() -> "DFPattern":
"""
Syntatic sugar for creating a WildcardPattern.
Expand Down Expand Up @@ -536,6 +559,26 @@ def __init__(
self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body)


@register_df_node
class IfPattern(DFPattern):
"""A patern matching a Relay If.

Parameters
----------
cond: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the condition of If.

true_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the true branch of If.

false_branch: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the false branch of If.
"""

def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "DFPattern"):
self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch)


@register_df_node
class TuplePattern(DFPattern):
"""A patern matching a Relay Tuple.
Expand Down
12 changes: 12 additions & 0 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
Expand Down Expand Up @@ -407,6 +408,17 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e
return matches;
}

bool DFPatternMatcher::VisitDFPattern_(const IfPatternNode* op, const Expr& expr) {
if (const auto* if_node = expr.as<IfNode>()) {
auto cond = if_node->cond;
auto true_branch = if_node->true_branch;
auto false_branch = if_node->false_branch;
return VisitDFPattern(op->cond, cond) && VisitDFPattern(op->true_branch, true_branch) &&
VisitDFPattern(op->false_branch, false_branch);
}
return false;
}

Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
});

IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(IfPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern")
.set_body_typed([](DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
return IfPattern(cond, true_branch, false_branch);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IfPatternNode*>(ref.get());
p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << ", "
<< node->false_branch << ")";
});

TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) {
ObjectPtr<TuplePatternNode> n = make_object<TuplePatternNode>();
n->fields = std::move(fields);
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) {
}
}

void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) {
VisitDFPattern(op->cond);
VisitDFPattern(op->true_branch);
VisitDFPattern(op->false_branch);
}

void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
}
}

void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->cond, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->true_branch, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->false_branch, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def test_AttrPattern():
assert op.attrs["TOpPattern"] == K_ELEMWISE


def test_IfPattern():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

assert isinstance(pat, IfPattern)
assert isinstance(pat.cond, CallPattern)
assert isinstance(pat.true_branch, VarPattern)
assert isinstance(pat.false_branch, VarPattern)


## MATCHER TESTS


Expand Down Expand Up @@ -198,6 +209,30 @@ def test_no_match_func():
assert not func_pattern.match(relay.Function([x, y], x - y))


def test_match_if():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")
cond = x < y

assert pat.match(relay.expr.If(cond, x, y))


def test_no_match_if():
x = is_var("x")
y = is_var("y")
pat = is_if(is_op("less")(x, y), x, y)

x = relay.var("x")
y = relay.var("y")

assert not pat.match(relay.expr.If(x > y, x, y))
assert not pat.match(relay.expr.If(x < y, y, x))


def test_match_option():
x = relay.var("x")
w = relay.var("w")
Expand Down Expand Up @@ -1512,3 +1547,6 @@ def test_partition_constant_embedding():
test_partition_option()
test_match_match()
test_partition_constant_embedding()
test_IfPattern()
test_match_if()
test_no_match_if()