From a2b9285204199b1bb8784ae8ff70641b5387cf2c Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 17 Aug 2021 23:16:54 -0700 Subject: [PATCH] [Relay] Extract dataflow matcher data structure into header (#8774) * extract dataflow matcher data structure into a header file * lint * lint --- src/relay/ir/dataflow_matcher.cc | 615 +++++++++++---------------- src/relay/ir/dataflow_matcher_impl.h | 164 +++++++ 2 files changed, 417 insertions(+), 362 deletions(-) create mode 100644 src/relay/ir/dataflow_matcher_impl.h diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 5ce06d9fefaa..d7f130f2796d 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -29,50 +29,12 @@ #include -#include "indexed_graph.h" +#include "dataflow_matcher_impl.h" namespace tvm { namespace relay { // Pattern Matcher - -class DominatorMatcher; - -class DFPatternMatcher : public DFPatternFunctor { - public: - explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} - bool Match(const DFPattern& pattern, const Expr& expr); - Map> GetMemo() { return Map>(memo_); } - const IndexedGraph expr_graph_; - - protected: - bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; - bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; - - void ClearMap(size_t watermark); - bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); - bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); - - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; - std::vector matched_nodes_; - bool memoize_ = true; -}; - bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { memo_.clear(); matched_nodes_.clear(); @@ -542,304 +504,251 @@ bool MatchPattern(DFPattern pattern, Expr expr) { TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); -/*! - * \brief PatternGrouper does pre-rewriting pattern matching and analysis - * - * This class creates a number of groups of matched expressions, ensures they don't overlap, and - * returns them to the caller for post-analysis rewriting. - * - * This is primarily needed to support the post-dominator analysis required for dominator pattern - * matching. - */ -class PatternGrouper { +/*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform + * group overlap analysis */ +class MatchExtractor : public ExprMutator { public: - /*! \brief Internal Group class for storing analysis */ - struct Group { - Expr root_node; - int gid; - Map> matched_nodes; - std::string name; - Function function; - Array args; - }; - - /*! \brief Return the group assignments of expressions */ - const std::unordered_map& GetGIDAssignments() { - return gid_assignments_; + explicit MatchExtractor( + const std::unordered_map& inputs) + : inputs_(inputs) {} + const std::unordered_map& GetMemo() { + return this->memo_; } - /*! \brief Group expressions that match the pattern */ - const std::unordered_map& GroupMatches(const DFPattern& pattern, const Expr& pre) { - groups_.clear(); - gid_assignments_.clear(); + const std::string& GetName() { return name_; } - pattern_ = pattern; - pattern_graph_ = CreateIndexedGraph(pattern_); - auto matcher = DFPatternMatcher(pre); - matcher_ = &matcher; - this->VisitExprs(); - return this->groups_; + protected: + Expr VisitExpr(const Expr& pre) override { + if (inputs_.count(pre)) { + return inputs_.at(pre); + } + return ExprMutator::VisitExpr(pre); } + Expr VisitExpr_(const TupleNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Tuple_"; + return out; + }; + Expr VisitExpr_(const FunctionNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Function"; + return out; + }; + Expr VisitExpr_(const CallNode* call_node) override { + auto out = ExprMutator::VisitExpr_(call_node); + if (auto operation = call_node->op.as()) { + name_ += operation->name + "_"; + } else { + name_ += "Call_"; + } + return out; + }; + Expr VisitExpr_(const LetNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Let_"; + return out; + }; + Expr VisitExpr_(const IfNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "If_"; + return out; + }; + Expr VisitExpr_(const TupleGetItemNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "TupleGetItem" + std::to_string(op->index) + "_"; + return out; + }; + Expr VisitExpr_(const MatchNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Match_"; + return out; + }; + std::string name_; + const std::unordered_map inputs_; +}; - protected: - /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs - * - * If we traverse the graph in post-order, we can run into situtations where a small subgraph will - * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in - * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph - * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order - * traversal. - */ - void VisitExprs() { - std::unordered_set pre_partitioned; - for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { - size_t index = i - 1; - Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; - if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped - if (auto op = current.as()) { - if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { - pre_partitioned.insert(current); - PostOrderVisit(op->body, - [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); }); - } - } - if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) { - CreateGroup(current); +/*! \brief Group expressions that match the pattern */ +const std::unordered_map& PatternGrouper::GroupMatches( + const DFPattern& pattern, const Expr& pre) { + groups_.clear(); + gid_assignments_.clear(); + + pattern_ = pattern; + pattern_graph_ = CreateIndexedGraph(pattern_); + auto matcher = DFPatternMatcher(pre); + matcher_ = &matcher; + this->VisitExprs(); + return this->groups_; +} + +void PatternGrouper::VisitExprs() { + std::unordered_set pre_partitioned; + for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped + if (auto op = current.as()) { + if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + pre_partitioned.insert(current); + PostOrderVisit(op->body, + [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); }); } } + if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) { + CreateGroup(current); + } } } - /*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform - * group overlap analysis */ - class MatchExtractor : public ExprMutator { - public: - explicit MatchExtractor( - const std::unordered_map& inputs) - : inputs_(inputs) {} - const std::unordered_map& GetMemo() { - return this->memo_; - } - const std::string& GetName() { return name_; } +} - protected: - Expr VisitExpr(const Expr& pre) override { - if (inputs_.count(pre)) { - return inputs_.at(pre); +void PatternGrouper::CreateGroup(const Expr& expr) { + int var_number = 0; + + auto node_map = matcher_->GetMemo(); + // Get fuzzy patterns + std::unordered_set fuzzy_matches; + for (auto node : pattern_graph_.topological_order_) { + // Don't treat fuzzy Dominator patterns input variables for partition + if (auto op = node->ref_.as()) { + for (auto fuzzy_op : {op->parent, op->path}) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } } - return ExprMutator::VisitExpr(pre); } - Expr VisitExpr_(const TupleNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Tuple_"; - return out; - }; - Expr VisitExpr_(const FunctionNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Function"; - return out; - }; - Expr VisitExpr_(const CallNode* call_node) override { - auto out = ExprMutator::VisitExpr_(call_node); - if (auto operation = call_node->op.as()) { - name_ += operation->name + "_"; - } else { - name_ += "Call_"; + // Don't treat Function params or body as input variables for partition + if (node->ref_.as()) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + auto graph = CreateIndexedGraph(match.as()->body); + for (auto node : graph.topological_order_) { + fuzzy_matches.insert(node->ref_); + } } - return out; - }; - Expr VisitExpr_(const LetNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Let_"; - return out; - }; - Expr VisitExpr_(const IfNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "If_"; - return out; - }; - Expr VisitExpr_(const TupleGetItemNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "TupleGetItem" + std::to_string(op->index) + "_"; - return out; - }; - Expr VisitExpr_(const MatchNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Match_"; - return out; - }; - std::string name_; - const std::unordered_map inputs_; - }; + } + } - /*! \brief Create a group based on a matched expression */ - void CreateGroup(const Expr& expr) { - int var_number = 0; - - auto node_map = matcher_->GetMemo(); - // Get fuzzy patterns - std::unordered_set fuzzy_matches; - for (auto node : pattern_graph_.topological_order_) { - // Don't treat fuzzy Dominator patterns input variables for partition - if (auto op = node->ref_.as()) { - for (auto fuzzy_op : {op->parent, op->path}) { - for (auto match : node_map[fuzzy_op]) { - fuzzy_matches.insert(match); - } - } + // Create input variables + Group group; + group.root_node = expr; + group.matched_nodes = node_map; + + std::unordered_map inputs; + Array params; + + for (auto node : pattern_graph_.topological_order_) { + auto make_input = [&](const Expr& input) { + if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && + input.as() == nullptr && !EmbedConst(input, node->ref_)) { + inputs[input] = + Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), + NullValue()); + group.args.push_back(input); + params.push_back(inputs[input]); + var_number++; } - // Don't treat Function params or body as input variables for partition - if (node->ref_.as()) { + }; + auto tuple = node->ref_.as(); + auto call = node->ref_.as(); + if (tuple && !tuple->fields.defined()) { + if (node_map.count(node->ref_)) { auto matches = node_map[node->ref_]; for (auto match : matches) { - auto graph = CreateIndexedGraph(match.as()->body); - for (auto node : graph.topological_order_) { - fuzzy_matches.insert(node->ref_); + for (auto input : match.as()->fields) { + make_input(input); } } } - } - - // Create input variables - Group group; - group.root_node = expr; - group.matched_nodes = node_map; - - std::unordered_map inputs; - Array params; - - for (auto node : pattern_graph_.topological_order_) { - auto make_input = [&](const Expr& input) { - if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && - input.as() == nullptr && !EmbedConst(input, node->ref_)) { - inputs[input] = - Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), - NullValue()); - group.args.push_back(input); - params.push_back(inputs[input]); - var_number++; - } - }; - auto tuple = node->ref_.as(); - auto call = node->ref_.as(); - if (tuple && !tuple->fields.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; - for (auto match : matches) { - for (auto input : match.as()->fields) { - make_input(input); - } - } - } - } else if (call && !call->args.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; - for (auto match : matches) { - for (auto input : match.as()->args) { - make_input(input); - } + } else if (call && !call->args.defined()) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + for (auto input : match.as()->args) { + make_input(input); } } - } else if (node->inputs_.size() == 0) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; - for (auto match : matches) { - make_input(match); - } + } + } else if (node->inputs_.size() == 0) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + make_input(match); } } } + } - graph_number_++; - - // Extract a Function. Used in Partition directly, - // used to determine Group overlap in other passes - auto extractor = MatchExtractor(inputs); - auto body = extractor.Mutate(expr); - - group.function = Function(params, body, NullValue(), Array()); - group.name = extractor.GetName(); - // Check to make sure we aren't overlapping with another group or creating an invalid fusion - // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the - // pattern with the input FunctionVar* Variables. The resulting memoization map will only - // contain nodes in the expression that matched the pattern. If a non-input node of the pattern - // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a - // situation where we try to rewrite the same node twice in the second rewriting or parition - // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants - // because they exist more globally outside of the fusion. - // Similiarly, if interior nodes in a group are used outside of the group fusing to a single - // output would create an invalid graph tranformation, so we block the creation of such groups. - auto memo = extractor.GetMemo(); - for (auto kv : memo) { - // Check to ensure that this node isn't an input or a global - if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && - kv.first.as() == nullptr && kv.first.as() == nullptr) { - if (gid_assignments_.count(kv.first) != 0) { - // check to see if the node is use in other groups - // Exit due to overlapping partitions - return; - } else if (kv.second != body) { - // if the node isn't the output of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); - for (auto* output : node->outputs_) { - // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { - // Exit because nodes in this pattern's body are used outside the pattern - // fusing it would be invalid - return; - } + graph_number_++; + + // Extract a Function. Used in Partition directly, + // used to determine Group overlap in other passes + auto extractor = MatchExtractor(inputs); + auto body = extractor.Mutate(expr); + + group.function = Function(params, body, NullValue(), Array()); + group.name = extractor.GetName(); + // Check to make sure we aren't overlapping with another group or creating an invalid fusion + // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the + // pattern with the input FunctionVar* Variables. The resulting memoization map will only + // contain nodes in the expression that matched the pattern. If a non-input node of the pattern + // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a + // situation where we try to rewrite the same node twice in the second rewriting or parition + // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants + // because they exist more globally outside of the fusion. + // Similiarly, if interior nodes in a group are used outside of the group fusing to a single + // output would create an invalid graph tranformation, so we block the creation of such groups. + auto memo = extractor.GetMemo(); + for (auto kv : memo) { + // Check to ensure that this node isn't an input or a global + if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && + kv.first.as() == nullptr && kv.first.as() == nullptr) { + if (gid_assignments_.count(kv.first) != 0) { + // check to see if the node is use in other groups + // Exit due to overlapping partitions + return; + } else if (kv.second != body) { + // if the node isn't the output of the group + auto node = matcher_->expr_graph_.node_map_.at(kv.first); + for (auto* output : node->outputs_) { + // and the node is used by nodes outside of the group + if (memo.count(output->ref_) == 0 && + !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { + // Exit because nodes in this pattern's body are used outside the pattern + // fusing it would be invalid + return; } } } } - // Assign Group Ids - group.gid = ++gid_; - for (auto kv : extractor.GetMemo()) { - gid_assignments_[kv.first] = gid_; - } + } + // Assign Group Ids + group.gid = ++gid_; + for (auto kv : extractor.GetMemo()) { + gid_assignments_[kv.first] = gid_; + } + + // Save Group + groups_[group.gid] = std::move(group); +} - // Save Group - groups_[group.gid] = std::move(group); - } - - /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or - * lifting them into the function arguments. - * - * The rules depend on what pattern the ConstantNode matched. - * - * The basic rules are: - * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant - * in the partitioned function. If the constant matched an AltPattern, recursively check the - * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc), - * lift the constant into the arguments of the partitioned function. - */ - bool EmbedConst(const Expr& expr, const DFPattern pattern) { - bool embed = false; - if (expr.as()) { - if (pattern.as() != nullptr) { +bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) { + bool embed = false; + if (expr.as()) { + if (pattern.as() != nullptr) { + embed = true; + } else if (auto expr_pat = pattern.as()) { + if (expr_pat->expr.as()) { embed = true; - } else if (auto expr_pat = pattern.as()) { - if (expr_pat->expr.as()) { - embed = true; - } - } else if (auto alt_pat = pattern.as()) { - if (matcher_->Match(alt_pat->left, expr)) { - embed = EmbedConst(expr, alt_pat->left); - } else { - embed = EmbedConst(expr, alt_pat->right); - } + } + } else if (auto alt_pat = pattern.as()) { + if (matcher_->Match(alt_pat->left, expr)) { + embed = EmbedConst(expr, alt_pat->left); + } else { + embed = EmbedConst(expr, alt_pat->right); } } - return embed; } - // Internal State - DFPattern pattern_; - std::unordered_map groups_; - std::unordered_map gid_assignments_; - DFPatternMatcher* matcher_ = nullptr; - IndexedGraph pattern_graph_; - int gid_ = 0; - int graph_number_ = 0; -}; + return embed; +} // Rewrite @@ -858,72 +767,54 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") return DFPatternCallback(pattern, function, require_type); }); -/*! - * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback - * function to rewrite those matches - * - * The class uses PatternGrouper to support the dominator pattern. - */ -class PatternRewriter : protected MixedModeMutator { - public: - PatternRewriter(IRModule mod) : mod_(mod) {} - /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the - * callbacks until it stops changing */ - Expr Rewrite(const Array& callbacks, const Expr& pre) { - auto post = pre; - auto last = post; - // rewrite the graph until it stops changing to make sure all rewrites are complete - int count = 0; - bool equal = true; - static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); - ICHECK(structural_equal) << "node.StructuralEqual is not registered."; - do { - last = post; - for (auto callback : callbacks) { - callback_ = callback; - if (callback_->require_type) { - post = InferTypeWithModule(post, mod_); - } - auto grouper = PatternGrouper(); - groups_ = grouper.GroupMatches(callback_->pattern, post); - gid_assignments_ = grouper.GetGIDAssignments(); - memo_.clear(); - post = this->VisitExpr(post); - count++; - } - equal = (*structural_equal)(last, post, false, true); - } while (!equal && count < 100); - if (count >= 100) { - LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; +Expr PatternRewriter::Rewrite(const Array& callbacks, const Expr& pre) { + auto post = pre; + auto last = post; + // rewrite the graph until it stops changing to make sure all rewrites are complete + int count = 0; + bool equal = true; + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + do { + last = post; + for (auto callback : callbacks) { + callback_ = callback; + if (callback_->require_type) { + post = InferTypeWithModule(post, mod_); + } + auto grouper = PatternGrouper(); + groups_ = grouper.GroupMatches(callback_->pattern, post); + gid_assignments_ = grouper.GetGIDAssignments(); + memo_.clear(); + post = this->VisitExpr(post); + count++; } - return post; + equal = (*structural_equal)(last, post, false, true); + } while (!equal && count < 100); + if (count >= 100) { + LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } + return post; +} - protected: - Expr DispatchVisitExpr(const Expr& pre) override { - auto post = MixedModeMutator::DispatchVisitExpr(pre); - if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { - // Convert the pre-rewrite node map to a post-rewrite node map - auto group = groups_[gid_assignments_[pre]]; - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map; - for (auto kv : group.matched_nodes) { - Array tmp; - for (size_t i = 0; i < kv.second.size(); ++i) { - tmp.push_back(this->memo_[kv.second[i]]); - } - node_map.insert({kv.first, tmp}); +Expr PatternRewriter::DispatchVisitExpr(const Expr& pre) { + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + // Convert the pre-rewrite node map to a post-rewrite node map + auto group = groups_[gid_assignments_[pre]]; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map; + for (auto kv : group.matched_nodes) { + Array tmp; + for (size_t i = 0; i < kv.second.size(); ++i) { + tmp.push_back(this->memo_[kv.second[i]]); } - // run the user callback function - return callback_->function(pre, post, Map>(node_map)); + node_map.insert({kv.first, tmp}); } - return post; + // run the user callback function + return callback_->function(pre, post, Map>(node_map)); } - - IRModule mod_; - DFPatternCallback callback_; - std::unordered_map groups_; - std::unordered_map gid_assignments_; -}; + return post; +} Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod) { return PatternRewriter(mod).Rewrite(callbacks, expr); diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h new file mode 100644 index 000000000000..d993d4720e4e --- /dev/null +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher_impl.h + * \brief The auxiliary data structure for dataflow matcher. + */ +#ifndef TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ +#define TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ + +#include +#include +#include + +#include +#include +#include + +#include "indexed_graph.h" + +namespace tvm { +namespace relay { + +class DFPatternMatcher : public DFPatternFunctor { + public: + explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + const IndexedGraph expr_graph_; + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); + bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; + std::vector matched_nodes_; + bool memoize_ = true; +}; + +/*! + * \brief PatternGrouper does pre-rewriting pattern matching and analysis + * + * This class creates a number of groups of matched expressions, ensures they don't overlap, and + * returns them to the caller for post-analysis rewriting. + * + * This is primarily needed to support the post-dominator analysis required for dominator pattern + * matching. + */ +class PatternGrouper { + public: + /*! \brief Internal Group class for storing analysis */ + struct Group { + Expr root_node; + int gid; + Map> matched_nodes; + std::string name; + Function function; + Array args; + }; + + /*! \brief Return the group assignments of expressions */ + inline const std::unordered_map& GetGIDAssignments() { + return gid_assignments_; + } + /*! \brief Group expressions that match the pattern */ + const std::unordered_map& GroupMatches(const DFPattern& pattern, const Expr& pre); + + protected: + /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs + * + * If we traverse the graph in post-order, we can run into situtations where a small subgraph will + * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in + * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph + * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order + * traversal. + */ + void VisitExprs(); + + /*! \brief Create a group based on a matched expression */ + void CreateGroup(const Expr& expr); + + /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or + * lifting them into the function arguments. + * + * The rules depend on what pattern the ConstantNode matched. + * + * The basic rules are: + * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant + * in the partitioned function. If the constant matched an AltPattern, recursively check the + * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc), + * lift the constant into the arguments of the partitioned function. + */ + bool EmbedConst(const Expr& expr, const DFPattern pattern); + // Internal State + DFPattern pattern_; + std::unordered_map groups_; + std::unordered_map gid_assignments_; + DFPatternMatcher* matcher_ = nullptr; + IndexedGraph pattern_graph_; + int gid_ = 0; + int graph_number_ = 0; +}; + +/*! + * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback + * function to rewrite those matches + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternRewriter : protected MixedModeMutator { + public: + explicit PatternRewriter(IRModule mod) : mod_(mod) {} + /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the + * callbacks until it stops changing */ + virtual Expr Rewrite(const Array& callbacks, const Expr& pre); + + protected: + virtual Expr DispatchVisitExpr(const Expr& pre); + + IRModule mod_; + DFPatternCallback callback_; + std::unordered_map groups_; + std::unordered_map gid_assignments_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_