diff --git a/src/relay/collage/candidate_partition.cc b/src/relay/collage/candidate_partition.cc new file mode 100644 index 000000000000..9cccdf96d5a4 --- /dev/null +++ b/src/relay/collage/candidate_partition.cc @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_partition.cc + * \brief A potential partition in the Collage search. + */ + +#include "./candidate_partition.h" + +#include + +#include "./candidate_set.h" +#include "./partition_rule.h" +#include "./partition_spec.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +TVM_REGISTER_NODE_TYPE(CandidatePartitionNode); + +void CandidatePartitionNode::VisitAttrs(AttrVisitor* v) { + v->Visit("rule_name", &rule_name_); + v->Visit("sub_graph", &sub_graph_); + v->Visit("spec", &spec_); + // TODO(mbs): cost_ +} + +PartitionSpec CandidatePartitionNode::partition_spec() const { + return Downcast(spec_); +} + +std::string CandidatePartitionNode::partition_spec_name() const { + return Downcast(spec_)->spec_name_; +} + +Target CandidatePartitionNode::target() const { return Downcast(spec_)->target_; } + +std::string CandidatePartitionNode::ToSummary(const DataflowGraph& dataflow_graph) const { + std::ostringstream os; + os << sub_graph_->label_; + os << " | ("; + bool first = true; + for (PostDfsIndex index : sub_graph_->input_) { + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); + if (CanInline(sub_expr)) { + continue; + } + if (first) { + first = false; + } else { + os << ", "; + } + os << PrettyPrint(sub_expr->checked_type()); + } + os << ") -> ("; + first = true; + for (PostDfsIndex index : sub_graph_->exit_) { + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); + if (CanInline(sub_expr)) { + continue; + } + if (first) { + first = false; + } else { + os << ", "; + } + os << PrettyPrint(sub_expr->checked_type()); + } + os << ") | "; + os << sub_graph_->inside_.ToString(); + os << " | "; + os << partition_spec_name(); + os << " | "; + os << cost_.ToString(); + return os.str(); +} + +std::string CandidatePartitionNode::ToString() const { + std::ostringstream os; + os << "{rule_name=" << rule_name_; + os << ",sub_graph=" << sub_graph_->ToString(); + os << ",spec_name=" << partition_spec_name(); + if (!cost_.is_unknown()) { + os << ",cost=" << cost_.ToString(); + } + os << "}"; + return os.str(); +} + +CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph, + ObjectRef /* actually PartitionSpec */ spec, Cost cost) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_graph_ = std::move(sub_graph); + node->spec_ = std::move(spec); + node->cost_ = cost; + data_ = std::move(node); +} + +CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name) { + if (rule_name == candidate->rule_name_) { + return candidate; + } + auto* node = candidate.CopyOnWrite(); + node->rule_name_ = std::move(rule_name); + return GetRef(node); +} + +CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph) { + if (sub_graph == candidate->sub_graph_) { + return candidate; + } + auto* node = candidate.CopyOnWrite(); + node->sub_graph_ = std::move(sub_graph); + return GetRef(node); +} + +bool CandidatePartition::operator<(const CandidatePartition& that) const { + // Order lexicographically on sub-graphs. + if (*get()->sub_graph_.get() < *that->sub_graph_.get()) { + return true; + } + if (*that->sub_graph_.get() < *get()->sub_graph_.get()) { + return false; + } + // Break ties by rule name. + return get()->rule_name_ < that->rule_name_; +} + +bool CandidatePartition::AreTouching(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const { + return get()->spec_ == that->spec_ && + get()->sub_graph_.AreTouching(dataflow_graph, that->sub_graph_); +} + +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const { + ICHECK_EQ(get()->spec_, that->spec_); + return CandidatePartition(UnionLabels(get()->rule_name_, that->rule_name_), + get()->sub_graph_.DisjointUnion(dataflow_graph, that->sub_graph_), + get()->spec_, get()->cost_ + that->cost_); +} + +/*static*/ +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, + std::vector candidates) { + ICHECK_GT(candidates.size(), 1); + CandidatePartition result = candidates.front(); + for (size_t i = 1; i < candidates.size(); ++i) { + result = result.DisjointUnion(dataflow_graph, candidates[i]); + } + return result; +} + +/*static*/ +Expr CandidatePartition::ParallelRewrite(const DataflowGraph& dataflow_graph, + const std::vector& candidates) { + std::vector sub_graphs; + sub_graphs.reserve(candidates.size()); + for (const auto& candidate : candidates) { + sub_graphs.emplace_back(candidate->sub_graph_); + } + return SubGraph::ParallelRewrite(dataflow_graph, sub_graphs); +} + +/*static*/ +std::vector CandidatePartition::MaxCoalesce( + const DataflowGraph& dataflow_graph, std::vector candidates) { + VLOG(1) << "Running MaxCoalesce over " << candidates.size() << " candidates"; + // This is an eager version of using the simple (kOpaque, kOpaque) combiner. + + // Switch to set representation. + CandidateSet result_set(std::move(candidates)); + + // Until fixed point... + size_t num_rounds = 0; + while (result_set.PrepareForNextRound()) { + VLOG_CONTEXT << "round " << ++num_rounds; + VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index() + << " existing)"; + IndexSet removed_this_round(result_set.size()); // over candidate indexes! + + // Build map from post-dfs indices to the indices of candidates with corresponding entry node. + // NOTE: the index set is over candidate indices not post-dfs indices! + std::vector entry_map(dataflow_graph.size(), IndexSet(result_set.size())); + for (size_t i = 0; i < result_set.size(); ++i) { + CandidatePartition candidate = result_set.at(i); + for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) { + entry_map[entry_index].Add(i); + } + } + + for (size_t i = 0; i < result_set.size(); ++i) { + if (removed_this_round[i]) { + // Already merged. + continue; + } + CandidatePartition upstream = result_set.at(i); + // Narrow our search to just those candidates which could touch. + IndexSet possible_downstream(result_set.size()); // over candidate indexes! + for (PostDfsIndex output_index : upstream->sub_graph_->output_) { + possible_downstream = possible_downstream | entry_map[output_index]; + } + for (size_t j : possible_downstream) { + if (removed_this_round[j]) { + // Already merged. + continue; + } + if (i == j) { + // Ignore self. + continue; + } + CandidatePartition downstream = result_set.at(j); + if (!upstream.AreTouching(dataflow_graph, downstream)) { + continue; + } + CandidatePartition new_candidate = upstream.DisjointUnion(dataflow_graph, downstream); + VLOG(2) << "Merging upstream candidate " << upstream->ToString() + << " and downstream candidate " << downstream->ToString() << " to yield " + << new_candidate->ToString(); + result_set.Add(dataflow_graph, new_candidate); + result_set.Remove(upstream); + removed_this_round.Add(i); + result_set.Remove(downstream); + removed_this_round.Add(j); + } + } + } + + // Restore canonical order. + result_set.sort(); + + VLOG(1) << "MaxCoalesce produced " << result_set.size() << " candidates"; + return result_set.MovedCurrentCandidates(); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/candidate_partition.h b/src/relay/collage/candidate_partition.h new file mode 100644 index 000000000000..8267c3efcb2b --- /dev/null +++ b/src/relay/collage/candidate_partition.h @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_partition.cc + * \brief A potential partition in the Collage search. + */ + +#ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ +#define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ + +#include +#include + +#include +#include +#include + +#include "./cost.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +class PartitionSpec; + +/*! + * \brief A candidate partition w.r.t. the overall Relay model. + * + * We represent the partition as a sub-graph. This means not only can we represent the scope + * of Relay sub-expressions intended for a particular partition (or kernel), but we can also + * represent various conventions for encoding how the operators within the partition should be + * tagged for downstream processing. + */ +class CandidatePartitionNode : public Object { + public: + CandidatePartitionNode() = default; + + /*! + * \brief Combination of all the partition rule names which produced this candidate. + * For debugging and explainability. + */ + String rule_name_; + + /*! + * \brief The sub-graph of the overall expression matched by the partition rule. + */ + SubGraph sub_graph_; + + /*! + * \brief The partition specification which produced this candidate. + */ + ObjectRef /* actually PartitionSpec */ spec_; + + /*! + * \brief The (cached) cost of the partition. + * + * Initially Cost::Unknown, calculated and cached by EstimateCost. + */ + mutable Cost cost_ = Cost::Unknown(); + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns the partition specification which produced this candidate. + */ + PartitionSpec partition_spec() const; + + /*! + * \brief Returns the name of the partition specification which produced this candidate. + */ + std::string partition_spec_name() const; + + /*! + * \brief Returns the target of the partition specification which produced this candidate. + */ + Target target() const; + + /*! + * \brief Returns a brief description of candidate suitable for debugging output. + */ + std::string ToSummary(const DataflowGraph& dataflow_graph) const; + + std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.CandidatePartition"; + TVM_DECLARE_FINAL_OBJECT_INFO(CandidatePartitionNode, Object); +}; + +class CandidatePartition : public ObjectRef { + public: + CandidatePartition(String rule_name, SubGraph sub_graph, + ObjectRef /* actually PartitionSpec */ spec, Cost cost = Cost::Unknown()); + + bool operator<(const CandidatePartition& that) const; + + /*! + * \brief Returns true if this and \p that candidate are disjoint, have the same (or no) target, + * and touch. This does not imply the \p DisjointUnion of this and that will be valid. For + * example, the result may be too deep or have too many outputs. + */ + bool AreTouching(const DataflowGraph& dataflow_graph, const CandidatePartition& that) const; + + /*! + * \brief Returns the disjoint union of this and \p that. + */ + CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const; + + /*! + * \brief Returns the disjoint union of all \p candidates. + */ + static CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, + std::vector candidates); + + /*! + * \brief Returns the root expression of \p dataflow_graph rewritten to apply all the partitions + * implied by \p candidates. The candidates can be in any order but must be disjoint. + */ + static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, + const std::vector& candidates); + + /*! + * Eagerly merge all touching candidates for the same target. The candidates must be disjoint + * and have their Targets filled in. This is typically called on the optimal list of candidate + * partitions found by the Collage search in order to remove unnecessary partition boundaries. + * Ideally the search would never produce such candidates however to keep the search space + * manageable Collage may only consider candidate partitions up to a particular depth. + */ + static std::vector MaxCoalesce(const DataflowGraph& dataflow_graph, + std::vector candidates); + + TVM_DEFINE_OBJECT_REF_METHODS(CandidatePartition, ObjectRef, CandidatePartitionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CandidatePartitionNode); +}; + +CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name); +CandidatePartition WithTarget(CandidatePartition candidate, Target target); +CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph); + +struct CandidatePartitionHash { + size_t operator()(const CandidatePartition& candidate) const { + return candidate->sub_graph_->hash(); + } +}; + +struct CandidatePartitionEquals { + bool operator()(const CandidatePartition& left, const CandidatePartition& right) const { + return *left->sub_graph_.get() == *right->sub_graph_.get(); + } +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ diff --git a/src/relay/collage/candidate_set.cc b/src/relay/collage/candidate_set.cc new file mode 100644 index 000000000000..2c2a7eaf8d54 --- /dev/null +++ b/src/relay/collage/candidate_set.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_set.cc + * \brief Collects a set of candidate partitions. + */ + +#include "./candidate_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +CandidateSet::CandidateSet(std::vector candidates_to_add) + : candidates_to_add_(std::move(candidates_to_add)) { + for (const auto& candidate : candidates_to_add_) { + seen_.emplace(candidate); + } +} + +void CandidateSet::Add(const DataflowGraph& dataflow_graph, + const CandidatePartition& new_candidate) { + VLOG(2) << "adding " << new_candidate->ToString(); + if (seen_.count(new_candidate)) { + VLOG(2) << "already seen candidate, ignoring"; + return; + } + seen_.emplace(new_candidate); + candidates_to_add_.emplace_back(new_candidate); +} + +void CandidateSet::Remove(const CandidatePartition& old_candidate) { + ICHECK(seen_.count(old_candidate)); + VLOG(2) << "removing " << old_candidate->ToString(); + candidates_to_remove_.emplace_back(old_candidate); +} + +bool CandidateSet::PrepareForNextRound() { + size_t init_size = current_candidates_.size(); + for (const auto& candidate_to_remove : candidates_to_remove_) { + current_candidates_.erase( + std::remove(current_candidates_.begin(), current_candidates_.end(), candidate_to_remove), + current_candidates_.end()); + } + size_t num_removed = init_size - current_candidates_.size(); + candidates_to_remove_.clear(); + first_new_index_ = current_candidates_.size(); + for (const auto& new_candidate : candidates_to_add_) { + current_candidates_.push_back(new_candidate); + } + size_t num_added = candidates_to_add_.size(); + candidates_to_add_.clear(); + VLOG(1) << "removed " << num_removed << " and added " << num_added << " candidates"; + return num_removed + num_added > 0; +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/candidate_set.h b/src/relay/collage/candidate_set.h new file mode 100644 index 000000000000..4cb2c40e9500 --- /dev/null +++ b/src/relay/collage/candidate_set.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_set.h + * \brief Collects a set of candidate partitions. + */ + +#ifndef TVM_RELAY_COLLAGE_CANDIDATE_SET_H_ +#define TVM_RELAY_COLLAGE_CANDIDATE_SET_H_ + +#include +#include +#include +#include + +#include "./candidate_partition.h" +#include "./dataflow_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Holds a vector of current candidates and the additions/removals to apply to them. + */ +struct CandidateSet { + CandidateSet() = default; + + explicit CandidateSet(std::vector candidates_to_add); + + /*! + * \brief Schedule \p new_candidate for addition before the next round (unless it is not valid). + */ + void Add(const DataflowGraph& dataflow_graph, const CandidatePartition& new_candidate); + + /*! \brief Schedule \p old_candidate for removal before the next round. */ + void Remove(const CandidatePartition& old_candidate); + + /*! + * \brief Update \p current_candidates and \p first_new_index. Return false if no + * new candidates were added, in which case we have reached a fixed point. + */ + bool PrepareForNextRound(); + + size_t size() const { return current_candidates_.size(); } + + CandidatePartition operator[](size_t i) const { + ICHECK_LT(i, current_candidates_.size()); + return current_candidates_[i]; + } + CandidatePartition at(size_t i) const { return (*this)[i]; } + + size_t first_new_index() const { return first_new_index_; } + + void sort() { std::sort(current_candidates_.begin(), current_candidates_.end()); } + + std::vector MovedCurrentCandidates() { + return std::move(current_candidates_); + } + + private: + /*! + * \brief Index of first candidate in current_candidates added in last round. This can be used to + * avoid considering candidates or candidate combinations which have already been considered in an + * earlier round. + */ + size_t first_new_index_ = 0; + /*! \brief Candidates gathered in previous rounds. */ + std::vector current_candidates_; + /*! \brief New candidates gathered in the current round. */ + std::vector candidates_to_add_; + /*! \brief Existing candidates to remove before starting the next round. */ + std::vector candidates_to_remove_; + /*! \brief Which candidates have been seen so far and should not be added again. */ + std::unordered_set seen_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_CANDIDATE_SET_H_ diff --git a/src/relay/collage/cost.cc b/src/relay/collage/cost.cc new file mode 100644 index 000000000000..ae2eb8600ebd --- /dev/null +++ b/src/relay/collage/cost.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost.cc + * \brief Represents the estimated cost of a candidate partition. + */ + +#include "./cost.h" + +namespace tvm { +namespace relay { +namespace collage { + +std::string Cost::ToString() const { + if (is_invalid()) { + return "invalid"; + } else if (is_unknown()) { + return "unknown"; + } else if (value_ == 0.0) { + return "0"; + } else { + return std::to_string(value_ * 1e6) + "us"; + } +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/cost.h b/src/relay/collage/cost.h new file mode 100644 index 000000000000..8ae276d22078 --- /dev/null +++ b/src/relay/collage/cost.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost.h + * \brief Represents the estimated cost of a candidate partition. + */ +#ifndef TVM_RELAY_COLLAGE_COST_H_ +#define TVM_RELAY_COLLAGE_COST_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief The assumed cost for a candidate partition. Generally average execution time in seconds. + * However other cost functions are possible, for example to introduce a penalty for high memory + * use, etc. + */ +class Cost { + public: + Cost() = delete; + + static Cost Zero() { return Cost(0.0); } + + /*! + * \brief Returns the distinguished 'invalid' cost signaling a candidate partition is not + * supported by the intended target, for example because the sub-graph has an unsupported operator + * or the intermediate memory required exceeds some system limit. + */ + static Cost Invalid() { return Cost(std::numeric_limits::infinity()); } + + bool is_invalid() const { return std::isinf(value_) && value_ > 0.0; } + + /*! + * \brief Returns the distinguished 'unknown' cost, signaling fixed priorities should be used to + * choose the best partitions. This can be used to disable tuning and fallback to fixed rules, + * much as TVM will use an un-tuned kernel if no tuning records are available. + */ + static Cost Unknown() { return Cost(std::numeric_limits::quiet_NaN()); } + + bool is_unknown() const { return std::isnan(value_); } + + /*! \brief Returns cost with given finite, non-negative value. */ + static Cost Value(double value) { + ICHECK(!std::isnan(value) && !std::isinf(value) && value >= 0.0); + return Cost(value); + } + + bool is_value() const { return !std::isnan(value_) && !std::isinf(value_); } + + /*! \brief Return true if the less-than relation is defined for this and that. */ + bool are_comparable(Cost that) const { return !std::isnan(value_) && !std::isnan(that.value_); } + + /*! \brief Returns sum of this and that. */ + Cost operator+(Cost that) const { return Cost(value_ + that.value_); } + + /*! \brief Returns difference of this and that. */ + Cost operator-(Cost that) const { return Cost(value_ - that.value_); } + + /*! \brief Returns true if this is cheaper than that, assuming they are comparable. */ + bool operator<(Cost that) const { return value_ < that.value_; } + + std::string ToString() const; + + private: + explicit Cost(double value) : value_(value) {} + + /*! + * \brief Non-negative value or: + * - +inf if candidate partition is not feasible. + * - NaN if candidate partition has an unknown cost (priority may be used to break ties). + */ + double value_ = 0.0; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_COST_H_ diff --git a/src/relay/collage/partition_rule.cc b/src/relay/collage/partition_rule.cc new file mode 100644 index 000000000000..1cedbfc9d72c --- /dev/null +++ b/src/relay/collage/partition_rule.cc @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_rule.cc + * \brief Compositional partitioning rules. + */ + +#include "./partition_rule.h" + +#include + +#include "./partition_rule.h" +#include "./partition_spec.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +TVM_REGISTER_NODE_TYPE(PartitionRuleNode); + +void PartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector PartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + ICHECK(false) << "PartitionRuleNode::AllCandidates should be overridden in sub-class"; + return {}; +} + +std::string PartitionRuleNode::ToString() const { return ToDoc().str(); } + +Doc PartitionRuleNode::ToDoc() const { + Doc doc; + doc << GetTypeKey() << "(" << Doc::NewLine(2); + std::vector body_items; + AppendBodyItems(&body_items); + doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); + doc << ")"; + return doc; +} + +void PartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + body_items->emplace_back(); + body_items->back() << "rule_name=" << Doc::StrLiteral(rule_name_); +} + +PartitionRule::PartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +bool DefaultPatternPredicate(const Expr& matched_sub_expr) { return true; } + +TVM_REGISTER_NODE_TYPE(DFPatternPartitionRuleNode); + +void DFPatternPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector DFPatternPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running DFPatternPartitionRule(" << rule_name_ << ")"; + std::vector result; + DFPatternMatcher matcher(&dataflow_graph.indexed_graph()); + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); + if (!matcher.Match(pattern_, sub_expr)) { + continue; + } + if (!predicate_(sub_expr)) { + VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") has failing predicate"; + continue; + } + IndexSet inside = MatcherToIndexSet(matcher); + OpPatternKind kind; + String label; + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); + String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_; + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); + VLOG(2) << "DFPatternPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); + result.emplace_back(std::move(candidate)); + } + VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void DFPatternPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "pattern=" << PrettyPrint(pattern_); +} + +DFPatternPartitionRule::DFPatternPartitionRule(String rule_name, DFPattern pattern, + TPatternPredicate predicate) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->pattern_ = std::move(pattern); + node->predicate_ = std::move(predicate); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(CompositePartitionRuleNode); + +void CompositePartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector CompositePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running CompositePartitionRule(" << rule_name_ << ") over " << candidates.size() + << " sub-candidates"; + std::vector result; + FunctionAttrsMap attrs; + attrs.Set(attr::kComposite, rule_name_); + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); + CandidatePartition new_candidate = WithSubGraph( + WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); + VLOG(2) << "CompositePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "CompositePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void CompositePartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); +} + +CompositePartitionRule::CompositePartitionRule(String rule_name, PartitionRule sub_rule) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(PrimitivePartitionRuleNode); + +void PrimitivePartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector PrimitivePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running PrimitivePartitionRule(" << rule_name_ << ") over " << candidates.size() + << " sub-candidates"; + std::vector result; + FunctionAttrsMap attrs; + attrs.Set(attr::kPrimitive, Integer(1)); + if (spec->target_.IsExternalCodegen()) { + // The spec name will be the target kind name which is 1:1 with the "Compiler" attribute name. + attrs.Set(attr::kCompiler, spec->spec_name_); + } + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); + CandidatePartition new_candidate = WithSubGraph( + WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); + VLOG(2) << "PrimitivePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "PrimitivePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void PrimitivePartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); +} + +PrimitivePartitionRule::PrimitivePartitionRule(String rule_name, PartitionRule sub_rule) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(UnionPartitionRuleNode); + +void UnionPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector UnionPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector result; + for (const auto& sub_rule : sub_rules_) { + std::vector candidates = sub_rule->AllCandidates(dataflow_graph, spec); + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "UnionPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + } + VLOG(1) << "UnionPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates"; + return result; +} + +void UnionPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + for (const auto& sub_rule : sub_rules_) { + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule->ToDoc(); + } +} + +UnionPartitionRule::UnionPartitionRule(String rule_name, Array sub_rules) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rules_ = std::move(sub_rules); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OpCallByKindPartitionRuleNode); + +void OpCallByKindPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector OpCallByKindPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running OpCallByKindPartitionRule(" << rule_name_ << ")"; + std::vector result; + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + auto node = dataflow_graph.index_to_node(index); + Expr sub_expr = node->ref(); + if (sub_expr->IsInstance()) { + OpPatternKind kind; + String label; + std::tie(kind, label) = SubExprKindAndLabel(sub_expr); + if (kind <= kOutEWiseFusable) { + IndexSet inside(dataflow_graph.size(), {index}); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); + String rule_name = NestLabels(rule_name_, sub_graph->label_); + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); + VLOG(2) << "OpCallByKindPartitionRule(" << rule_name_ << ") yields " + << candidate->ToString(); + result.emplace_back(std::move(candidate)); + } + } + } + VLOG(1) << "OpCallByKindPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void OpCallByKindPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); +} + +OpCallByKindPartitionRule::OpCallByKindPartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OnlyValidPartitionRuleNode); + +void OnlyValidPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector OnlyValidPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running OnlyValidPartitionRule(" << rule_name_ << ") over " << candidates.size() + << " sub-candidates"; + std::vector result; + for (auto& candidate : candidates) { + if (!candidate->sub_graph_->IsValid(dataflow_graph, config_)) { + VLOG(2) << "Ignoring invalid candidate " << candidate->ToString(); + continue; + } + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "OnlyValidPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "OnlyValidPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void OnlyValidPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); + body_items->emplace_back(); + body_items->back() << "config=" << config_.ToString(); +} + +OnlyValidPartitionRule::OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, + const SubGraphConfig& config) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + node->config_ = config; + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(HostPartitionRuleNode); + +void HostPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector HostPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running HostPartitionRule(" << rule_name_ << ")"; + std::vector result; + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + if (MustBeLowered(dataflow_graph.index_to_node(index)->ref())) { + continue; + } + IndexSet inside(dataflow_graph.size(), {index}); + OpPatternKind kind; + String label; + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label); + String rule_name = NestLabels(rule_name_, sub_graph->label_); + // We'll a zero cost for the candidate since we'll never want to actually estimate the cost + // of this 'partition'. + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec, Cost::Zero()); + VLOG(2) << "HostPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); + result.push_back(candidate); + } + VLOG(1) << "HostPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates"; + return result; +} + +void HostPartitionRuleNode::AppendBodyItems(std::vector* body_items) const {} + +HostPartitionRule::HostPartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h new file mode 100644 index 000000000000..13f5c0b01d31 --- /dev/null +++ b/src/relay/collage/partition_rule.h @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_rule.h + * \brief Compositional partitioning rules. + */ + +#ifndef TVM_RELAY_COLLAGE_PARTITION_RULE_H_ +#define TVM_RELAY_COLLAGE_PARTITION_RULE_H_ + +#include +#include + +#include +#include + +#include "../../printer/doc.h" +#include "./candidate_partition.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Type of function to check if a matched sub-expression should be accepted by a rule. This + * can be used to, eg, reject operators of unsupported shape or dtype, or otherwise implement rules + * which are difficult to express in the dataflow pattern language directly. + */ +using TPatternPredicate = TypedPackedFunc; + +/*! + * \brief The default pattern predicate. Always returns true. + */ +bool DefaultPatternPredicate(const Expr& matched_sub_expr); + +/*! + * \brief Base class of all partition rules. + * + * A \p PartitionRule describes how to find a set of \p CandidatePartitions for a \p DataflowGraph. + * The candidates are allowed to overlap, and ultimately it is the job of the Collage searcher to + * find a selection of candidates which covers the whole Relay expression without overlap. Partition + * rules are paired with their \p Target and other 'top level' configuration in a \p PartitionSpec. + * + * We provide a set of 'base' partition rules which produce candidates from the dataflow graph + * directly. We also provide a set of 'combinator' partition rules which can produce new candidates + * from the results of an arbitrary sub-rule or sub-rules. By mixing these base and combinator + * rules we can express a wide variety of partition strategies and encoding conventions. + * + * There may be many thousands of candidates in flight during the Collage search. We take care to + * defer constructing or rewriting Relay expressions until absolutely necessary. We only pay for + * extracting a function to represent a candidate when we need to measure it's cost. And we only + * pay for rewriting the overall Relay expression to commit to a partitioning when the Collage + * search has completed. + * + * The base rules implemented so far: + * - \p DFPatternPartitionRule: Given a \p DFPattern and expression predicate, produces a candidate + * for every sub-graph matched by the pattern and predicate. Unlike the \p PatternRewriter, + * candidates are free to overlap. Used to bring BYOC patterns into the Collage framework. + * - \p OpCallByKindPartitionRule: Uses the "TOpPattern" attribute provided for every Relay + * operator to produce a candidate for every call to a 'fusable Relay operator'. Used to + * look ahead to how TVM will fuse sub-graphs. + * + * The combinator rules implemented so far: + * - \p CompositePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped + * by a "Composite" function. The "Composite" name is taken from the rule name. Used to indicate + * Relay operators (or groups of Relay operators) should be mapped to target-specific operators, + * both for BYOC and TVM external library integrations. + * - \p PrimitivePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped + * by a "Primitive" function, possibly with an additional "Compiler" attribute. Used to + * delineate a partition (or kernel). + * - \p UnionPartitionRule: Simply unions all the candidates from all sub-rules together. Used to + * combine individual \p DFPatternPartitionRules. + * - \p OnlyValidPartitionRule: Given a \p SubGraphConfig, ignores candidates with 'invalid' + * sub-graphs. Used to limit the maximum candidate depth, the number of independent outputs, + * and whether intermediate 'taps' are allowed. + * - \p HostPartitionRule: Produces candidates for all Relay expressions which could be + * 'left behind' for execution by the host (eg on the VM). This rule lets us simplify the + * overall Collage search algorithm. + * + * (Though not yet implemented, we'd like to allow a combinator rule which will union candidate + * based on their 'anchor' operators. This can be used to implement 'vertical' and 'horizontal' + * partition on more primitive candidates. Note that the \p SubGraph machinery supports + * multiple-input and -output sub-graphs and their validation, so horizontal partition is easy + * implement.) + */ +class PartitionRuleNode : public Object { + public: + /*! + * \brief A unique (over all rules for the same target) name for the rule. Rule names are + * combined and captured with \p PartitionCandidate rule names for debuggability and + * explainability. Some rules will copy the rule name into function attributes. + * + */ + String rule_name_; + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns all the possible candidate partitions according to this rule for the overall + * expression corresponding to \p dataflow_graph. The candidates will generally have unknown + * target and cost: the target will be filled in by the \p PartitionSpec, while the cost will + * be filled in lazily. + */ + virtual std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const; + + std::string ToString() const; + Doc ToDoc() const; + + protected: + virtual void AppendBodyItems(std::vector* body_items) const; + + public: + static constexpr const char* _type_key = "relay.collage.PartitionRule"; + static constexpr const uint32_t _type_child_slots = 10; + TVM_DECLARE_BASE_OBJECT_INFO(PartitionRuleNode, Object); +}; + +class PartitionRule : public ObjectRef { + public: + explicit PartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(PartitionRule, ObjectRef, PartitionRuleNode); +}; + +/*! + * \brief Partition rule which fires on all sub-expressions matching a dataflow-pattern and pattern + * predicate. It is valid for matching candidates to overlap. + */ +class DFPatternPartitionRuleNode : public PartitionRuleNode { + public: + /*! + * \brief Relay pattern. + */ + DFPattern pattern_; + + /*! + * \brief Predicate on matched sub-expression to decide if partition rule should fire. + */ + TPatternPredicate predicate_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.DFPatternPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(DFPatternPartitionRuleNode, PartitionRuleNode); +}; + +class DFPatternPartitionRule : public PartitionRule { + public: + DFPatternPartitionRule(String rule_name, DFPattern pattern, + TPatternPredicate predicate = DefaultPatternPredicate); + + TVM_DEFINE_OBJECT_REF_METHODS(DFPatternPartitionRule, PartitionRule, DFPatternPartitionRuleNode); +}; + +/*! + * \brief Partition rule which wraps candidates within a function with the "Composite" attribute + * bound to the given rule name. + * + * This is the standard way by which operators or operator groups are tagged as being supported + * by a particular externally provided function. It is up to the BYOC lowering function to + * recognize the "Composite" name and emit the appropriate code or call. + */ +class CompositePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-partition rule. */ + PartitionRule sub_rule_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.CompositePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompositePartitionRuleNode, PartitionRuleNode); +}; + +class CompositePartitionRule : public PartitionRule { + public: + CompositePartitionRule(String rule_name, PartitionRule sub_rule); + + TVM_DEFINE_OBJECT_REF_METHODS(CompositePartitionRule, PartitionRule, CompositePartitionRuleNode); +}; + +/*! + * \brief Partition rule which wraps candidates within a function with the "Primitive" attribute + * bound to 1. If the partition spec target(s) have the "compiler" attribute then that name is + * also added to the function as a "Compiler" attribute. + * + * This is the standard way by which sub-graphs are marked as being in a 'partition' who's + * compilation will be managed by an external BYOC toolchain. It can also be used to mark + * sub-graphs for lowering to a single kernel by the built-in TVM lowering machinery. + */ +class PrimitivePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-partition rule. */ + PartitionRule sub_rule_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.PrimitivePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimitivePartitionRuleNode, PartitionRuleNode); +}; + +class PrimitivePartitionRule : public PartitionRule { + public: + PrimitivePartitionRule(String rule_name, PartitionRule sub_rule); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimitivePartitionRule, PartitionRule, PrimitivePartitionRuleNode); +}; + +/*! + * \brief Partition rule which simply unions all matches from all sub-partition rules. + * + * This can be used to combine the results of a set of, eg, DFPatternPartitionRules. + */ +class UnionPartitionRuleNode : public PartitionRuleNode { + public: + Array sub_rules_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.UnionPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnionPartitionRuleNode, PartitionRuleNode); +}; + +class UnionPartitionRule : public PartitionRule { + public: + UnionPartitionRule(String rule_name, Array sub_rules); + + TVM_DEFINE_OBJECT_REF_METHODS(UnionPartitionRule, PartitionRule, UnionPartitionRuleNode) +}; + +/* + *! \brief Partition rule which places calls to Relay operators with a "TOpPattern" attribute of + * \p kOutEWiseFusable or less in their own singleton sub-graph. No other Relay sub-expressions + * (such as tuples or tuple projection) are selected, and it is up to outer partition rules to + * account for them. + */ +class OpCallByKindPartitionRuleNode : public PartitionRuleNode { + public: + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.OpCallByKindPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(OpCallByKindPartitionRuleNode, PartitionRuleNode); +}; + +class OpCallByKindPartitionRule : public PartitionRule { + public: + explicit OpCallByKindPartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(OpCallByKindPartitionRule, PartitionRule, + OpCallByKindPartitionRuleNode); +}; + +/*! + * \brief Partition rules which keeps only candidates from the sub-rule whose sub-groups are valid + * w.r.t. the given \p SubGraphConfig. + */ +class OnlyValidPartitionRuleNode : public PartitionRuleNode { + public: + PartitionRule sub_rule_; + SubGraphConfig config_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.OnlyValidPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(OnlyValidPartitionRuleNode, PartitionRuleNode); +}; + +class OnlyValidPartitionRule : public PartitionRule { + public: + OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, const SubGraphConfig& config); + + TVM_DEFINE_OBJECT_REF_METHODS(OnlyValidPartitionRule, PartitionRule, OnlyValidPartitionRuleNode); +}; + +/*! + * \brief Partition rule which selects nodes which can be 'left behind' to be executed by the host + * (eg on the VM). This includes most of the 'interstitial' Relay constructs, such a let bindings, + * operators on references, calls to non-operator functions, and so on. It can also include the + * construction of and projection from tuples which may not be supported within a partition. + */ +class HostPartitionRuleNode : public PartitionRuleNode { + public: + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.HostPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(HostPartitionRuleNode, PartitionRuleNode); +}; + +class HostPartitionRule : public PartitionRule { + public: + explicit HostPartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(HostPartitionRule, PartitionRule, HostPartitionRuleNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_PARTITION_RULE_H_ diff --git a/src/relay/collage/partition_spec.cc b/src/relay/collage/partition_spec.cc new file mode 100644 index 000000000000..b2095d0a594e --- /dev/null +++ b/src/relay/collage/partition_spec.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_spec.cc + * \brief Combine a \p PartitionRule with a \p Target. + */ + +#include "./partition_spec.h" + +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +String DefaultValidateSubGraphFunc(const Function& function) { return String(); } + +TVM_REGISTER_NODE_TYPE(PartitionSpecNode); + +void PartitionSpecNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector PartitionSpecNode::AllCandidates( + const DataflowGraph& dataflow_graph) const { + std::vector result; + // Make sure the target is in scope for inspection by any predicates in + // DFPatternPartitionRuleNode rules. + With target_scope(target_); + // Gather all the candidates. + std::vector candidates = + rule_->AllCandidates(dataflow_graph, GetRef(this)); + // Update the rules names. + for (const auto& candidate : candidates) { + ICHECK_EQ(candidate->spec_, GetRef(this)); + String rule_name = NestLabels(spec_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(candidate, std::move(rule_name)); + result.emplace_back(std::move(new_candidate)); + } + return result; +} + +std::string PartitionSpecNode::ToString() const { + Doc doc; + doc << "PartitionSpec(" << Doc::NewLine(2); + std::vector body_items; + body_items.emplace_back(); + body_items.back() << "spec_name=" << Doc::StrLiteral(spec_name_); + body_items.emplace_back(); + body_items.back() << "target=" << target_->ToDebugString(); + body_items.emplace_back(); + body_items.back() << "rule=" << rule_->ToDoc(); + doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); + doc << ")"; + return doc.str(); +} + +PartitionSpec::PartitionSpec(String spec_name, Target target, PartitionRule rule, + TValidateSubGraphFunc validate_sub_graph_func) { + auto node = runtime::make_object(); + node->spec_name_ = std::move(spec_name); + node->target_ = std::move(target); + node->rule_ = std::move(rule); + node->validate_sub_graph_func_ = std::move(validate_sub_graph_func); + data_ = std::move(node); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/partition_spec.h b/src/relay/collage/partition_spec.h new file mode 100644 index 000000000000..e8ce64c68468 --- /dev/null +++ b/src/relay/collage/partition_spec.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_spec.h + * \brief Combine a \p PartitionRule with a \p Target. + */ + +#ifndef TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ +#define TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ + +#include +#include +#include + +#include +#include + +#include "./partition_rule.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Type of functions for checking the validity of partitions before they proceed to lowering + * and codegen. The argument is the function extracted from the overall expression to represent + * the partition. The result is a non-empty error message string if the candidate should be + * rejected. + */ +using TValidateSubGraphFunc = TypedPackedFunc; + +/*! + * \brief The default validation function. Always returns the empty string, ie no error. + */ +String DefaultValidateSubGraphFunc(const Function& function); + +/*! + * \brief Pairs a \p PartitionRule with one or more \p Targets it can be used for. + */ +class PartitionSpecNode : public Object { + public: + /*! + * \brief Specification name to distinguish this spec from all others. Typically the BYOC + * 'compiler' name, "tvm", or "host". + */ + String spec_name_; + + /*! + * \brief The target all candidate partitions should be compiled for. + * + * It's tempting to support multiple targets here since. Eg the partitioning rules for + * TVM are the same irrespective of whether the target is "cuda" or "llvm", so it would make + * sense to build the candidate partitions first without committing to any target, then 'stamp' + * them for each target as the final step. + * + * However, we want to make sure any predicate in \p DFPatternPartitionRuleNode instances + * can have access to the current target instance. Eg the predicate may need to consult + * build-time configuration to decide what operators, shapes etc are actually supported. + * That implies the specific target is known when the candidate partitions are being constructed. + * + * So for now we'll just force each spec to have exactly one target. + */ + Target target_; + + /*! + * \brief The partition rule to use to gather candidates. + */ + PartitionRule rule_; + + /*! + * \brief The validation function to apply to each candidate's the extracted function before + * proceeding to lowering/codegen. + */ + TValidateSubGraphFunc validate_sub_graph_func_ = DefaultValidateSubGraphFunc; + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns all the candidate partitions found by this specification. The candidates + * will be for a specific target, but will not yet have an extracted function or cost. + */ + std::vector AllCandidates(const DataflowGraph& dataflow_graph) const; + + std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.PartitionSpec"; + TVM_DECLARE_FINAL_OBJECT_INFO(PartitionSpecNode, Object); +}; + +class PartitionSpec : public ObjectRef { + public: + PartitionSpec(String spec_name, Target target, PartitionRule rule, + TValidateSubGraphFunc validate_sub_graph_func = DefaultValidateSubGraphFunc); + + TVM_DEFINE_OBJECT_REF_METHODS(PartitionSpec, ObjectRef, PartitionSpecNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ diff --git a/tests/cpp/relay/collage/partition_rule_test.cc b/tests/cpp/relay/collage/partition_rule_test.cc new file mode 100644 index 000000000000..48cf478e09bb --- /dev/null +++ b/tests/cpp/relay/collage/partition_rule_test.cc @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/relay/collage/partition_rule.h" + +#include +#include +#include +#include + +#include "../../../src/relay/collage/partition_spec.h" + +namespace tvm { +namespace relay { +namespace { + +IRModule TestIRModule() { + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 3 + %1 = nn.relu(%0); // 4 + nn.relu(%1) // 5 + } + )"; + return parser::ParseModule("string", kModel); +} + +std::vector MakeCandidates( + const collage::DataflowGraph& graph, const runtime::String rule_name, + const collage::PartitionSpec& spec, const std::vector> index_sets) { + std::vector candidate_partitions; + for (const auto& indexes : index_sets) { + auto subgraph = collage::SubGraph(graph, collage::IndexSet(graph.size(), indexes)); + auto candidate = collage::CandidatePartition(rule_name, subgraph, spec); + candidate_partitions.emplace_back(std::move(candidate)); + } + return candidate_partitions; +} + +TEST(PartitionRule, DFPatternSingleOp) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto expected_candidates = MakeCandidates(graph, "relu_pattern", spec, {{4}, {5}}); + + auto candidates = rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), 2); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +TEST(PartitionRule, DFPatternOverlap) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto pattern = + IsOp("nn.relu")({IsOp("nn.relu")({IsWildcard()}) || IsOp("abs")({IsWildcard()})}); + auto rule = collage::DFPatternPartitionRule("relu+abs_pattern", pattern); + auto expected_candidates = MakeCandidates(graph, "relu+abs_pattern", spec, {{3, 4}, {4, 5}}); + + auto candidates = rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), 2); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +TEST(PartitionRule, Composite) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + constexpr const char* kExpectedMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); + %1 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="composite") { + nn.relu(%FunctionVar_01) + }; + %2 = %1(%0); + %3 = fn (%FunctionVar_0: Tensor[(10, 10), float32], Composite="composite") { + nn.relu(%FunctionVar_0) + }; + %3(%2) + } + )"; + auto expected_expr = + Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto composite_rule = collage::CompositePartitionRule("composite", df_rule); + + auto candidates = composite_rule->AllCandidates(graph, spec); + auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); + ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + } +} + +TEST(PartitionRule, PrimitiveTVM) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + constexpr const char* kExpectedMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); + %1 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Primitive=1) { + nn.relu(%FunctionVar_01) + }; + %2 = %1(%0); + %3 = fn (%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1) { + nn.relu(%FunctionVar_0) + }; + %3(%2) + } + )"; + auto expected_expr = + Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = collage::PrimitivePartitionRule("primitive", df_rule); + + auto candidates = primitive_rule->AllCandidates(graph, spec); + auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); + ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + } +} + +TVM_REGISTER_TARGET_KIND("test_ext_codegen", kDLCUDA) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + +TEST(PartitionRule, PrimitiveExternal) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("test_ext_codegen"); + auto spec = collage::PartitionSpec("test_ext_codegen", target, {}); + + { + constexpr const char* kExpectedMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); + %1 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Primitive=1, Compiler="test_ext_codegen") { + nn.relu(%FunctionVar_01) + }; + %2 = %1(%0); + %3 = fn (%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="test_ext_codegen") { + nn.relu(%FunctionVar_0) + }; + %3(%2) + } + )"; + auto expected_expr = + Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = collage::PrimitivePartitionRule("primitive", df_rule); + + auto candidates = primitive_rule->AllCandidates(graph, spec); + auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); + ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + } +} + +TEST(PartitionRule, Union) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto abs_pattern = IsOp("abs")({IsWildcard()}); + auto abs_rule = collage::DFPatternPartitionRule("abs_pattern", abs_pattern); + auto relu_pattern = IsOp("nn.relu")({IsWildcard()}); + auto relu_rule = collage::DFPatternPartitionRule("relu_pattern", relu_pattern); + auto union_rule = collage::UnionPartitionRule("union", {abs_rule, relu_rule}); + + auto abs_candidates = MakeCandidates(graph, "abs_pattern", spec, {{3}}); + auto relu_candidates = MakeCandidates(graph, "relu_pattern", spec, {{4}, {5}}); + + std::vector expected_candidates; + expected_candidates.insert(expected_candidates.end(), abs_candidates.begin(), + abs_candidates.end()); + expected_candidates.insert(expected_candidates.end(), relu_candidates.begin(), + relu_candidates.end()); + + auto candidates = union_rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), expected_candidates.size()); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +TEST(PartitionRule, OpCallByKind) { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 4 + %1 = add(%0, %x); // 5 + shape_of(%1) // 6 + } + )"; + auto main = Downcast(parser::ParseModule("string", kMod)->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto rule = collage::OpCallByKindPartitionRule("op_call_by_kind"); + auto expected_candidates = MakeCandidates(graph, "op_call_by_kind", spec, {{4}, {5}}); + + auto candidates = rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), expected_candidates.size()); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +} // namespace +} // namespace relay +} // namespace tvm