From 36fd2dd8a78b5c06a4cd3ee04b93a4f34c6b66f6 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Mon, 11 Jul 2022 13:57:55 -0700 Subject: [PATCH] [Collage] SubGraphs (#11981) * [Collage] SubGraphs See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md. Collage works in units of 'sub-graphs', which are potential partitions of the overall Relay model. This PR introduces SubGraph (an arbitrary partitioning, without any implication about how it is to be represented), it's companion SubSubGraph (implying a representation as a function), and some supporting odds 'n ends. * - make Integer <-> size_t conversion explicit - make 'Compiler' name explicit * - fix namespace ambiguity * - review comments --- CMakeLists.txt | 1 + src/relay/collage/README.md | 26 + src/relay/collage/dataflow_graph.cc | 48 + src/relay/collage/dataflow_graph.h | 77 ++ src/relay/collage/index_set.cc | 231 ++++ src/relay/collage/index_set.h | 128 +++ src/relay/collage/sub_graph.cc | 1034 ++++++++++++++++++ src/relay/collage/sub_graph.h | 452 ++++++++ src/relay/collage/utils.cc | 139 +++ src/relay/collage/utils.h | 86 ++ tests/python/relay/collage/test_sub_graph.py | 387 +++++++ 11 files changed, 2609 insertions(+) create mode 100644 src/relay/collage/README.md create mode 100644 src/relay/collage/dataflow_graph.cc create mode 100644 src/relay/collage/dataflow_graph.h create mode 100644 src/relay/collage/index_set.cc create mode 100644 src/relay/collage/index_set.h create mode 100644 src/relay/collage/sub_graph.cc create mode 100644 src/relay/collage/sub_graph.h create mode 100644 src/relay/collage/utils.cc create mode 100644 src/relay/collage/utils.h create mode 100644 tests/python/relay/collage/test_sub_graph.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 46de8f5d07fa..8dc03ee0f40e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -296,6 +296,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_OP_SRCS ) tvm_file_glob(GLOB_RECURSE RELAY_PASS_SRCS src/relay/analysis/*.cc + src/relay/collage/*.cc src/relay/transforms/*.cc src/relay/quantize/*.cc ) diff --git a/src/relay/collage/README.md b/src/relay/collage/README.md new file mode 100644 index 000000000000..dc56496092cc --- /dev/null +++ b/src/relay/collage/README.md @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + +The `CollagePartition` pass for finding optimal partitionings of Relay models. + +See the [RFC](https://github.com/mbs-octoml/mbs-tvm-rfcs/blob/mbs-rfcs-collage/rfcs/xxxx-collage.md). + +Based on: +> *Collage: Automated Integration of Deep Learning Backends* +> Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia + +CAUTION: This is a prototype, do not use in prod. diff --git a/src/relay/collage/dataflow_graph.cc b/src/relay/collage/dataflow_graph.cc new file mode 100644 index 000000000000..b4e19a73f04d --- /dev/null +++ b/src/relay/collage/dataflow_graph.cc @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/dataflow_graph.cc + * \brief A representation of the dataflow for an overall Relay expression. + */ + +#include "./dataflow_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +DataflowGraph::DataflowGraph(Expr expr) : expr_(std::move(expr)) { + indexed_graph_ = CreateIndexedGraph(expr_); + downstream_map_.reserve(indexed_graph_->size()); + for (PostDfsIndex index = 0; index < indexed_graph_->size(); ++index) { + const Node* node = indexed_graph_->index_to_node(index); + std::unordered_set downstream_nodes; + node->AccumulateDownstreamNodes(&downstream_nodes); + IndexSet index_set(indexed_graph_->size()); + for (const Node* downstream_node : downstream_nodes) { + index_set.Add(downstream_node->index_); + } + downstream_map_.emplace_back(std::move(index_set)); + } +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/dataflow_graph.h b/src/relay/collage/dataflow_graph.h new file mode 100644 index 000000000000..c3c22381a889 --- /dev/null +++ b/src/relay/collage/dataflow_graph.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/dataflow_graph.h + * \brief A representation of the dataflow for an overall Relay expression. + */ +#ifndef TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_ +#define TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_ + +#include + +#include +#include + +#include "../ir/indexed_graph.h" +#include "./index_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Represents the dataflow of an overall Relay expression. + */ +class DataflowGraph { + public: + using Node = IndexedGraph::Node; + + explicit DataflowGraph(Expr expr); + + size_t size() const { return indexed_graph_->size(); } + const Node* index_to_node(PostDfsIndex index) const { + return indexed_graph_->index_to_node(index); + } + const Node* item_to_node(const Expr& expr) const { return indexed_graph_->item_to_node(expr); } + const Node* item_to_node(const ExprNode* expr_node) const { + return indexed_graph_->item_to_node(expr_node); + } + const Expr& expr() const { return expr_; } + const IndexedGraph& indexed_graph() const { return *indexed_graph_; } + + const IndexSet& downstream_of(PostDfsIndex index) const { + ICHECK_LT(index, indexed_graph_->size()); + return downstream_map_[index]; + } + + private: + /*! \brief The overall expression. */ + Expr expr_; + /*! \brief The indexed graph which captures the main dataflow. */ + std::unique_ptr> indexed_graph_; + /*! \brief Map from a node's PostDfsIndex to the set of its downstream dataflow node indexes. */ + std::vector downstream_map_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_ diff --git a/src/relay/collage/index_set.cc b/src/relay/collage/index_set.cc new file mode 100644 index 000000000000..55bec80820a4 --- /dev/null +++ b/src/relay/collage/index_set.cc @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/index_set.cc + * \brief Efficient representation of a set of post-dfs indexes. + */ + +#include "./index_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +// TODO(mbs): These should operate one-word-at-a-time + +IndexSet::IndexSet(size_t size, const std::vector& indexes) : bitvec_(size, false) { + for (size_t index : indexes) { + ICHECK_LT(index, bitvec_.size()); + ICHECK(!bitvec_[index]); + bitvec_[index] = true; + } +} + +IndexSet IndexSet::operator&(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + std::vector result(bitvec_.size(), false); + for (size_t index = 0; index < bitvec_.size(); ++index) { + result[index] = bitvec_[index] && that.bitvec_[index]; + } + return IndexSet(result); +} + +IndexSet IndexSet::operator|(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + std::vector result(bitvec_.size(), false); + for (size_t index = 0; index < bitvec_.size(); ++index) { + result[index] = bitvec_[index] || that.bitvec_[index]; + } + return IndexSet(result); +} + +IndexSet IndexSet::operator-(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + std::vector result(bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); ++index) { + result[index] = bitvec_[index] && !that.bitvec_[index]; + } + return IndexSet(result); +} + +bool IndexSet::AreDisjoint(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && that.bitvec_[index]) { + return false; + } + } + return true; +} + +bool IndexSet::IsSubset(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && !that.bitvec_[index]) { + return false; + } + } + return true; +} + +bool IndexSet::Intersects(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && that.bitvec_[index]) { + return true; + } + } + return false; +} + +IndexSet IndexSet::Subst(size_t new_size, const IndexSubst& subst) const { + std::vector result(new_size, false); + for (PostDfsIndex index = 0; index < bitvec_.size(); ++index) { + if (!bitvec_[index]) { + continue; + } + auto itr = subst.find(index); + ICHECK(itr != subst.end()); + PostDfsIndex new_index = itr->second; + ICHECK(new_index < new_size); + ICHECK(!result[new_index]); + result[new_index] = true; + } + return IndexSet(result); +} + +size_t IndexSet::PopCount() const { + size_t n = 0; + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + ++n; + } + } + return n; +} + +bool IndexSet::IsZero() const { + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + return false; + } + } + return true; +} + +size_t IndexSet::FirstInsideIndex() const { + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +size_t IndexSet::LastInsideIndex() const { + for (size_t i = bitvec_.size(); i > 0; i--) { + const size_t index = i - 1; + if (bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +size_t IndexSet::NextIndex(size_t index) const { + ICHECK_LT(index, bitvec_.size()); + for (index++; index < bitvec_.size(); index++) { + if (bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +size_t IndexSet::FirstOutsideIndex() const { + for (size_t index = 0; index < bitvec_.size(); index++) { + if (!bitvec_[index]) { + return index; + } + } + return bitvec_.size(); +} + +bool IndexSet::operator==(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + return bitvec_ == that.bitvec_; +} + +bool IndexSet::operator!=(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + return bitvec_ != that.bitvec_; +} + +bool IndexSet::operator<(const IndexSet& that) const { + ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); + for (size_t index = 0; index < bitvec_.size(); index++) { + if (bitvec_[index] && !that.bitvec_[index]) { + return true; + } + if (!bitvec_[index] && that.bitvec_[index]) { + return false; + } + } + return false; +} + +size_t IndexSet::hash() const { + std::hash> h; + return h(bitvec_); +} + +std::string IndexSet::ToString() const { + std::ostringstream os; + os << "{"; + bool first = true; + for (size_t start = 0; start < bitvec_.size(); /*no-op*/) { + if (!bitvec_[start]) { + ++start; + continue; + } + size_t end; + for (end = start + 1; end < bitvec_.size() && bitvec_[end]; ++end) { + /*no-op*/ + } + if (first) { + first = false; + } else { + os << ","; + } + os << start; + if (end > start + 2) { + os << ".." << (end - 1); + start = end; + } else { + ++start; + } + } + os << "}"; + return os.str(); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/index_set.h b/src/relay/collage/index_set.h new file mode 100644 index 000000000000..f24b695cc76c --- /dev/null +++ b/src/relay/collage/index_set.h @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/index_set.h + * \brief Efficient representation of a set of post-dfs indexes. + */ + +#ifndef TVM_RELAY_COLLAGE_INDEX_SET_H_ +#define TVM_RELAY_COLLAGE_INDEX_SET_H_ + +#include +#include +#include +#include + +#include "../ir/dataflow_matcher_impl.h" +#include "../ir/indexed_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +using IndexSubst = std::unordered_map; + +class IndexSet { + public: + IndexSet() = default; + explicit IndexSet(size_t size) : bitvec_(size, false) {} + IndexSet(size_t size, const std::vector& indexes); + + IndexSet operator&(const IndexSet& that) const; + IndexSet operator|(const IndexSet& that) const; + IndexSet operator-(const IndexSet& that) const; + bool AreDisjoint(const IndexSet& that) const; + bool IsSubset(const IndexSet& that) const; + bool Intersects(const IndexSet& that) const; + + bool operator[](size_t index) const { + ICHECK_LT(index, bitvec_.size()); + return bitvec_[index]; + } + + IndexSet& Add(size_t index) { + ICHECK_LT(index, bitvec_.size()); + bitvec_[index] = true; + return *this; + } + + IndexSet Subst(size_t new_size, const IndexSubst& subst) const; + + size_t end_index() const { return bitvec_.size(); } + size_t PopCount() const; + bool IsZero() const; + size_t FirstInsideIndex() const; + size_t LastInsideIndex() const; + size_t NextIndex(size_t index) const; + size_t FirstOutsideIndex() const; + bool operator==(const IndexSet& that) const; + bool operator!=(const IndexSet& that) const; + bool operator<(const IndexSet& that) const; + size_t hash() const; + std::string ToString() const; + + struct IndexSetIterator { + const IndexSet* set; + size_t i; + + size_t operator*() const { + ICHECK_LT(i, set->end_index()); + return i; + } + + const IndexSetIterator& operator++() { + ICHECK_LT(i, set->end_index()); + i = set->NextIndex(i); + return *this; + } + + bool operator==(const IndexSetIterator& that) const { + ICHECK(set == that.set); + return i == that.i; + } + + bool operator!=(const IndexSetIterator& that) const { + ICHECK(set == that.set); + return i != that.i; + } + }; + + IndexSetIterator begin() const { return IndexSetIterator{this, FirstInsideIndex()}; } + IndexSetIterator end() const { return IndexSetIterator{this, end_index()}; } + + private: + explicit IndexSet(std::vector bitvec) : bitvec_(std::move(bitvec)) {} + + std::vector bitvec_; +}; + +struct IndexSetEqual { + bool operator()(const IndexSet& left, const IndexSet& right) const { return left == right; } +}; + +struct IndexSetHash { + size_t operator()(const IndexSet& set) const { return set.hash(); } +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_INDEX_SET_H_ diff --git a/src/relay/collage/sub_graph.cc b/src/relay/collage/sub_graph.cc new file mode 100644 index 000000000000..63edc8c079fb --- /dev/null +++ b/src/relay/collage/sub_graph.cc @@ -0,0 +1,1034 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/sub_graph.cc + * \brief Represents a sub-graph of an overall Relay expression. + */ + +#include "./sub_graph.h" + +#include + +#include "../../support/scalars.h" +#include "../transforms/pass_utils.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +namespace { + +class Extractor; + +/*! + * \brief Helper class for rewriting expressions to replace a sub-graph according to the + * given extractor. + */ +class Rewriter : public ExprMutator { + public: + explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {} + + Expr VisitExpr(const Expr& expr) final; + + private: + /*! \brief Already prepared extractor which will guide the rewrite. */ + const Extractor* extractor_; +}; + +/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */ +class Extractor : public ExprMutator { + public: + Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph, + FunctionAttrsMap opt_attrs) + : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) { + ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size()); + } + + const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; } + + /*! + * \brief Collect the parameters and output expressions for the function representing + * the sub-graph. + */ + void Extract() { + ICHECK(!sub_graph_->IsEmpty()); + VLOG(2) << "Extracting " << sub_graph_->ToString(); + const bool for_function = opt_attrs_.defined(); + + // In reverse dataflow order... + for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) { + PostDfsIndex index = i - 1; + if (!sub_graph_->inside_[index]) { + // Node is outside sub-graph. + continue; + } + VLOG(2) << "index " << index; + auto node = dataflow_graph_->index_to_node(index); + if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) { + // This sub-expression is: + // - inside the sub-graph and needed outside the sub-graph. So it must contribute to an + // output (even if we've already visited it while constructing an output from a + // downstream sub-expression). + // - not yet visited, in which case it must still be considered an 'output' so it will + // be evaluated for any possible side effects. + Expr output = VisitExpr(GetRef(node->node_ref_)); + VLOG(2) << "index " << index << " added as output:\n" + << PrettyPrint(output) << "\nat " << outputs_.size(); + expr_to_output_index_.emplace(node->node_ref_, outputs_.size()); + outputs_.emplace_back(std::move(output)); + output_types_.emplace_back(node->node_ref_->checked_type()); + } + } + ICHECK(!outputs_.empty()); + + // Reverse the outputs so as to preserve the original evaluation order. + std::reverse(outputs_.begin(), outputs_.end()); + std::reverse(output_types_.begin(), output_types_.end()); + for (auto& kv : expr_to_output_index_) { + kv.second = static_cast(outputs_.size()) - 1 - kv.second; + } + + // Build a 'body' expression to represent the extracted sub-graph. If we have multiple + // outputs we'll place them in a tuple. + Type body_type; + Expr body; + if (outputs_.size() > 1) { + body_type = TupleType(output_types_); + body = Tuple(outputs_); + body->checked_type_ = body_type; + } else { + body_type = output_types_.front(); + body = outputs_.front(); + } + + // Re-express all the nested sub-graphs in terms of the body. + DataflowGraph body_dataflow_graph(body); + std::vector nested_sub_graphs; + IndexSubst subst = MakeIndexSubst(body_dataflow_graph); + for (const auto& nested_sub_graph : sub_graph_->nested_sub_graphs_) { + nested_sub_graphs.emplace_back(nested_sub_graph.Subst(body_dataflow_graph, subst)); + } + + // Sweep backwards through the body, rewriting to account for each nested sub-graph. + body = NestedSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(nested_sub_graphs)); + + if (for_function) { + // Rewrite so all input nodes are now conveyed via call arguments to a new function. + Array arg_types; + arg_types.reserve(params_.size()); + for (const auto& param : params_) { + arg_types.push_back(param->checked_type()); + } + extracted_ = Function(std::move(params_), std::move(body), body_type, + /*ty_params=*/{}, DictAttrs(opt_attrs_)); + extracted_->checked_type_ = + FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{}); + body = Call(extracted_, std::move(args_)); + body->checked_type_ = body_type; + } else { + // Don't do anything with the inputs. + extracted_ = body; + } + + // Setup the output substitution. + for (const auto& kv : expr_to_output_index_) { + Expr expr; + if (outputs_.size() == 1) { + expr = body; + } else if (for_function) { + expr = TupleGetItem(body, kv.second); + expr->checked_type_ = output_types_[kv.second]; + } else { + const auto* tuple_node = body.as(); + ICHECK(tuple_node); + expr = tuple_node->fields[kv.second]; + } + VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index " + << kv.second << " (of " << outputs_.size() << " outputs)"; + output_substitution_.emplace(kv.first, std::move(expr)); + } + } + + ////// Following members are valid only after Extract() has returned. + + /*! + * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is + * defined then will be a function. + */ + Expr extracted() const { return extracted_; } + + /*! + * \brief Returns the substitution to apply to all expression nodes in the overall expression + * so as to replace references to outputs of the sub-graph with their rewritten form. + */ + const std::unordered_map& output_substitution() const { + return output_substitution_; + } + + private: + /*! + * \brief Returns a map from original index to new index for each node inside the sub-graph. Only + * valid after \p Extract has made its backwards dataflow sweep. + */ + IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const { + VLOG(2) << "building extractor substitution"; + IndexSubst subst; + for (PostDfsIndex index : sub_graph_->inside_) { + auto orig_node = dataflow_graph_->index_to_node(index); + ICHECK_EQ(orig_node->index_, index); + auto itr = memo_.find(orig_node->ref()); + ICHECK(itr != memo_.end()); + auto new_node = new_dataflow_graph.item_to_node(itr->second); + VLOG(2) << orig_node->index_ << " |-> " << new_node->index_; + subst.emplace(orig_node->index_, new_node->index_); + } + return subst; + } + + /*! \brief Returns true if \p expr is inside the sub-graph. */ + bool inside(const Expr& expr) { + return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_]; + } + + /*! + * \brief Returns the variable uniquely representing \p expr, which should be + * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph). + * + * It is valid for: + * - An expression outside the sub-graph to be used multiple times inside the sub-graph. + * - An expression outside the sub-graph to be used both inside and outside the sub-graph. + */ + Var VarFor(const Expr& expr) { + ICHECK(!inside(expr)); + ICHECK(opt_attrs_.defined()); + auto itr = expr_to_param_.find(expr.get()); + if (itr != expr_to_param_.end()) { + return itr->second; + } + auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type()); + fresh_var->checked_type_ = expr->checked_type(); + params_.push_back(fresh_var); + args_.push_back(expr); + expr_to_param_.emplace(expr.get(), fresh_var); + return fresh_var; + } + + /*! + * \brief If \p expr is inside the sub-graph then return it's rewritten form. + * If \p expr is outside the sub-graph then it must correspond to an input node. + * - If opt_attrs_ is defined return the variable to represent it. + * - Otherwise just return the expression directly. + * + * Should be called only on inputs to nodes which are inside the sub-graph. + */ + Expr VisitExpr(const Expr& expr) final { + if (inside(expr)) { + return ExprMutator::VisitExpr(expr); + } else if (CanInline(expr)) { + // Implicitly include inlinable input sub-expressions. + return expr; + } else if (opt_attrs_.defined()) { + // Map to a function parameter. + return VarFor(expr); + } else { + // Stop rewriting. + return expr; + } + } + + Expr VisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + return ExprMutator::VisitExpr_(function_node); + } + + //// Context fields, passed in constructor. + + /*! \brief The dataflow graph corresponding to the overall expression. */ + const DataflowGraph* dataflow_graph_; + /*! \brief The sub-graph of the above we are extracting. */ + const SubGraphNode* sub_graph_; + /*! \brief Optional attributes if the sub-graph should be extracted as a function. */ + FunctionAttrsMap opt_attrs_; + + //// Result fields, available after Extract() called. + + /*! + * \brief The extracted expression. If opt_attrs_ is defined this will be a function. + */ + Expr extracted_; + /*! + * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than + * one exit node then each entry will be a tuple projection. + */ + std::unordered_map output_substitution_; + + //// Accumulator fields, built as we visit expressions. + + /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */ + Array params_; + /*! + * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_. + */ + Array args_; + /*! + * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters + * in params_ which now representing them. + */ + std::unordered_map expr_to_param_; + /*! + * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph. + * It is possible to have multiple outputs. It is possible one output also contributes to other + * outputs (ie the output is a 'tap'). + */ + std::vector outputs_; + /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */ + std::vector output_types_; + /*! + * \brief Map from existing exit expression nodes to the index in outputs_ which should + * represent them in the rewritten overall expression. + */ + std::unordered_map expr_to_output_index_; +}; + +Expr Rewriter::VisitExpr(const Expr& expr) { + auto itr = extractor_->output_substitution().find(expr.get()); + if (itr == extractor_->output_substitution().end()) { + return ExprMutator::VisitExpr(expr); + } else { + return itr->second; + } +} + +} // namespace + +std::pair SubExprKindAndLabel(const Expr& sub_expr) { + class Visitor : public ExprFunctor(const Expr&)> { + private: + std::pair VisitExpr_(const CallNode* call_node) final { + if (const auto* op_node = call_node->op.as()) { + auto op = GetRef(op_node); + static auto fpattern = Op::GetAttrMap("TOpPattern"); + if (fpattern.count(op) == 0) { + VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque"; + return {kOpaque, op->name}; + } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) { + VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque"; + return {kOpaque, op->name}; + } else { + OpPatternKind kind = static_cast(fpattern[op]); + VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind); + return {kind, op->name}; + } + } else if (const auto* function_node = call_node->op.as()) { + Optional opt_i = + function_node->GetAttr("TOpPattern", Optional()); + if (opt_i.defined()) { + OpPatternKind kind = static_cast(opt_i.value()->value); + VLOG(1) << "TOpPattern for function is " << KindToString(kind); + return {kind, "call_prim"}; + } else { + VLOG(1) << "calling function without TOpPattern, considering opaque"; + return {kOpaque, "call_fun"}; + } + } else { + VLOG(1) << "unsupported call, considering opaque"; + return {kOpaque, "call_any"}; + } + } + + std::pair VisitExpr_(const ConstantNode* constant_node) final { + VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise); + if (support::IsSimpleScalar(constant_node)) { + return {kElemWise, "scalar"}; + } else { + return {kElemWise, "const"}; + } + } + + std::pair VisitExpr_(const TupleNode* tuple_node) final { + const auto* tuple_type_node = tuple_node->checked_type().as(); + ICHECK(tuple_type_node != nullptr); + if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(), + [](const Type& type) { return type.as() != nullptr; })) { + VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective); + return {kInjective, "tuple"}; + } else { + VLOG(1) << "tuple contains non-tensors, considering opaque"; + return {kOpaque, "tuple"}; + } + } + + std::pair VisitExpr_( + const TupleGetItemNode* tuple_get_item_node) final { + const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as(); + ICHECK(tuple_type_node != nullptr); + if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(), + [](const Type& type) { return type.as() != nullptr; })) { + VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective); + return {kInjective, "proj"}; + } else { + VLOG(1) << "tuple being projected contains non-tensors, considering opaque"; + return {kOpaque, "proj"}; + } + } + + // TODO(mbs): We implement the following mostly so we have a lightweight way of describing + // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj + // sub-language we should revise the returned operator kinds to match. + + std::pair VisitExpr_(const VarNode* var_node) final { + return {kOpaque, "%" + var_node->name_hint()}; + } + std::pair VisitExpr_(const GlobalVarNode* global_var_node) final { + return {kOpaque, "@" + global_var_node->name_hint}; + } + std::pair VisitExpr_(const OpNode* op_node) final { + return {kOpaque, "`" + op_node->name}; + } + std::pair VisitExpr_(const FunctionNode* function_node) final { + return {kOpaque, "fn"}; + } + std::pair VisitExpr_(const LetNode* let_node) final { + return {kOpaque, "let"}; + } + std::pair VisitExpr_(const IfNode* if_node) final { + return {kOpaque, "if"}; + } + std::pair VisitExpr_(const RefCreateNode* ref_create_node) final { + return {kOpaque, "ref"}; + } + std::pair VisitExpr_(const RefReadNode* op) final { + return {kOpaque, "ref_read"}; + } + std::pair VisitExpr_(const RefWriteNode* op) final { + return {kOpaque, "ref_write"}; + } + std::pair VisitExpr_(const ConstructorNode* op) final { + return {kOpaque, "`" + op->name_hint}; + } + std::pair VisitExpr_(const MatchNode* op) final { + return {kOpaque, "match"}; + } + }; + return Visitor().VisitExpr(sub_expr); +} + +std::pair SubGraphKindAndLabel(const DataflowGraph& dataflow_graph, + const IndexSet& inside) { + std::ostringstream os; + bool first = true; + OpPatternKind max_kind = kElemWise; + for (PostDfsIndex index : inside) { + OpPatternKind sub_kind; + std::string sub_label; + std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref()); + if (!sub_label.empty()) { + if (first) { + first = false; + } else { + os << "+"; + } + os << sub_label; + } + max_kind = CombineKinds(max_kind, sub_kind); + } + return {max_kind, os.str()}; +} + +IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) { + IndexSet result(matcher.size()); + for (const auto& kv : matcher.memo()) { + for (const auto& matched_sub_expr : kv.second) { + if (CanInline(matched_sub_expr)) { + // Trivial sub-expressions can just be included in the extracted function body + // when we construct it and don't need to be considered part of the sub-graph. + continue; + } + if (kv.first.as()) { + // Don't consider the expressions matched by a wildcard to be part of the sub-graph. + continue; + } + result.Add(matcher.expr_to_node(matched_sub_expr)->index_); + } + } + return result; +} + +std::string SubGraphConfig::ToString() const { + std::ostringstream os; + os << "{max_exits=" << max_exits; + os << ", allow_taps=" << allow_taps; + os << ", max_depth=" << max_depth; + os << "}"; + return os.str(); +} + +TVM_REGISTER_NODE_TYPE(NestedSubGraphNode); + +void NestedSubGraphNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +SubGraph NestedSubGraphNode::sub_graph() const { return Downcast(sub_graph_obj_); } + +bool NestedSubGraphNode::operator==(const NestedSubGraphNode& that) const { + return *sub_graph().get() == *that.sub_graph().get(); +} + +bool NestedSubGraphNode::operator<(const NestedSubGraphNode& that) const { + return *sub_graph().get() < *that.sub_graph().get(); +} + +size_t NestedSubGraphNode::hash() const { + size_t h = StructuralHash()(attrs_); + h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; +} + +std::string NestedSubGraphNode::ToString() const { + std::ostringstream os; + os << "{sub_graph=" << sub_graph()->ToString(); + os << ", attrs=" << PrettyPrint(attrs_); + os << "}"; + return os.str(); +} + +Function NestedSubGraphNode::Extract(const DataflowGraph& dataflow_graph) const { + Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_); + extractor.Extract(); + return Downcast(extractor.extracted()); +} + +Expr NestedSubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const { + Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_); + extractor.Extract(); + Rewriter rewriter(&extractor); + return rewriter.VisitExpr(expr); +} + +NestedSubGraph::NestedSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs) { + auto data = runtime::make_object(); + data->sub_graph_obj_ = std::move(sub_graph); + data->attrs_ = std::move(attrs); + data_ = std::move(data); +} + +NestedSubGraph NestedSubGraph::Subst( + const DataflowGraph& new_dataflow_graph, + const std::unordered_map& subst) const { + return NestedSubGraph(get()->sub_graph().Subst(new_dataflow_graph, subst), get()->attrs_); +} + +bool NestedSubGraph::TriviallyUnionable(const NestedSubGraph& that) const { + if (get()->attrs_.size() != that->attrs_.size()) { + return false; + } + for (const auto& kv : get()->attrs_) { + if (kv.first == "Composite") { + // Even if all the attributes agree we don't consider "Composite" functions to + // ever be unionable. + // TODO(mbs): Find a cleaner way to do this. + return false; + } + auto itr = that->attrs_.find(kv.first); + if (itr == that->attrs_.end()) { + return false; + } + if (!StructuralEqual()(kv.second, (*itr).second)) { + return false; + } + } + return true; +} + +NestedSubGraph NestedSubGraph::DisjointUnion(const DataflowGraph& dataflow_graph, + const NestedSubGraph& that) const { + ICHECK(TriviallyUnionable(that)); + return NestedSubGraph(get()->sub_graph().DisjointUnion(dataflow_graph, that->sub_graph()), + get()->attrs_); +} + +/*static*/ +Expr NestedSubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + std::vector nested_sub_graphs) { + // IMPORTANT: See the corresponding comment in SubGraph::ParallelRewrite. + std::sort(nested_sub_graphs.begin(), nested_sub_graphs.end(), + [](const NestedSubGraph& left, const NestedSubGraph& right) { + return left->sub_graph()->last_inside_index_ > right->sub_graph()->last_inside_index_; + }); + + Expr result = expr; + for (const auto& nested_sub_graph : nested_sub_graphs) { + result = nested_sub_graph->Rewrite(dataflow_graph, result); + } + return result; +} + +TVM_REGISTER_NODE_TYPE(SubGraphNode); + +void SubGraphNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +IndexSet SubGraphNode::Downstream(const DataflowGraph& dataflow_graph) const { + IndexSet downstream(dataflow_graph.size()); + for (PostDfsIndex exit_index : exit_) { + downstream = downstream | dataflow_graph.downstream_of(exit_index); + } + return downstream; +} + +bool SubGraphNode::IsValid(const DataflowGraph& dataflow_graph, + const SubGraphConfig& config) const { + // Check we don't have too many exit nodes. + if (config.max_exits > 0 && exit_.PopCount() > config.max_exits) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: " << exit_.PopCount() + << " exits exceeds maximum " << config.max_exits; + return false; + } + + // Check the maximum path depth is in limit. + if (config.max_depth > 0 && depth_ > config.max_depth) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: maximum depth " << depth_ + << " exceeds limit " << config.max_depth; + return false; + } + + // All inside nodes must be in the same basic block. + const DataflowGraph::Node* basic_block = nullptr; + for (PostDfsIndex index : inside_) { + auto node = dataflow_graph.index_to_node(index); + if (basic_block == nullptr) { + basic_block = node->basic_block_; + } + if (node->basic_block_ != basic_block) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: nodes are from different basic blocks"; + return false; + } + } + + // The nested sub-graphs must be subsets and non-overlapping. + IndexSet union_inside(dataflow_graph.size()); + for (const auto& nested_sub_graph : nested_sub_graphs_) { + if (!nested_sub_graph->sub_graph()->inside_.AreDisjoint(union_inside)) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: nested sub-graphs overlap"; + return false; + } + if (!nested_sub_graph->sub_graph()->inside_.IsSubset(inside_)) { + VLOG(1) << "Subgraph " << ToString() + << " is invalid: nested sub-graph is not subset of overall sub-graph"; + return false; + } + } + + if (!config.allow_taps) { + // Exit nodes cannot also contribute to inside nodes. + for (PostDfsIndex index : exit_) { + auto node = dataflow_graph.index_to_node(index); + if (AnyOutputInside(node)) { + VLOG(1) << "Subgraph " << ToString() + << " is invalid: inner node is 'tapped' and also contributes to output, but taps " + "are disabled"; + return false; + } + } + } + + // Check no output would end up feeding into any entry node. + for (PostDfsIndex output_index : output_) { + if (dataflow_graph.downstream_of(output_index).Intersects(entry_)) { + VLOG(1) << "Subgraph " << ToString() << " is invalid: output node " << output_index + << " feeds back into this sub-graph"; + return false; + } + } + + // Looks legit! + return true; +} + +Function SubGraphNode::ExtractAsFunction(const DataflowGraph& dataflow_graph) const { + NestedSubGraph nested_sub_graph(GetRef(this), FunctionAttrsMap()); + return nested_sub_graph->Extract(dataflow_graph); +} + +Expr SubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const { + if (nested_sub_graphs_.empty()) { + // Nothing to rewrite. + return expr; + } + Extractor extractor(&dataflow_graph, this, NullValue()); + extractor.Extract(); + Rewriter rewriter(&extractor); + return rewriter.VisitExpr(expr); +} + +std::string SubGraphNode::ToString() const { + std::ostringstream os; + os << "{inside=" << inside_.ToString(); + os << ", entry=" << entry_.ToString(); + os << ", exit=" << exit_.ToString(); + os << ", input=" << input_.ToString(); + os << ", output=" << output_.ToString(); + os << ", depth=" << depth_; + os << ", kind=" << KindToString(kind_); + if (!label_.empty()) { + os << ", label=" << label_; + } + for (const auto& nested_sub_graph : nested_sub_graphs_) { + os << ", nested_sub_graph=" << nested_sub_graph->ToString(); + } + os << "}"; + return os.str(); +} + +bool SubGraphNode::operator==(const SubGraphNode& that) const { + ICHECK_EQ(inside_.end_index(), that.inside_.end_index()); + if (inside_ != that.inside_) { + return false; + } + if (nested_sub_graphs_.size() != that.nested_sub_graphs_.size()) { + return false; + } + for (size_t i = 0; i < nested_sub_graphs_.size(); ++i) { + if (*nested_sub_graphs_[i].get() != *that.nested_sub_graphs_[i].get()) { + return false; + } + } + return true; +} + +bool SubGraphNode::operator<(const SubGraphNode& that) const { + if (first_inside_index_ < that.first_inside_index_) { + return true; + } + if (that.first_inside_index_ < first_inside_index_) { + return false; + } + return inside_ < that.inside_; +} + +size_t SubGraphNode::hash() const { + size_t h = inside_.hash(); + for (const auto& nested_sub_graph : nested_sub_graphs_) { + h ^= nested_sub_graph->hash() + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; +} + +void SubGraphNode::Init(const DataflowGraph& dataflow_graph) { + for (PostDfsIndex index = 0; index < inside_.end_index(); ++index) { + auto node = dataflow_graph.index_to_node(index); + if (inside_[index]) { + if (AnyInputOutside(node)) { + entry_.Add(index); + } + if (AnyOutputOutside(node) || node->is_external_) { + exit_.Add(index); + } + } else { + if (AnyInputInside(node)) { + output_.Add(index); + } + if (AnyOutputInside(node) && !CanInline(node->ref())) { + input_.Add(index); + } + } + } + depth_ = Depth(dataflow_graph); +} + +size_t SubGraphNode::Depth(const DataflowGraph& dataflow_graph) const { + std::unordered_map max_depths; + std::vector stack; + size_t max_depth = 0; + // All the entry nodes have max depth 0. + for (PostDfsIndex index : entry_) { + auto node = dataflow_graph.index_to_node(index); + max_depths.emplace(node, 0); + stack.push_back(node); + } + while (!stack.empty()) { + const DataflowGraph::Node* node = stack.back(); + stack.pop_back(); + size_t next_depth = max_depths[node] + 1; + if (exit_[node->index_]) { + // If this node is external then it will have no outputs but we still wish to consider + // the path to the implied output as requiring one more step. + // Otherwise we're accounting for reaching one of the external outputs belowe. + max_depth = std::max(max_depth, next_depth); + } + for (const DataflowGraph::Node* output_node : node->outputs_) { + if (!inside_[output_node->index_]) { + continue; + } + if (max_depths.count(output_node) == 0) { + max_depths.emplace(output_node, next_depth); + stack.push_back(output_node); + } else if (next_depth > max_depths[output_node]) { + // We found a deeper path to an already expanded node. We'll expand again. + max_depths[output_node] = next_depth; + stack.push_back(output_node); + } + } + } + return max_depth; +} + +/*! \brief Returns true if any (input/output) of node is (outside/inside) the sub-graph. */ +bool SubGraphNode::AnyInputOutside(const DataflowGraph::Node* node) const { + return std::any_of(node->inputs_.begin(), node->inputs_.end(), + [this](const DataflowGraph::Node* sub_node) { + return !inside_[sub_node->index_] && !CanInline(sub_node->ref()); + }); +} + +bool SubGraphNode::AnyInputInside(const DataflowGraph::Node* node) const { + return std::any_of( + node->inputs_.begin(), node->inputs_.end(), + [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; }); +} + +bool SubGraphNode::AnyOutputOutside(const DataflowGraph::Node* node) const { + return std::any_of( + node->outputs_.begin(), node->outputs_.end(), + [this](const DataflowGraph::Node* sub_node) { return !inside_[sub_node->index_]; }); +} + +bool SubGraphNode::AnyOutputInside(const DataflowGraph::Node* node) const { + return std::any_of( + node->outputs_.begin(), node->outputs_.end(), + [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; }); +} + +SubGraph::SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind, + String label, std::vector nested_sub_graphs) { + std::sort(nested_sub_graphs.begin(), nested_sub_graphs.end(), + [](const NestedSubGraph& left, const NestedSubGraph& right) { + return *left.get() < *right.get(); + }); + auto node = runtime::make_object(); + node->inside_ = std::move(inside); + node->first_inside_index_ = node->inside_.FirstInsideIndex(); + node->last_inside_index_ = node->inside_.LastInsideIndex(); + node->entry_ = IndexSet(node->inside_.end_index()); + node->exit_ = IndexSet(node->inside_.end_index()); + node->input_ = IndexSet(node->inside_.end_index()); + node->output_ = IndexSet(node->inside_.end_index()); + node->kind_ = kind; + node->label_ = std::move(label); + node->nested_sub_graphs_ = nested_sub_graphs; + node->Init(dataflow_graph); + data_ = std::move(node); +} + +SubGraph::SubGraph(const DataflowGraph& dataflow_graph) + : SubGraph(dataflow_graph, IndexSet(dataflow_graph.size())) {} + +bool SubGraph::AreDisjoint(const SubGraph& that) const { + return get()->inside_.AreDisjoint(that->inside_); +} + +namespace { +/*! \brief Returns true if an output of \p left not in \p right ultimately flows into \p right. */ +bool FlowsInto(const DataflowGraph& dataflow_graph, const SubGraph& left, const SubGraph& right) { + for (PostDfsIndex output_index : left->output_) { + if (!right->inside_[output_index] && + dataflow_graph.downstream_of(output_index).Intersects(right->entry_)) { + return true; + } + } + return false; +} +} // namespace + +bool SubGraph::AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const { + if (!get()->inside_.AreDisjoint(that->inside_)) { + // Easy rejection. + return false; + } + if (!get()->output_.Intersects(that->entry_)) { + // Not touching. + return false; + } + if (FlowsInto(dataflow_graph, *this, that) || FlowsInto(dataflow_graph, that, *this)) { + // Unioning would create a cycle. + return false; + } + return true; +} + +bool SubGraph::AreSelfContained(const SubGraph& that) const { + return get()->output_.IsSubset(that->entry_) && that->input_.IsSubset(get()->exit_); +} + +SubGraph SubGraph::DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const { + ICHECK(AreDisjoint(that)); + IndexSet inside = get()->inside_ | that->inside_; + std::vector nested_sub_graphs; + for (const auto& nested_sub_graph : get()->nested_sub_graphs_) { + nested_sub_graphs.push_back(nested_sub_graph); + } + for (const auto& nested_sub_graph : that->nested_sub_graphs_) { + auto existing_itr = std::find_if(nested_sub_graphs.begin(), nested_sub_graphs.end(), + [&nested_sub_graph](const NestedSubGraph& existing) { + return existing.TriviallyUnionable(nested_sub_graph); + }); + if (existing_itr != nested_sub_graphs.end()) { + *existing_itr = existing_itr->DisjointUnion(dataflow_graph, nested_sub_graph); + } else { + nested_sub_graphs.push_back(nested_sub_graph); + } + } + return SubGraph(dataflow_graph, std::move(inside), CombineKinds(get()->kind_, that->kind_), + UnionLabels(get()->label_, that->label_), std::move(nested_sub_graphs)); +} + +SubGraph SubGraph::WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const { + std::vector nested_sub_graphs; + nested_sub_graphs.push_back(NestedSubGraph(*this, attrs)); + return SubGraph(dataflow_graph, get()->inside_, get()->kind_, get()->label_, + std::move(nested_sub_graphs)); +} + +SubGraph SubGraph::Subst(const DataflowGraph& new_dataflow_graph, const IndexSubst& subst) const { + IndexSet new_inside = get()->inside_.Subst(new_dataflow_graph.size(), subst); + std::vector new_nested_sub_graphs; + for (const auto& nested_sub_graph : get()->nested_sub_graphs_) { + new_nested_sub_graphs.push_back(nested_sub_graph.Subst(new_dataflow_graph, subst)); + } + return SubGraph(new_dataflow_graph, std::move(new_inside), get()->kind_, get()->label_, + std::move(new_nested_sub_graphs)); +} + +/*static*/ +Expr SubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph, + std::vector sub_graphs) { + // IMPORTANT: + // - All the sub-graphs will be w.r.t. the dataflow graph for the original expression. + // Each time we call Rewrite on one of those graphs the result expression will be rewritten + // from the final output back to the inputs. The inputs will then be shared with the original + // expression. Thus it is safe to iteratively rewrite all the sub-graphs without redoing the + // dataflow_graph and substituting indexes provided we work in reverse dataflow order. + // - We rely on the dataflow_graph expression reference holding the original expression alive + // so that the dataflow_graph will never contain dangling pointers (even though as per above + // we'll never dereference them). + std::sort(sub_graphs.begin(), sub_graphs.end(), [](const SubGraph& left, const SubGraph& right) { + return left->last_inside_index_ > right->last_inside_index_; + }); + Expr result = dataflow_graph.expr(); + for (const auto& sub_graph : sub_graphs) { + result = sub_graph->Rewrite(dataflow_graph, result); + } + return result; +} + +/*! + * \brief A pass which partitions (the unique) global function in the module according to the + * post-dfs indexes in \p indexes. The partitioning must respect the configuration with \p max_exits + * and \p allow_taps. + * + * Each index is also paired with a label. A non-empty label denotes the index should also be + * included in a nested sub-graph which will be extracted as a function with the label as its + * "Composite" attribute. An empty label denotes the index should go into the overall partitioned + * "Compiler" function. In this way we can simulate the usual partitioning needed by external + * codegen integrations. + * + * This function is intended to support \p SubGraph unit tests and is not used by the regular + * compilation flow. + */ +transform::Pass PartitionForTesting(Integer max_exits, Bool allow_taps, String compiler, + Array indexes, Array labels) { + auto pass_func = [=](Function function, IRModule mod, transform::PassContext ctxt) { + ICHECK(max_exits.defined() && max_exits->value >= 0); + ICHECK(allow_taps.defined()); + ICHECK(indexes.size() == labels.size()); + VLOG(1) << "Partitioning:" << std::endl << PrettyPrint(function); + DataflowGraph dataflow_graph(function); + VLOG(1) << "Dataflow graph is:" << std::endl << dataflow_graph.indexed_graph().ToString(); + + // Collect the 'inside' indexes and any nested sub-graph indexes and labels. + std::vector node_indexes; + std::unordered_map> nested_sub_graph_indexes; + node_indexes.reserve(indexes.size()); + for (size_t i = 0; i < indexes.size(); ++i) { + const Integer& index = indexes[i]; + ICHECK_GE(index->value, 0); + ICHECK_LT(index->value, dataflow_graph.size()); + auto index_int = static_cast(index->value); + node_indexes.push_back(index_int); + const String& label = labels[i]; + if (!label.empty()) { + nested_sub_graph_indexes[label].push_back(index_int); + } + } + + // Build the nested sub-graphs representing the "Composite" functions (if any). + std::vector nested_sub_graphs; + for (const auto& kv : nested_sub_graph_indexes) { + FunctionAttrsMap composite_attrs; + composite_attrs.Set("Composite", kv.first); + nested_sub_graphs.emplace_back( + SubGraph(dataflow_graph, IndexSet(dataflow_graph.size(), kv.second)), composite_attrs); + } + + // Build the overall sub-graph, which will include any "Composite" functions as + // well as any nodes without a label. + IndexSet inside(dataflow_graph.size(), node_indexes); + OpPatternKind kind; + String label; + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, inside, kind, label, std::move(nested_sub_graphs)); + + // Push the overall sub-graph into the final "Compiler" function. + FunctionAttrsMap compiler_attrs; + compiler_attrs.Set("Compiler", compiler); + NestedSubGraph overall_nested_sub_graph(sub_graph, compiler_attrs); + SubGraph overall_sub_graph(dataflow_graph, inside, kind, label, {overall_nested_sub_graph}); + + // Check the sub-graph is valid. + SubGraphConfig config; + config.max_exits = static_cast(max_exits->value); + config.allow_taps = allow_taps; + if (overall_sub_graph->IsValid(dataflow_graph, config)) { + VLOG(1) << "Sub-graph " << overall_sub_graph->ToString() << " is considered valid"; + } else { + VLOG(1) << "Sub-graph " << overall_sub_graph->ToString() + << " is NOT considered valid, not partitioning"; + return function; + } + + // Do the partitioning. + Function result = Downcast(overall_sub_graph->Rewrite(dataflow_graph, function)); + VLOG(1) << "Extracted as:" << std::endl << PrettyPrint(result); + + return result; + }; + return transform::CreateFunctionPass(pass_func, /*opt_level=*/0, "PartitionForTesting", {}); +} + +TVM_REGISTER_GLOBAL("relay.collage.PartitionForTesting").set_body_typed(PartitionForTesting); + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/sub_graph.h b/src/relay/collage/sub_graph.h new file mode 100644 index 000000000000..f7d4354d5483 --- /dev/null +++ b/src/relay/collage/sub_graph.h @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/sub_graph.h + * \brief Represents a sub-graph of an overall Relay expression. + */ + +#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_ +#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_ + +#include +#include + +#include +#include +#include +#include + +#include "../ir/dataflow_matcher_impl.h" +#include "../ir/indexed_graph.h" +#include "./dataflow_graph.h" +#include "./index_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! \brief Returns operator pattern kind as single-letter string. */ +std::string KindToString(OpPatternKind kind); + +/*! + * \brief Returns a kind and label for the single \p sub_expr, ignoring its nested sub expressions. + */ +std::pair SubExprKindAndLabel(const Expr& sub_expr); + +/*! + * \brief Returns a kind and label for all the nodes in \p inside. + */ +std::pair SubGraphKindAndLabel(const DataflowGraph& dataflow_graph, + const IndexSet& inside); + +/*! + * \brief Returns the index set representing all the sub-expression matched by \p matcher. + */ +IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher); + +/*! + * \brief Configuration controlling which sub-graphs are considered valid. + */ +struct SubGraphConfig { + /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */ + size_t max_exits = 0; + /*! + * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside + * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs + * even with this flag false. + */ + bool allow_taps = false; + /*! + * \brief Maximum allowed sub-graph depth, or zero if no-limit. + */ + size_t max_depth = 0; + + std::string ToString() const; +}; + +class SubGraph; +using FunctionAttrsMap = Map; + +/*! + * \brief A nested sub-graph is a sub-graph which is to be nested inside a function as part of some + * enclosing sub-graph. + * + * Extraction yields a function with input nodes replaced by parameters and exit nodes in the + * function result. Rewriting replaces the sub-graph with a call to that function, and all + * outputs with (projections from) the call result. + * + * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class. + * However we found the implementation was easier to understand in this form since it makes + * the result of \p Extract unambiguous.) + */ +class NestedSubGraphNode : public Object { + public: + /*! \brief The nested sub-graph. */ + ObjectRef /* actually SubGraph */ sub_graph_obj_; + /*! \brief Attributes (possibly empty) to attach to the extracted function. */ + FunctionAttrsMap attrs_; + + void VisitAttrs(AttrVisitor* v); + + SubGraph sub_graph() const; + + bool operator==(const NestedSubGraphNode& that) const; + bool operator!=(const NestedSubGraphNode& that) const { return !(*this == that); } + bool operator<(const NestedSubGraphNode& that) const; + size_t hash() const; + + std::string ToString() const; + + /*! + * \brief Returns the function representing this nested sub-graph within the overall expression + * represented by \p dataflow_graph: + * - All sub-graph inputs become parameters. + * - All sub-graph outputs become function results (either directly or as a field in a tuple). + * - The function has attrs_ for attributes (which may be empty). + * - The function body accounts for any rewrites implied by the nested sub-graph. + */ + Function Extract(const DataflowGraph& dataflow_graph) const; + + /*! + * \brief Returns \p expr rewritten to encode the partitioning implied by this nested sub-graph. + * + * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes + * inside this nested sub-graph must correspond to nodes shared between \p dataflow_graph.expr() + * and \p expr. See \p SubGraph::ParallelRewrite below. + */ + Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const; + + static constexpr const char* _type_key = "relay.collage.NestedSubGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(NestedSubGraphNode, Object); +}; + +class NestedSubGraph : public ObjectRef { + public: + NestedSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs); + + /*! + * \brief Returns copy of this nested sub-graph with all indexes substituted according to + * \p subst, whose range is w.r.t. \p new_dataflow_graph. + */ + NestedSubGraph Subst(const DataflowGraph& new_dataflow_graph, + const std::unordered_map& subst) const; + + /*! + * \brief Returns true if this can be safely unioned. + */ + bool TriviallyUnionable(const NestedSubGraph& that) const; + + /*! + * \brief Returns the disjoint union of this and \p that nested sub-graphs, which must agree on + * their attributes. + */ + NestedSubGraph DisjointUnion(const DataflowGraph& dataflow_graph, + const NestedSubGraph& that) const; + + /*! + * \brief Returns \p expr rewritten according to all the given nested sub-graphs. The + * nested sub-graphs can be given in any order, but must be disjoint. + * + * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes + * inside the nested sub-graphs must correspond to nodes shared between \p dataflow_graph.expr() + * and \p expr. See \p SubGraph::ParallelRewrite below. + */ + static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, + std::vector nested_sub_graphs); + + TVM_DEFINE_OBJECT_REF_METHODS(NestedSubGraph, ObjectRef, NestedSubGraphNode); +}; + +using NestedSubGraphs = Array; + +/*! + * \brief A compact representation of a sub-graph within an (implied) overall Relay expression. + * + * Sub-graphs can be used to represent partitions/kernels/composite functions without having to + * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a + * function to use for measuring a partition/kernel's latency independently from 'rewriting' + * the overall Relay expression since only a tiny subset of candidate partitions will end up being + * needed after Collage has completed its search. + * + * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so we are + * mindful of space overhead. + * + * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or + * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not + * valid for an inside node to feed into another inside node via outside nodes. We provide the + * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules + * apply (such as maximum depth). + * + * We generally work with the \p DataflowGraph representation of the overall Relay expression + * rather than the expression itself. We use the post-dfs visit index to uniquely refer to + * expression nodes. + * + * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely + * determined from the 'inside' nodes: + * - 'entry' nodes are those inside with at least one dataflow input outside. + * - 'exit' nodes are those inside with at least one dataflow output outside, or which + * are considered 'external' in the underlying dataflow graph (eg because they represent + * the result of the overall function). + * - 'input' nodes are those outside with at least one dataflow output inside. + * - 'output' nodes are those outside with at least one dataflow input inside. + * Index sets for these are cached with the sub-graph for performance. + * + * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to + * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes + * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result). + * + * Sub-graphs are closed under: + * - Disjoint union. + * - Wrapping by a function with given attributes (see \p NestedSubGraph above). This can be used + * to encode "Composite" functions, or to represent a candidate kernel within a "Primitive" + * function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should + * be placed inside a primitive function which itself may have calls to composite functions). + * - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to + * match some other (typically smaller) dataflow graph. + * + * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage + * search. + * + * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs + * a kind, which is generally the maximum of the kinds of all the operator calls appearing + * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging + * and guide the selection of global symbol names. + */ +class SubGraphNode : public Object { + public: + /*! + * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t. + * the implied DataflowGraph). + */ + IndexSet inside_; + + /*! + * \brief Index of first and last inside nodes. + * + * Cached for performance, uniquely determined by inside_. + */ + PostDfsIndex first_inside_index_ = 0; + PostDfsIndex last_inside_index_ = 0; + + /*! + * \brief Which sub-expressions are entry/exit/input/output for this sub-graph. + * + * Cached for performance, uniquely determined by inside_. + */ + IndexSet entry_; + IndexSet exit_; + IndexSet input_; + IndexSet output_; + + /*! + * \brief Maximum depth of any dataflow path from an entry to an output sub-expression. + * + * Cached for performance, uniquely determined by inside_. + */ + size_t depth_ = 0; + + /*! + * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph. + * + * A sub-graph consisting of a single Relay expression node is given kind: + * - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the + * call does not involve data-dependent dynamic shapes). + * - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has + * that attribute) + * - For Constants, \p kElemWise. + * - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor + * type) + * - All other nodes \p kOpaque. + * Sub-graphs with more than one node have the maximum of the kind of each node. + * + * Cached for performance, uniquely determined by inside_. + */ + OpPatternKind kind_ = kOpaque; + + /*! + * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary + * of the sub-graph which can help with debugging and guide the selection of global symbol names. + */ + String label_; + + /*! + * \brief Nested sub-graphs of this sub-graph which must be represented by functions. These must + * be disjoint, but it's ok for this sub-graph to have nodes not inside any nested sub-graph. + */ + NestedSubGraphs nested_sub_graphs_; + + void VisitAttrs(AttrVisitor* v); + + // TODO(mbs): 'Anchor nodes' and rules for unioning them. + // In FuseOps it's just the unique kEWiseFusable node, if any. + // I'd like to allow writing vertical fusion rules, eg if two candidates are directly + // connected and have nn.conv2d anchors allow their join. + // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly + // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors + // then do so. Come back to this. + + /*! \brief Number of nodes in overall dataflow graph. */ + size_t overall_size() const { return inside_.end_index(); } + + bool IsEmpty() const { return inside_.IsZero(); } + + /*! \brief Number of nodes in sub-graph. */ + size_t Size() const { return inside_.PopCount(); } + + /*! + * \brief Returns the dataflow nodes downstream of all exit nodes. + */ + IndexSet Downstream(const DataflowGraph& dataflow_graph) const; + + /*! + * \brief Returns true if this sub-graph is valid. Ie: + * - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up + * with a dataflow cycle when we partition). + * - all inputs and outputs of the sub-graph are in the same scope, ie not separated by + * control flow (otherwise there'd be no consistent program point at which to eval the + * partitioned function). + * - no more than config.max_outputs outputs are required. + * - if config.allow_taps is false, no inside node has outputs to nodes both inside and + * outside the sub-graph. + */ + bool IsValid(const DataflowGraph& dataflow_graph, const SubGraphConfig& config) const; + + /*! + * \brief Returns this sub-graph extracted as a stand-alone function. The function will have + * no attributes, and is suitable for building and profiling by the \p CostEstimator. + */ + Function ExtractAsFunction(const DataflowGraph& dataflow_graph) const; + + /*! + * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-graph. + * + * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes + * inside this sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and + * \p expr. See \p SubGraph::ParallelRewrite below. + */ + Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const; + + std::string ToString() const; + + bool operator==(const SubGraphNode& that) const; + bool operator!=(const SubGraphNode& that) const { return !(*this == that); } + bool operator<(const SubGraphNode& that) const; + size_t hash() const; + + private: + /*! \brief Initialize the entry/exit/input/output sets given the inside and \p dataflow_graph. */ + void Init(const DataflowGraph& dataflow_graph); + + /*! \brief Calculates and returns the maximum path depth. */ + size_t Depth(const DataflowGraph& dataflow_graph) const; + + /*! \brief Returns true if any (input/output) of node is (outside/inside) the sub-graph. */ + bool AnyInputOutside(const DataflowGraph::Node* node) const; + bool AnyInputInside(const DataflowGraph::Node* node) const; + bool AnyOutputOutside(const DataflowGraph::Node* node) const; + bool AnyOutputInside(const DataflowGraph::Node* node) const; + + public: + static constexpr const char* _type_key = "relay.collage.SubGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(SubGraphNode, Object); + + friend class SubGraph; +}; + +class SubGraph : public ObjectRef { + public: + /*! \brief Primitive constructor. The following constructors are generally more convenient. */ + SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind = kOpaque, + String label = {}, std::vector nested_sub_graphs = {}); + + /*! \brief Constructs the empty sub-graph for \p dataflow_graph. */ + explicit SubGraph(const DataflowGraph& dataflow_graph); + + /*! \brief Returns true if this and that are disjoint. */ + bool AreDisjoint(const SubGraph& that) const; + + /*! + * \brief Returns true if: + * - \p this and \p that are disjoint, and + * - an output node of \p this coincides with an entry node of \p that, and + * - \p this and \p that are not obviously invalid after \p DisjointUnion + * (eg because such a sub-graph would produce a cycle). + * Note however that the \p DisjointUnion may not necessarily be valid even with the above + * checks. + */ + bool AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const; + + /*! + * \brief Returns true if: + * - all the outputs of \p this are entries for \p that, and + * - all the inputs of \p that are exits for \p this. + */ + bool AreSelfContained(const SubGraph& that) const; + + /*! + * \brief Returns disjoint union of this and \p that sub-graphs. The result may not be valid. + */ + SubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const; + + /*! + * \brief Returns copy of this sub-graph with all nodes placed inside a nested sub-graph with + * given attributes. + */ + SubGraph WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const; + + /*! + * \brief Returns copy of this sub-graph with all indexes substituted according to \p subst, + * whose range is w.r.t. \p new_dataflow_graph. + */ + SubGraph Subst(const DataflowGraph& new_dataflow_graph, + const std::unordered_map& subst) const; + + /*! + * \brief Returns the root expression of \p dataflow_graph rewritten according to all the + * given sub-graphs. The sub-graphs can be given in any order, but must be disjoint. + */ + static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, + std::vector sub_graphs); + + TVM_DEFINE_OBJECT_REF_METHODS(SubGraph, ObjectRef, SubGraphNode); +}; + +struct SubGraphEqual { + bool operator()(const SubGraph& left, const SubGraph& right) const { + return *left.get() == *right.get(); + } +}; + +struct SubGraphHash { + size_t operator()(const SubGraph& sub_graph) const { return sub_graph->hash(); } +}; + +/*! + * \brief Pass to partition every global function according to the post-dfs indexes + * given in an array. Visible for testing from Python only, would never make sense to use + * as a generic pass! + */ +tvm::transform::Pass PartitionOnIndexesForTesting(Array indexes); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_SUB_GRAPH_H_ diff --git a/src/relay/collage/utils.cc b/src/relay/collage/utils.cc new file mode 100644 index 000000000000..03af980e8c1d --- /dev/null +++ b/src/relay/collage/utils.cc @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/utils.cc + * \brief Misc helpers. + */ + +#include "./utils.h" + +#include "../../support/scalars.h" +#include "../op/memory/device_copy.h" + +namespace tvm { +namespace relay { +namespace collage { + +String GetSpecName(const Target& target) { + if (TargetKind::GetAttrMap(tvm::attr::kIsExternalCodegen).get(target->kind, Bool(false))) { + return target->kind->name; + } else { + return std::string(kTVMSpecNamePrefix) + target->kind->name; + } +} + +String UnionLabels(String left, String right) { + if (left.empty()) { + return right; + } + if (right.empty()) { + return left; + } + return left + "+" + right; +} + +String NestLabels(String left, String right) { + if (left.empty()) { + return right; + } + if (right.empty()) { + return left; + } + if (right.size() > left.size()) { + std::string right_str = right; + if (right_str.substr(0, left.size()) == left) { + return right; + } + } + return left + "." + right; +} + +std::string KindToString(OpPatternKind kind) { + switch (kind) { + case kElemWise: + return "E"; + case kBroadcast: + return "B"; + case kInjective: + return "I"; + case kCommReduce: + return "R"; + case kOutEWiseFusable: + return "A"; + case kTuple: + return "T"; + case kOpaque: + return "O"; + } + return "?"; +} + +OpPatternKind CombineKinds(OpPatternKind left, OpPatternKind right) { + return std::max(left, right); +} + +bool CanInline(const Expr& expr) { + if (expr.as() || expr.as() || expr.as()) { + return true; + } + if (const auto* constant_node = expr.as()) { + return support::IsSimpleScalar(constant_node); + } + return false; +} + +bool IsSpecialOp(const OpNode* op_node) { + auto op = GetRef(op_node); + static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); + if (fnoncomputational.count(op) && fnoncomputational[op]) { + // Operator has been marked as non-computational. + return true; + } + // TODO(mbs): This is incomplete. + static auto shape_of_op_ = Op::Get("shape_of"); + static auto vm_shape_of_op_ = Op::Get("vm.shape_of"); + if (op == DeviceCopyOp() || op == shape_of_op_ || op == vm_shape_of_op_) { + // Operator is compiled away by the VM compilation flow. + return true; + } + return false; +} + +bool MustBeLowered(const Expr& expr) { + if (const auto* call_node = expr.as()) { + if (const auto* function_node = call_node->op.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // We've already committed to this call being to one or more operators which must be + // lowered. + return true; + } + } else if (const auto* op_node = call_node->op.as()) { + if (!IsSpecialOp(op_node)) { + // The VM compilation path won't rewrite this call. + return true; + } + } + } + return false; +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/utils.h b/src/relay/collage/utils.h new file mode 100644 index 000000000000..4c0493cdd675 --- /dev/null +++ b/src/relay/collage/utils.h @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/utils.h + * \brief Misc helpers. + */ + +#ifndef TVM_RELAY_COLLAGE_UTILS_H_ +#define TVM_RELAY_COLLAGE_UTILS_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Distinguished partition spec names. + */ +constexpr const char* kTVMSpecNamePrefix = "tvm_"; +constexpr const char* kHostSpecName = "host"; + +/*! + * \brief Returns the partition spec name to use for \p target. For external codegen targets the + * spec name is just the target kind name. For TVM native targets the spec name is of the form + * "tvm_". + */ +String GetSpecName(const Target& target); + +/*! \brief Returns \p "+". */ +String UnionLabels(String left, String right); + +/*! \brief Returns \p ".". */ +String NestLabels(String outer, String inner); + +/*! \brief Returns abbreviation for \p kind. */ +std::string KindToString(OpPatternKind kind); + +/*! \brief Returns maximum of \p left and \p right. */ +OpPatternKind CombineKinds(OpPatternKind left, OpPatternKind right); + +/*! + * \brief Returns true if \p expr can be safely inlined in body of function extracted + * from sub-graph, even if \p expr was not technically matched by the pattern which produced + * the sub-graph. + */ +bool CanInline(const Expr& expr); + +/*! + * \brief Returns true if \p op_node can be directly handled by the VM. + */ +bool IsSpecialOp(const OpNode* op_node); + +/*! + * \brief Return true if the Relay expression node given by \p expr cannot be evaluated by + * the VM and must end up in a kernel. + */ +bool MustBeLowered(const Expr& expr); + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_UTILS_H_ diff --git a/tests/python/relay/collage/test_sub_graph.py b/tests/python/relay/collage/test_sub_graph.py new file mode 100644 index 000000000000..de2d974bf934 --- /dev/null +++ b/tests/python/relay/collage/test_sub_graph.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import logging +import tvm.testing + +logging.basicConfig(level=logging.INFO) + +partition_for_testing = tvm._ffi.get_global_func("relay.collage.PartitionForTesting") + + +def print_with_indexes(mod): + mod = tvm.relay.transform.CapturePostDfsIndexInSpans()(mod) + print(mod) + + +def run(in_mod, expected_mod, max_outputs, allow_taps, compiler, map): + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + in_mod = tvm.relay.transform.InferType()(in_mod) + in_mod = tvm.relay.transform.CapturePostDfsIndexInSpans()(in_mod) + + indexes = [i for l, iss in map.items() for i in iss] + labels = [l for l, iss in map.items() for i in iss] + actual_mod = partition_for_testing(max_outputs, allow_taps, compiler, indexes, labels)(in_mod) + + if not tvm.ir.structural_equal(actual_mod, expected_mod, True): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod, map_free_vars=True) + + +def test_single_op(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); // node 7 + subtract(%0, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = (fn(%x, %y, Compiler="foo") { add(%x, %y) })(%c, %d); + subtract(%0, %1) + } + """ + ) + + run(input(), expected(), 1, False, "foo", {"": [7]}) + + +def test_multi_output(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); // node 6 + %1 = add(%c, %d); // node 7 + subtract(%0, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = (fn(%w, %x, %y, %z, Compiler="foo") { (add(%y, %z), add(%w, %x)) })(%c, %d, %a, %b); + %1 = %0.0; + %2 = %0.1; + subtract(%1, %2) + } + """ + ) + + # No rewrite since 2 outputs + run(input(), input(), 1, False, "foo", {"": [6, 7]}) + # Rewrite + run(input(), expected(), 2, False, "foo", {"": [6, 7]}) + + +def test_classic_conv2d_add_relu(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32], + %c: Tensor[(5, 2, 28, 28), float32], %d: Tensor[(5, 2, 28, 28), float32]) { + %0 = nn.conv2d(%a, %b); // node 8 + %1 = add(%0, %c); // node 9 + %2 = nn.relu(%1); // node 10 + subtract(%2, %d) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32], + %c: Tensor[(5, 2, 28, 28), float32], %d: Tensor[(5, 2, 28, 28), float32]) { + %2 = (fn(%x, %y, %z, Compiler="foo") { + %0 = nn.conv2d(%x, %y); + %1 = add(%0, %z); + nn.relu(%1) + })(%a, %b, %c); + subtract(%2, %d) + } + """ + ) + + run(input(), expected(), 1, False, "foo", {"": [8, 9, 10]}) + + +def test_diamond_single_output(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%a, %b, padding=[0, 0, 0, 0]); // node 5 + %1 = nn.relu(%0); // node 6 + %2 = nn.relu(%1); // node 7 + %3 = nn.leaky_relu(%0, alpha=0f); // node 9 + add(%2, %3) // node 10 + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + (fn (%x: Tensor[(5, 3, 32, 32), float32], %y: Tensor[(2, 3, 5, 5), float32], Compiler="foo") { + %0 = nn.conv2d(%x, %y, padding=[0, 0, 0, 0]); + %1 = nn.relu(%0); + %2 = nn.relu(%1); + %3 = nn.leaky_relu(%0, alpha=0f); + add(%2, %3) + })(%a, %b) + } + """ + ) + + run(input(), expected(), 1, False, "foo", {"": [5, 6, 7, 9, 10]}) + + +def test_diamond_multi_output(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%a, %b, padding=[0, 0, 0, 0]); // node 5 + %1 = nn.relu(%0); // node 6 + %2 = nn.relu(%1); // node 7 + %3 = nn.leaky_relu(%0, alpha=0f); // node 9 + add(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %4 = (fn (%x: Tensor[(5, 3, 32, 32), float32], %y: Tensor[(2, 3, 5, 5), float32], Compiler="foo") { + %0 = nn.conv2d(%x, %y, padding=[0, 0, 0, 0]); + %1 = nn.relu(%0); + %2 = nn.relu(%1); + %3 = nn.leaky_relu(%0, alpha=0f); + (%2, %3) + })(%a, %b); + %5 = %4.0; + %6 = %4.1; + add(%5, %6) + } + """ + ) + + run(input(), expected(), 2, False, "foo", {"": [5, 6, 7, 9]}) + + +def test_with_tap(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %0 = nn.conv2d(%a, %b, padding=[0, 0, 0, 0]); // node 5 + %1 = nn.relu(%0); // node 6 + add(%1, %0) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 3, 32, 32), float32], %b: Tensor[(2, 3, 5, 5), float32]) { + %2 = (fn (%x, %y, Compiler="foo") { + %0 = nn.conv2d(%x, %y, padding=[0, 0, 0, 0]); + %1 = nn.relu(%0); + (%0, %1) + })(%a, %b); + %3 = %2.1; + %4 = %2.0; + add(%3, %4) + } + """ + ) + + # No rewrite since has tap + run(input(), input(), 2, False, "foo", {"": [5, 6]}) + # Rewrite + run(input(), expected(), 2, True, "foo", {"": [5, 6]}) + + +def test_no_cycles(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); // node 3 + %1 = add(%0, %b); + add(%1, %b) // node 5 + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + (fn(%x, %y, Compiler="foo") { + %0 = add(%x, %y); + %1 = add(%0, %y); + add(%1, %y) + })(%a, %b) + } + """ + ) + + # No rewrite since would create cycle + run(input(), input(), 2, False, "foo", {"": [3, 5]}) + # No cycle + run(input(), expected(), 2, False, "foo", {"": [3, 4, 5]}) + + +def test_labels_direct_connection(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + %0 = nn.relu(%a); // node 3 + %1 = nn.relu(%0); // node 4 + %2 = nn.relu(%1); // node 5 + %3 = nn.relu(%1); // node 6 + %4 = add(%2, %3); // node 7 + %5 = nn.relu(%4); // node 8 + %6 = nn.relu(%4); // node 9 + %7 = add(%5, %6); // node 10 + nn.relu(%7) // node 11 + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + (fn(%aa: Tensor[(5, 7), float32], Compiler="foo") { + %0 = nn.relu(%aa); + %4 = (fn(%y, Composite="a") { + %1 = nn.relu(%y); + %2 = nn.relu(%1); + %3 = nn.relu(%1); + add(%2, %3) + })(%0); + %7 = (fn(%z, Composite="b") { + %5 = nn.relu(%z); + %6 = nn.relu(%z); + add(%5, %6) + })(%4); + nn.relu(%7) + })(%a) + } + """ + ) + + run(input(), expected(), 1, False, "foo", {"": [3, 11], "a": [4, 5, 6, 7], "b": [8, 9, 10]}) + + +def test_labels_nested_tap(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + %0 = nn.relu(%a); // node 3 + %1 = nn.relu(%0); // node 4 + %2 = nn.relu(%1); // node 5 + %3 = nn.relu(%1); // node 6 + %4 = add(%2, %3); // node 7 + %5 = nn.relu(%4); // node 8 + %6 = nn.relu(%4); // node 9 + %7 = add(%5, %6); // node 10 + add(%2, %7) // node 11 + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32]) { + %0 = nn.relu(%a); + %9 = (fn(%x: Tensor[(5, 7), float32], Compiler="foo") { + %5 = (fn(%y, Composite="a") { + %1 = nn.relu(%y); + %2 = nn.relu(%1); + %3 = nn.relu(%1); + %4 = add(%2, %3); + (%2, %4) + })(%x); + %8 = (fn(%z, Composite="b") { + %6 = nn.relu(%z); + %7 = nn.relu(%z); + add(%6, %7) + })(%5.1); + (%5.0, %8) + })(%0); + add(%9.0, %9.1) + } + """ + ) + + run(input(), expected(), 2, True, "foo", {"a": [4, 5, 6, 7], "b": [8, 9, 10]}) + + +if __name__ == "__main__": + tvm.testing.main()