From 5903ed1d9adb79632d0740830a4de0f6fd7fc3cc Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 1 Jul 2022 13:33:01 -0700 Subject: [PATCH 1/3] [Collage] PartitionRule (though without CombinePartitionRule) See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md. (Special thanks to Matthew Barrett for authoring partition_rule_test.cc and suggesting a PR partitioning strategy.) Collage uses a small 'combinator library' of PartitionRule to decribe how candidate partitions can be extracted from a model for measurement and comparison. This introduces most of that machinery, however we defer the all important 'CombinerPartitionRule' for the next PR. Thus the rules at this stage can only express the sorts of DFPattern-based rules we find in most BYOC integrations, and cannot describe rules more traditionally associated with operator fusion. Based on #11981. --- src/relay/collage/candidate_partition.cc | 258 ++++++++++++ src/relay/collage/candidate_partition.h | 174 ++++++++ src/relay/collage/candidate_set.cc | 76 ++++ src/relay/collage/candidate_set.h | 99 +++++ src/relay/collage/cost.cc | 45 +++ src/relay/collage/cost.h | 103 +++++ src/relay/collage/partition_rule.cc | 372 ++++++++++++++++++ src/relay/collage/partition_rule.h | 355 +++++++++++++++++ src/relay/collage/partition_spec.cc | 87 ++++ src/relay/collage/partition_spec.h | 120 ++++++ .../cpp/relay/collage/partition_rule_test.cc | 272 +++++++++++++ 11 files changed, 1961 insertions(+) create mode 100644 src/relay/collage/candidate_partition.cc create mode 100644 src/relay/collage/candidate_partition.h create mode 100644 src/relay/collage/candidate_set.cc create mode 100644 src/relay/collage/candidate_set.h create mode 100644 src/relay/collage/cost.cc create mode 100644 src/relay/collage/cost.h create mode 100644 src/relay/collage/partition_rule.cc create mode 100644 src/relay/collage/partition_rule.h create mode 100644 src/relay/collage/partition_spec.cc create mode 100644 src/relay/collage/partition_spec.h create mode 100644 tests/cpp/relay/collage/partition_rule_test.cc diff --git a/src/relay/collage/candidate_partition.cc b/src/relay/collage/candidate_partition.cc new file mode 100644 index 000000000000..9cccdf96d5a4 --- /dev/null +++ b/src/relay/collage/candidate_partition.cc @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_partition.cc + * \brief A potential partition in the Collage search. + */ + +#include "./candidate_partition.h" + +#include + +#include "./candidate_set.h" +#include "./partition_rule.h" +#include "./partition_spec.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +TVM_REGISTER_NODE_TYPE(CandidatePartitionNode); + +void CandidatePartitionNode::VisitAttrs(AttrVisitor* v) { + v->Visit("rule_name", &rule_name_); + v->Visit("sub_graph", &sub_graph_); + v->Visit("spec", &spec_); + // TODO(mbs): cost_ +} + +PartitionSpec CandidatePartitionNode::partition_spec() const { + return Downcast(spec_); +} + +std::string CandidatePartitionNode::partition_spec_name() const { + return Downcast(spec_)->spec_name_; +} + +Target CandidatePartitionNode::target() const { return Downcast(spec_)->target_; } + +std::string CandidatePartitionNode::ToSummary(const DataflowGraph& dataflow_graph) const { + std::ostringstream os; + os << sub_graph_->label_; + os << " | ("; + bool first = true; + for (PostDfsIndex index : sub_graph_->input_) { + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); + if (CanInline(sub_expr)) { + continue; + } + if (first) { + first = false; + } else { + os << ", "; + } + os << PrettyPrint(sub_expr->checked_type()); + } + os << ") -> ("; + first = true; + for (PostDfsIndex index : sub_graph_->exit_) { + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); + if (CanInline(sub_expr)) { + continue; + } + if (first) { + first = false; + } else { + os << ", "; + } + os << PrettyPrint(sub_expr->checked_type()); + } + os << ") | "; + os << sub_graph_->inside_.ToString(); + os << " | "; + os << partition_spec_name(); + os << " | "; + os << cost_.ToString(); + return os.str(); +} + +std::string CandidatePartitionNode::ToString() const { + std::ostringstream os; + os << "{rule_name=" << rule_name_; + os << ",sub_graph=" << sub_graph_->ToString(); + os << ",spec_name=" << partition_spec_name(); + if (!cost_.is_unknown()) { + os << ",cost=" << cost_.ToString(); + } + os << "}"; + return os.str(); +} + +CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph, + ObjectRef /* actually PartitionSpec */ spec, Cost cost) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_graph_ = std::move(sub_graph); + node->spec_ = std::move(spec); + node->cost_ = cost; + data_ = std::move(node); +} + +CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name) { + if (rule_name == candidate->rule_name_) { + return candidate; + } + auto* node = candidate.CopyOnWrite(); + node->rule_name_ = std::move(rule_name); + return GetRef(node); +} + +CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph) { + if (sub_graph == candidate->sub_graph_) { + return candidate; + } + auto* node = candidate.CopyOnWrite(); + node->sub_graph_ = std::move(sub_graph); + return GetRef(node); +} + +bool CandidatePartition::operator<(const CandidatePartition& that) const { + // Order lexicographically on sub-graphs. + if (*get()->sub_graph_.get() < *that->sub_graph_.get()) { + return true; + } + if (*that->sub_graph_.get() < *get()->sub_graph_.get()) { + return false; + } + // Break ties by rule name. + return get()->rule_name_ < that->rule_name_; +} + +bool CandidatePartition::AreTouching(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const { + return get()->spec_ == that->spec_ && + get()->sub_graph_.AreTouching(dataflow_graph, that->sub_graph_); +} + +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const { + ICHECK_EQ(get()->spec_, that->spec_); + return CandidatePartition(UnionLabels(get()->rule_name_, that->rule_name_), + get()->sub_graph_.DisjointUnion(dataflow_graph, that->sub_graph_), + get()->spec_, get()->cost_ + that->cost_); +} + +/*static*/ +CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, + std::vector candidates) { + ICHECK_GT(candidates.size(), 1); + CandidatePartition result = candidates.front(); + for (size_t i = 1; i < candidates.size(); ++i) { + result = result.DisjointUnion(dataflow_graph, candidates[i]); + } + return result; +} + +/*static*/ +Expr CandidatePartition::ParallelRewrite(const DataflowGraph& dataflow_graph, + const std::vector& candidates) { + std::vector sub_graphs; + sub_graphs.reserve(candidates.size()); + for (const auto& candidate : candidates) { + sub_graphs.emplace_back(candidate->sub_graph_); + } + return SubGraph::ParallelRewrite(dataflow_graph, sub_graphs); +} + +/*static*/ +std::vector CandidatePartition::MaxCoalesce( + const DataflowGraph& dataflow_graph, std::vector candidates) { + VLOG(1) << "Running MaxCoalesce over " << candidates.size() << " candidates"; + // This is an eager version of using the simple (kOpaque, kOpaque) combiner. + + // Switch to set representation. + CandidateSet result_set(std::move(candidates)); + + // Until fixed point... + size_t num_rounds = 0; + while (result_set.PrepareForNextRound()) { + VLOG_CONTEXT << "round " << ++num_rounds; + VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index() + << " existing)"; + IndexSet removed_this_round(result_set.size()); // over candidate indexes! + + // Build map from post-dfs indices to the indices of candidates with corresponding entry node. + // NOTE: the index set is over candidate indices not post-dfs indices! + std::vector entry_map(dataflow_graph.size(), IndexSet(result_set.size())); + for (size_t i = 0; i < result_set.size(); ++i) { + CandidatePartition candidate = result_set.at(i); + for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) { + entry_map[entry_index].Add(i); + } + } + + for (size_t i = 0; i < result_set.size(); ++i) { + if (removed_this_round[i]) { + // Already merged. + continue; + } + CandidatePartition upstream = result_set.at(i); + // Narrow our search to just those candidates which could touch. + IndexSet possible_downstream(result_set.size()); // over candidate indexes! + for (PostDfsIndex output_index : upstream->sub_graph_->output_) { + possible_downstream = possible_downstream | entry_map[output_index]; + } + for (size_t j : possible_downstream) { + if (removed_this_round[j]) { + // Already merged. + continue; + } + if (i == j) { + // Ignore self. + continue; + } + CandidatePartition downstream = result_set.at(j); + if (!upstream.AreTouching(dataflow_graph, downstream)) { + continue; + } + CandidatePartition new_candidate = upstream.DisjointUnion(dataflow_graph, downstream); + VLOG(2) << "Merging upstream candidate " << upstream->ToString() + << " and downstream candidate " << downstream->ToString() << " to yield " + << new_candidate->ToString(); + result_set.Add(dataflow_graph, new_candidate); + result_set.Remove(upstream); + removed_this_round.Add(i); + result_set.Remove(downstream); + removed_this_round.Add(j); + } + } + } + + // Restore canonical order. + result_set.sort(); + + VLOG(1) << "MaxCoalesce produced " << result_set.size() << " candidates"; + return result_set.MovedCurrentCandidates(); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/candidate_partition.h b/src/relay/collage/candidate_partition.h new file mode 100644 index 000000000000..8267c3efcb2b --- /dev/null +++ b/src/relay/collage/candidate_partition.h @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_partition.cc + * \brief A potential partition in the Collage search. + */ + +#ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ +#define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ + +#include +#include + +#include +#include +#include + +#include "./cost.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +class PartitionSpec; + +/*! + * \brief A candidate partition w.r.t. the overall Relay model. + * + * We represent the partition as a sub-graph. This means not only can we represent the scope + * of Relay sub-expressions intended for a particular partition (or kernel), but we can also + * represent various conventions for encoding how the operators within the partition should be + * tagged for downstream processing. + */ +class CandidatePartitionNode : public Object { + public: + CandidatePartitionNode() = default; + + /*! + * \brief Combination of all the partition rule names which produced this candidate. + * For debugging and explainability. + */ + String rule_name_; + + /*! + * \brief The sub-graph of the overall expression matched by the partition rule. + */ + SubGraph sub_graph_; + + /*! + * \brief The partition specification which produced this candidate. + */ + ObjectRef /* actually PartitionSpec */ spec_; + + /*! + * \brief The (cached) cost of the partition. + * + * Initially Cost::Unknown, calculated and cached by EstimateCost. + */ + mutable Cost cost_ = Cost::Unknown(); + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns the partition specification which produced this candidate. + */ + PartitionSpec partition_spec() const; + + /*! + * \brief Returns the name of the partition specification which produced this candidate. + */ + std::string partition_spec_name() const; + + /*! + * \brief Returns the target of the partition specification which produced this candidate. + */ + Target target() const; + + /*! + * \brief Returns a brief description of candidate suitable for debugging output. + */ + std::string ToSummary(const DataflowGraph& dataflow_graph) const; + + std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.CandidatePartition"; + TVM_DECLARE_FINAL_OBJECT_INFO(CandidatePartitionNode, Object); +}; + +class CandidatePartition : public ObjectRef { + public: + CandidatePartition(String rule_name, SubGraph sub_graph, + ObjectRef /* actually PartitionSpec */ spec, Cost cost = Cost::Unknown()); + + bool operator<(const CandidatePartition& that) const; + + /*! + * \brief Returns true if this and \p that candidate are disjoint, have the same (or no) target, + * and touch. This does not imply the \p DisjointUnion of this and that will be valid. For + * example, the result may be too deep or have too many outputs. + */ + bool AreTouching(const DataflowGraph& dataflow_graph, const CandidatePartition& that) const; + + /*! + * \brief Returns the disjoint union of this and \p that. + */ + CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, + const CandidatePartition& that) const; + + /*! + * \brief Returns the disjoint union of all \p candidates. + */ + static CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, + std::vector candidates); + + /*! + * \brief Returns the root expression of \p dataflow_graph rewritten to apply all the partitions + * implied by \p candidates. The candidates can be in any order but must be disjoint. + */ + static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, + const std::vector& candidates); + + /*! + * Eagerly merge all touching candidates for the same target. The candidates must be disjoint + * and have their Targets filled in. This is typically called on the optimal list of candidate + * partitions found by the Collage search in order to remove unnecessary partition boundaries. + * Ideally the search would never produce such candidates however to keep the search space + * manageable Collage may only consider candidate partitions up to a particular depth. + */ + static std::vector MaxCoalesce(const DataflowGraph& dataflow_graph, + std::vector candidates); + + TVM_DEFINE_OBJECT_REF_METHODS(CandidatePartition, ObjectRef, CandidatePartitionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CandidatePartitionNode); +}; + +CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name); +CandidatePartition WithTarget(CandidatePartition candidate, Target target); +CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph); + +struct CandidatePartitionHash { + size_t operator()(const CandidatePartition& candidate) const { + return candidate->sub_graph_->hash(); + } +}; + +struct CandidatePartitionEquals { + bool operator()(const CandidatePartition& left, const CandidatePartition& right) const { + return *left->sub_graph_.get() == *right->sub_graph_.get(); + } +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ diff --git a/src/relay/collage/candidate_set.cc b/src/relay/collage/candidate_set.cc new file mode 100644 index 000000000000..2c2a7eaf8d54 --- /dev/null +++ b/src/relay/collage/candidate_set.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_set.cc + * \brief Collects a set of candidate partitions. + */ + +#include "./candidate_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +CandidateSet::CandidateSet(std::vector candidates_to_add) + : candidates_to_add_(std::move(candidates_to_add)) { + for (const auto& candidate : candidates_to_add_) { + seen_.emplace(candidate); + } +} + +void CandidateSet::Add(const DataflowGraph& dataflow_graph, + const CandidatePartition& new_candidate) { + VLOG(2) << "adding " << new_candidate->ToString(); + if (seen_.count(new_candidate)) { + VLOG(2) << "already seen candidate, ignoring"; + return; + } + seen_.emplace(new_candidate); + candidates_to_add_.emplace_back(new_candidate); +} + +void CandidateSet::Remove(const CandidatePartition& old_candidate) { + ICHECK(seen_.count(old_candidate)); + VLOG(2) << "removing " << old_candidate->ToString(); + candidates_to_remove_.emplace_back(old_candidate); +} + +bool CandidateSet::PrepareForNextRound() { + size_t init_size = current_candidates_.size(); + for (const auto& candidate_to_remove : candidates_to_remove_) { + current_candidates_.erase( + std::remove(current_candidates_.begin(), current_candidates_.end(), candidate_to_remove), + current_candidates_.end()); + } + size_t num_removed = init_size - current_candidates_.size(); + candidates_to_remove_.clear(); + first_new_index_ = current_candidates_.size(); + for (const auto& new_candidate : candidates_to_add_) { + current_candidates_.push_back(new_candidate); + } + size_t num_added = candidates_to_add_.size(); + candidates_to_add_.clear(); + VLOG(1) << "removed " << num_removed << " and added " << num_added << " candidates"; + return num_removed + num_added > 0; +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/candidate_set.h b/src/relay/collage/candidate_set.h new file mode 100644 index 000000000000..4cb2c40e9500 --- /dev/null +++ b/src/relay/collage/candidate_set.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_set.h + * \brief Collects a set of candidate partitions. + */ + +#ifndef TVM_RELAY_COLLAGE_CANDIDATE_SET_H_ +#define TVM_RELAY_COLLAGE_CANDIDATE_SET_H_ + +#include +#include +#include +#include + +#include "./candidate_partition.h" +#include "./dataflow_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Holds a vector of current candidates and the additions/removals to apply to them. + */ +struct CandidateSet { + CandidateSet() = default; + + explicit CandidateSet(std::vector candidates_to_add); + + /*! + * \brief Schedule \p new_candidate for addition before the next round (unless it is not valid). + */ + void Add(const DataflowGraph& dataflow_graph, const CandidatePartition& new_candidate); + + /*! \brief Schedule \p old_candidate for removal before the next round. */ + void Remove(const CandidatePartition& old_candidate); + + /*! + * \brief Update \p current_candidates and \p first_new_index. Return false if no + * new candidates were added, in which case we have reached a fixed point. + */ + bool PrepareForNextRound(); + + size_t size() const { return current_candidates_.size(); } + + CandidatePartition operator[](size_t i) const { + ICHECK_LT(i, current_candidates_.size()); + return current_candidates_[i]; + } + CandidatePartition at(size_t i) const { return (*this)[i]; } + + size_t first_new_index() const { return first_new_index_; } + + void sort() { std::sort(current_candidates_.begin(), current_candidates_.end()); } + + std::vector MovedCurrentCandidates() { + return std::move(current_candidates_); + } + + private: + /*! + * \brief Index of first candidate in current_candidates added in last round. This can be used to + * avoid considering candidates or candidate combinations which have already been considered in an + * earlier round. + */ + size_t first_new_index_ = 0; + /*! \brief Candidates gathered in previous rounds. */ + std::vector current_candidates_; + /*! \brief New candidates gathered in the current round. */ + std::vector candidates_to_add_; + /*! \brief Existing candidates to remove before starting the next round. */ + std::vector candidates_to_remove_; + /*! \brief Which candidates have been seen so far and should not be added again. */ + std::unordered_set seen_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_CANDIDATE_SET_H_ diff --git a/src/relay/collage/cost.cc b/src/relay/collage/cost.cc new file mode 100644 index 000000000000..ae2eb8600ebd --- /dev/null +++ b/src/relay/collage/cost.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost.cc + * \brief Represents the estimated cost of a candidate partition. + */ + +#include "./cost.h" + +namespace tvm { +namespace relay { +namespace collage { + +std::string Cost::ToString() const { + if (is_invalid()) { + return "invalid"; + } else if (is_unknown()) { + return "unknown"; + } else if (value_ == 0.0) { + return "0"; + } else { + return std::to_string(value_ * 1e6) + "us"; + } +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/cost.h b/src/relay/collage/cost.h new file mode 100644 index 000000000000..8ae276d22078 --- /dev/null +++ b/src/relay/collage/cost.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost.h + * \brief Represents the estimated cost of a candidate partition. + */ +#ifndef TVM_RELAY_COLLAGE_COST_H_ +#define TVM_RELAY_COLLAGE_COST_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief The assumed cost for a candidate partition. Generally average execution time in seconds. + * However other cost functions are possible, for example to introduce a penalty for high memory + * use, etc. + */ +class Cost { + public: + Cost() = delete; + + static Cost Zero() { return Cost(0.0); } + + /*! + * \brief Returns the distinguished 'invalid' cost signaling a candidate partition is not + * supported by the intended target, for example because the sub-graph has an unsupported operator + * or the intermediate memory required exceeds some system limit. + */ + static Cost Invalid() { return Cost(std::numeric_limits::infinity()); } + + bool is_invalid() const { return std::isinf(value_) && value_ > 0.0; } + + /*! + * \brief Returns the distinguished 'unknown' cost, signaling fixed priorities should be used to + * choose the best partitions. This can be used to disable tuning and fallback to fixed rules, + * much as TVM will use an un-tuned kernel if no tuning records are available. + */ + static Cost Unknown() { return Cost(std::numeric_limits::quiet_NaN()); } + + bool is_unknown() const { return std::isnan(value_); } + + /*! \brief Returns cost with given finite, non-negative value. */ + static Cost Value(double value) { + ICHECK(!std::isnan(value) && !std::isinf(value) && value >= 0.0); + return Cost(value); + } + + bool is_value() const { return !std::isnan(value_) && !std::isinf(value_); } + + /*! \brief Return true if the less-than relation is defined for this and that. */ + bool are_comparable(Cost that) const { return !std::isnan(value_) && !std::isnan(that.value_); } + + /*! \brief Returns sum of this and that. */ + Cost operator+(Cost that) const { return Cost(value_ + that.value_); } + + /*! \brief Returns difference of this and that. */ + Cost operator-(Cost that) const { return Cost(value_ - that.value_); } + + /*! \brief Returns true if this is cheaper than that, assuming they are comparable. */ + bool operator<(Cost that) const { return value_ < that.value_; } + + std::string ToString() const; + + private: + explicit Cost(double value) : value_(value) {} + + /*! + * \brief Non-negative value or: + * - +inf if candidate partition is not feasible. + * - NaN if candidate partition has an unknown cost (priority may be used to break ties). + */ + double value_ = 0.0; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_COST_H_ diff --git a/src/relay/collage/partition_rule.cc b/src/relay/collage/partition_rule.cc new file mode 100644 index 000000000000..1cedbfc9d72c --- /dev/null +++ b/src/relay/collage/partition_rule.cc @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_rule.cc + * \brief Compositional partitioning rules. + */ + +#include "./partition_rule.h" + +#include + +#include "./partition_rule.h" +#include "./partition_spec.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +TVM_REGISTER_NODE_TYPE(PartitionRuleNode); + +void PartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector PartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + ICHECK(false) << "PartitionRuleNode::AllCandidates should be overridden in sub-class"; + return {}; +} + +std::string PartitionRuleNode::ToString() const { return ToDoc().str(); } + +Doc PartitionRuleNode::ToDoc() const { + Doc doc; + doc << GetTypeKey() << "(" << Doc::NewLine(2); + std::vector body_items; + AppendBodyItems(&body_items); + doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); + doc << ")"; + return doc; +} + +void PartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + body_items->emplace_back(); + body_items->back() << "rule_name=" << Doc::StrLiteral(rule_name_); +} + +PartitionRule::PartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +bool DefaultPatternPredicate(const Expr& matched_sub_expr) { return true; } + +TVM_REGISTER_NODE_TYPE(DFPatternPartitionRuleNode); + +void DFPatternPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector DFPatternPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running DFPatternPartitionRule(" << rule_name_ << ")"; + std::vector result; + DFPatternMatcher matcher(&dataflow_graph.indexed_graph()); + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); + if (!matcher.Match(pattern_, sub_expr)) { + continue; + } + if (!predicate_(sub_expr)) { + VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") has failing predicate"; + continue; + } + IndexSet inside = MatcherToIndexSet(matcher); + OpPatternKind kind; + String label; + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); + String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_; + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); + VLOG(2) << "DFPatternPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); + result.emplace_back(std::move(candidate)); + } + VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void DFPatternPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "pattern=" << PrettyPrint(pattern_); +} + +DFPatternPartitionRule::DFPatternPartitionRule(String rule_name, DFPattern pattern, + TPatternPredicate predicate) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->pattern_ = std::move(pattern); + node->predicate_ = std::move(predicate); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(CompositePartitionRuleNode); + +void CompositePartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector CompositePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running CompositePartitionRule(" << rule_name_ << ") over " << candidates.size() + << " sub-candidates"; + std::vector result; + FunctionAttrsMap attrs; + attrs.Set(attr::kComposite, rule_name_); + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); + CandidatePartition new_candidate = WithSubGraph( + WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); + VLOG(2) << "CompositePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "CompositePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void CompositePartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); +} + +CompositePartitionRule::CompositePartitionRule(String rule_name, PartitionRule sub_rule) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(PrimitivePartitionRuleNode); + +void PrimitivePartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector PrimitivePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running PrimitivePartitionRule(" << rule_name_ << ") over " << candidates.size() + << " sub-candidates"; + std::vector result; + FunctionAttrsMap attrs; + attrs.Set(attr::kPrimitive, Integer(1)); + if (spec->target_.IsExternalCodegen()) { + // The spec name will be the target kind name which is 1:1 with the "Compiler" attribute name. + attrs.Set(attr::kCompiler, spec->spec_name_); + } + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); + CandidatePartition new_candidate = WithSubGraph( + WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); + VLOG(2) << "PrimitivePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "PrimitivePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void PrimitivePartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); +} + +PrimitivePartitionRule::PrimitivePartitionRule(String rule_name, PartitionRule sub_rule) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(UnionPartitionRuleNode); + +void UnionPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector UnionPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector result; + for (const auto& sub_rule : sub_rules_) { + std::vector candidates = sub_rule->AllCandidates(dataflow_graph, spec); + for (auto& candidate : candidates) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "UnionPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + } + VLOG(1) << "UnionPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates"; + return result; +} + +void UnionPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + for (const auto& sub_rule : sub_rules_) { + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule->ToDoc(); + } +} + +UnionPartitionRule::UnionPartitionRule(String rule_name, Array sub_rules) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rules_ = std::move(sub_rules); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OpCallByKindPartitionRuleNode); + +void OpCallByKindPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector OpCallByKindPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running OpCallByKindPartitionRule(" << rule_name_ << ")"; + std::vector result; + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + auto node = dataflow_graph.index_to_node(index); + Expr sub_expr = node->ref(); + if (sub_expr->IsInstance()) { + OpPatternKind kind; + String label; + std::tie(kind, label) = SubExprKindAndLabel(sub_expr); + if (kind <= kOutEWiseFusable) { + IndexSet inside(dataflow_graph.size(), {index}); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); + String rule_name = NestLabels(rule_name_, sub_graph->label_); + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); + VLOG(2) << "OpCallByKindPartitionRule(" << rule_name_ << ") yields " + << candidate->ToString(); + result.emplace_back(std::move(candidate)); + } + } + } + VLOG(1) << "OpCallByKindPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void OpCallByKindPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); +} + +OpCallByKindPartitionRule::OpCallByKindPartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OnlyValidPartitionRuleNode); + +void OnlyValidPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector OnlyValidPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running OnlyValidPartitionRule(" << rule_name_ << ") over " << candidates.size() + << " sub-candidates"; + std::vector result; + for (auto& candidate : candidates) { + if (!candidate->sub_graph_->IsValid(dataflow_graph, config_)) { + VLOG(2) << "Ignoring invalid candidate " << candidate->ToString(); + continue; + } + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "OnlyValidPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "OnlyValidPartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void OnlyValidPartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); + body_items->emplace_back(); + body_items->back() << "config=" << config_.ToString(); +} + +OnlyValidPartitionRule::OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, + const SubGraphConfig& config) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + node->config_ = config; + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(HostPartitionRuleNode); + +void HostPartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector HostPartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + VLOG(1) << "running HostPartitionRule(" << rule_name_ << ")"; + std::vector result; + for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { + if (MustBeLowered(dataflow_graph.index_to_node(index)->ref())) { + continue; + } + IndexSet inside(dataflow_graph.size(), {index}); + OpPatternKind kind; + String label; + std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label); + String rule_name = NestLabels(rule_name_, sub_graph->label_); + // We'll a zero cost for the candidate since we'll never want to actually estimate the cost + // of this 'partition'. + CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec, Cost::Zero()); + VLOG(2) << "HostPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); + result.push_back(candidate); + } + VLOG(1) << "HostPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates"; + return result; +} + +void HostPartitionRuleNode::AppendBodyItems(std::vector* body_items) const {} + +HostPartitionRule::HostPartitionRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h new file mode 100644 index 000000000000..13f5c0b01d31 --- /dev/null +++ b/src/relay/collage/partition_rule.h @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_rule.h + * \brief Compositional partitioning rules. + */ + +#ifndef TVM_RELAY_COLLAGE_PARTITION_RULE_H_ +#define TVM_RELAY_COLLAGE_PARTITION_RULE_H_ + +#include +#include + +#include +#include + +#include "../../printer/doc.h" +#include "./candidate_partition.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Type of function to check if a matched sub-expression should be accepted by a rule. This + * can be used to, eg, reject operators of unsupported shape or dtype, or otherwise implement rules + * which are difficult to express in the dataflow pattern language directly. + */ +using TPatternPredicate = TypedPackedFunc; + +/*! + * \brief The default pattern predicate. Always returns true. + */ +bool DefaultPatternPredicate(const Expr& matched_sub_expr); + +/*! + * \brief Base class of all partition rules. + * + * A \p PartitionRule describes how to find a set of \p CandidatePartitions for a \p DataflowGraph. + * The candidates are allowed to overlap, and ultimately it is the job of the Collage searcher to + * find a selection of candidates which covers the whole Relay expression without overlap. Partition + * rules are paired with their \p Target and other 'top level' configuration in a \p PartitionSpec. + * + * We provide a set of 'base' partition rules which produce candidates from the dataflow graph + * directly. We also provide a set of 'combinator' partition rules which can produce new candidates + * from the results of an arbitrary sub-rule or sub-rules. By mixing these base and combinator + * rules we can express a wide variety of partition strategies and encoding conventions. + * + * There may be many thousands of candidates in flight during the Collage search. We take care to + * defer constructing or rewriting Relay expressions until absolutely necessary. We only pay for + * extracting a function to represent a candidate when we need to measure it's cost. And we only + * pay for rewriting the overall Relay expression to commit to a partitioning when the Collage + * search has completed. + * + * The base rules implemented so far: + * - \p DFPatternPartitionRule: Given a \p DFPattern and expression predicate, produces a candidate + * for every sub-graph matched by the pattern and predicate. Unlike the \p PatternRewriter, + * candidates are free to overlap. Used to bring BYOC patterns into the Collage framework. + * - \p OpCallByKindPartitionRule: Uses the "TOpPattern" attribute provided for every Relay + * operator to produce a candidate for every call to a 'fusable Relay operator'. Used to + * look ahead to how TVM will fuse sub-graphs. + * + * The combinator rules implemented so far: + * - \p CompositePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped + * by a "Composite" function. The "Composite" name is taken from the rule name. Used to indicate + * Relay operators (or groups of Relay operators) should be mapped to target-specific operators, + * both for BYOC and TVM external library integrations. + * - \p PrimitivePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped + * by a "Primitive" function, possibly with an additional "Compiler" attribute. Used to + * delineate a partition (or kernel). + * - \p UnionPartitionRule: Simply unions all the candidates from all sub-rules together. Used to + * combine individual \p DFPatternPartitionRules. + * - \p OnlyValidPartitionRule: Given a \p SubGraphConfig, ignores candidates with 'invalid' + * sub-graphs. Used to limit the maximum candidate depth, the number of independent outputs, + * and whether intermediate 'taps' are allowed. + * - \p HostPartitionRule: Produces candidates for all Relay expressions which could be + * 'left behind' for execution by the host (eg on the VM). This rule lets us simplify the + * overall Collage search algorithm. + * + * (Though not yet implemented, we'd like to allow a combinator rule which will union candidate + * based on their 'anchor' operators. This can be used to implement 'vertical' and 'horizontal' + * partition on more primitive candidates. Note that the \p SubGraph machinery supports + * multiple-input and -output sub-graphs and their validation, so horizontal partition is easy + * implement.) + */ +class PartitionRuleNode : public Object { + public: + /*! + * \brief A unique (over all rules for the same target) name for the rule. Rule names are + * combined and captured with \p PartitionCandidate rule names for debuggability and + * explainability. Some rules will copy the rule name into function attributes. + * + */ + String rule_name_; + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns all the possible candidate partitions according to this rule for the overall + * expression corresponding to \p dataflow_graph. The candidates will generally have unknown + * target and cost: the target will be filled in by the \p PartitionSpec, while the cost will + * be filled in lazily. + */ + virtual std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const; + + std::string ToString() const; + Doc ToDoc() const; + + protected: + virtual void AppendBodyItems(std::vector* body_items) const; + + public: + static constexpr const char* _type_key = "relay.collage.PartitionRule"; + static constexpr const uint32_t _type_child_slots = 10; + TVM_DECLARE_BASE_OBJECT_INFO(PartitionRuleNode, Object); +}; + +class PartitionRule : public ObjectRef { + public: + explicit PartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(PartitionRule, ObjectRef, PartitionRuleNode); +}; + +/*! + * \brief Partition rule which fires on all sub-expressions matching a dataflow-pattern and pattern + * predicate. It is valid for matching candidates to overlap. + */ +class DFPatternPartitionRuleNode : public PartitionRuleNode { + public: + /*! + * \brief Relay pattern. + */ + DFPattern pattern_; + + /*! + * \brief Predicate on matched sub-expression to decide if partition rule should fire. + */ + TPatternPredicate predicate_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.DFPatternPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(DFPatternPartitionRuleNode, PartitionRuleNode); +}; + +class DFPatternPartitionRule : public PartitionRule { + public: + DFPatternPartitionRule(String rule_name, DFPattern pattern, + TPatternPredicate predicate = DefaultPatternPredicate); + + TVM_DEFINE_OBJECT_REF_METHODS(DFPatternPartitionRule, PartitionRule, DFPatternPartitionRuleNode); +}; + +/*! + * \brief Partition rule which wraps candidates within a function with the "Composite" attribute + * bound to the given rule name. + * + * This is the standard way by which operators or operator groups are tagged as being supported + * by a particular externally provided function. It is up to the BYOC lowering function to + * recognize the "Composite" name and emit the appropriate code or call. + */ +class CompositePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-partition rule. */ + PartitionRule sub_rule_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.CompositePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompositePartitionRuleNode, PartitionRuleNode); +}; + +class CompositePartitionRule : public PartitionRule { + public: + CompositePartitionRule(String rule_name, PartitionRule sub_rule); + + TVM_DEFINE_OBJECT_REF_METHODS(CompositePartitionRule, PartitionRule, CompositePartitionRuleNode); +}; + +/*! + * \brief Partition rule which wraps candidates within a function with the "Primitive" attribute + * bound to 1. If the partition spec target(s) have the "compiler" attribute then that name is + * also added to the function as a "Compiler" attribute. + * + * This is the standard way by which sub-graphs are marked as being in a 'partition' who's + * compilation will be managed by an external BYOC toolchain. It can also be used to mark + * sub-graphs for lowering to a single kernel by the built-in TVM lowering machinery. + */ +class PrimitivePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-partition rule. */ + PartitionRule sub_rule_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.PrimitivePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimitivePartitionRuleNode, PartitionRuleNode); +}; + +class PrimitivePartitionRule : public PartitionRule { + public: + PrimitivePartitionRule(String rule_name, PartitionRule sub_rule); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimitivePartitionRule, PartitionRule, PrimitivePartitionRuleNode); +}; + +/*! + * \brief Partition rule which simply unions all matches from all sub-partition rules. + * + * This can be used to combine the results of a set of, eg, DFPatternPartitionRules. + */ +class UnionPartitionRuleNode : public PartitionRuleNode { + public: + Array sub_rules_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.UnionPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnionPartitionRuleNode, PartitionRuleNode); +}; + +class UnionPartitionRule : public PartitionRule { + public: + UnionPartitionRule(String rule_name, Array sub_rules); + + TVM_DEFINE_OBJECT_REF_METHODS(UnionPartitionRule, PartitionRule, UnionPartitionRuleNode) +}; + +/* + *! \brief Partition rule which places calls to Relay operators with a "TOpPattern" attribute of + * \p kOutEWiseFusable or less in their own singleton sub-graph. No other Relay sub-expressions + * (such as tuples or tuple projection) are selected, and it is up to outer partition rules to + * account for them. + */ +class OpCallByKindPartitionRuleNode : public PartitionRuleNode { + public: + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + static constexpr const char* _type_key = "relay.collage.OpCallByKindPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(OpCallByKindPartitionRuleNode, PartitionRuleNode); +}; + +class OpCallByKindPartitionRule : public PartitionRule { + public: + explicit OpCallByKindPartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(OpCallByKindPartitionRule, PartitionRule, + OpCallByKindPartitionRuleNode); +}; + +/*! + * \brief Partition rules which keeps only candidates from the sub-rule whose sub-groups are valid + * w.r.t. the given \p SubGraphConfig. + */ +class OnlyValidPartitionRuleNode : public PartitionRuleNode { + public: + PartitionRule sub_rule_; + SubGraphConfig config_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.OnlyValidPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(OnlyValidPartitionRuleNode, PartitionRuleNode); +}; + +class OnlyValidPartitionRule : public PartitionRule { + public: + OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, const SubGraphConfig& config); + + TVM_DEFINE_OBJECT_REF_METHODS(OnlyValidPartitionRule, PartitionRule, OnlyValidPartitionRuleNode); +}; + +/*! + * \brief Partition rule which selects nodes which can be 'left behind' to be executed by the host + * (eg on the VM). This includes most of the 'interstitial' Relay constructs, such a let bindings, + * operators on references, calls to non-operator functions, and so on. It can also include the + * construction of and projection from tuples which may not be supported within a partition. + */ +class HostPartitionRuleNode : public PartitionRuleNode { + public: + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.HostPartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(HostPartitionRuleNode, PartitionRuleNode); +}; + +class HostPartitionRule : public PartitionRule { + public: + explicit HostPartitionRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(HostPartitionRule, PartitionRule, HostPartitionRuleNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_PARTITION_RULE_H_ diff --git a/src/relay/collage/partition_spec.cc b/src/relay/collage/partition_spec.cc new file mode 100644 index 000000000000..b2095d0a594e --- /dev/null +++ b/src/relay/collage/partition_spec.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_spec.cc + * \brief Combine a \p PartitionRule with a \p Target. + */ + +#include "./partition_spec.h" + +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace collage { + +String DefaultValidateSubGraphFunc(const Function& function) { return String(); } + +TVM_REGISTER_NODE_TYPE(PartitionSpecNode); + +void PartitionSpecNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector PartitionSpecNode::AllCandidates( + const DataflowGraph& dataflow_graph) const { + std::vector result; + // Make sure the target is in scope for inspection by any predicates in + // DFPatternPartitionRuleNode rules. + With target_scope(target_); + // Gather all the candidates. + std::vector candidates = + rule_->AllCandidates(dataflow_graph, GetRef(this)); + // Update the rules names. + for (const auto& candidate : candidates) { + ICHECK_EQ(candidate->spec_, GetRef(this)); + String rule_name = NestLabels(spec_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(candidate, std::move(rule_name)); + result.emplace_back(std::move(new_candidate)); + } + return result; +} + +std::string PartitionSpecNode::ToString() const { + Doc doc; + doc << "PartitionSpec(" << Doc::NewLine(2); + std::vector body_items; + body_items.emplace_back(); + body_items.back() << "spec_name=" << Doc::StrLiteral(spec_name_); + body_items.emplace_back(); + body_items.back() << "target=" << target_->ToDebugString(); + body_items.emplace_back(); + body_items.back() << "rule=" << rule_->ToDoc(); + doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); + doc << ")"; + return doc.str(); +} + +PartitionSpec::PartitionSpec(String spec_name, Target target, PartitionRule rule, + TValidateSubGraphFunc validate_sub_graph_func) { + auto node = runtime::make_object(); + node->spec_name_ = std::move(spec_name); + node->target_ = std::move(target); + node->rule_ = std::move(rule); + node->validate_sub_graph_func_ = std::move(validate_sub_graph_func); + data_ = std::move(node); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/partition_spec.h b/src/relay/collage/partition_spec.h new file mode 100644 index 000000000000..e8ce64c68468 --- /dev/null +++ b/src/relay/collage/partition_spec.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/partition_spec.h + * \brief Combine a \p PartitionRule with a \p Target. + */ + +#ifndef TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ +#define TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ + +#include +#include +#include + +#include +#include + +#include "./partition_rule.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Type of functions for checking the validity of partitions before they proceed to lowering + * and codegen. The argument is the function extracted from the overall expression to represent + * the partition. The result is a non-empty error message string if the candidate should be + * rejected. + */ +using TValidateSubGraphFunc = TypedPackedFunc; + +/*! + * \brief The default validation function. Always returns the empty string, ie no error. + */ +String DefaultValidateSubGraphFunc(const Function& function); + +/*! + * \brief Pairs a \p PartitionRule with one or more \p Targets it can be used for. + */ +class PartitionSpecNode : public Object { + public: + /*! + * \brief Specification name to distinguish this spec from all others. Typically the BYOC + * 'compiler' name, "tvm", or "host". + */ + String spec_name_; + + /*! + * \brief The target all candidate partitions should be compiled for. + * + * It's tempting to support multiple targets here since. Eg the partitioning rules for + * TVM are the same irrespective of whether the target is "cuda" or "llvm", so it would make + * sense to build the candidate partitions first without committing to any target, then 'stamp' + * them for each target as the final step. + * + * However, we want to make sure any predicate in \p DFPatternPartitionRuleNode instances + * can have access to the current target instance. Eg the predicate may need to consult + * build-time configuration to decide what operators, shapes etc are actually supported. + * That implies the specific target is known when the candidate partitions are being constructed. + * + * So for now we'll just force each spec to have exactly one target. + */ + Target target_; + + /*! + * \brief The partition rule to use to gather candidates. + */ + PartitionRule rule_; + + /*! + * \brief The validation function to apply to each candidate's the extracted function before + * proceeding to lowering/codegen. + */ + TValidateSubGraphFunc validate_sub_graph_func_ = DefaultValidateSubGraphFunc; + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns all the candidate partitions found by this specification. The candidates + * will be for a specific target, but will not yet have an extracted function or cost. + */ + std::vector AllCandidates(const DataflowGraph& dataflow_graph) const; + + std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.PartitionSpec"; + TVM_DECLARE_FINAL_OBJECT_INFO(PartitionSpecNode, Object); +}; + +class PartitionSpec : public ObjectRef { + public: + PartitionSpec(String spec_name, Target target, PartitionRule rule, + TValidateSubGraphFunc validate_sub_graph_func = DefaultValidateSubGraphFunc); + + TVM_DEFINE_OBJECT_REF_METHODS(PartitionSpec, ObjectRef, PartitionSpecNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ diff --git a/tests/cpp/relay/collage/partition_rule_test.cc b/tests/cpp/relay/collage/partition_rule_test.cc new file mode 100644 index 000000000000..48cf478e09bb --- /dev/null +++ b/tests/cpp/relay/collage/partition_rule_test.cc @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/relay/collage/partition_rule.h" + +#include +#include +#include +#include + +#include "../../../src/relay/collage/partition_spec.h" + +namespace tvm { +namespace relay { +namespace { + +IRModule TestIRModule() { + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 3 + %1 = nn.relu(%0); // 4 + nn.relu(%1) // 5 + } + )"; + return parser::ParseModule("string", kModel); +} + +std::vector MakeCandidates( + const collage::DataflowGraph& graph, const runtime::String rule_name, + const collage::PartitionSpec& spec, const std::vector> index_sets) { + std::vector candidate_partitions; + for (const auto& indexes : index_sets) { + auto subgraph = collage::SubGraph(graph, collage::IndexSet(graph.size(), indexes)); + auto candidate = collage::CandidatePartition(rule_name, subgraph, spec); + candidate_partitions.emplace_back(std::move(candidate)); + } + return candidate_partitions; +} + +TEST(PartitionRule, DFPatternSingleOp) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto expected_candidates = MakeCandidates(graph, "relu_pattern", spec, {{4}, {5}}); + + auto candidates = rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), 2); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +TEST(PartitionRule, DFPatternOverlap) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto pattern = + IsOp("nn.relu")({IsOp("nn.relu")({IsWildcard()}) || IsOp("abs")({IsWildcard()})}); + auto rule = collage::DFPatternPartitionRule("relu+abs_pattern", pattern); + auto expected_candidates = MakeCandidates(graph, "relu+abs_pattern", spec, {{3, 4}, {4, 5}}); + + auto candidates = rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), 2); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +TEST(PartitionRule, Composite) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + constexpr const char* kExpectedMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); + %1 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="composite") { + nn.relu(%FunctionVar_01) + }; + %2 = %1(%0); + %3 = fn (%FunctionVar_0: Tensor[(10, 10), float32], Composite="composite") { + nn.relu(%FunctionVar_0) + }; + %3(%2) + } + )"; + auto expected_expr = + Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto composite_rule = collage::CompositePartitionRule("composite", df_rule); + + auto candidates = composite_rule->AllCandidates(graph, spec); + auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); + ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + } +} + +TEST(PartitionRule, PrimitiveTVM) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + constexpr const char* kExpectedMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); + %1 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Primitive=1) { + nn.relu(%FunctionVar_01) + }; + %2 = %1(%0); + %3 = fn (%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1) { + nn.relu(%FunctionVar_0) + }; + %3(%2) + } + )"; + auto expected_expr = + Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = collage::PrimitivePartitionRule("primitive", df_rule); + + auto candidates = primitive_rule->AllCandidates(graph, spec); + auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); + ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + } +} + +TVM_REGISTER_TARGET_KIND("test_ext_codegen", kDLCUDA) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + +TEST(PartitionRule, PrimitiveExternal) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("test_ext_codegen"); + auto spec = collage::PartitionSpec("test_ext_codegen", target, {}); + + { + constexpr const char* kExpectedMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); + %1 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Primitive=1, Compiler="test_ext_codegen") { + nn.relu(%FunctionVar_01) + }; + %2 = %1(%0); + %3 = fn (%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="test_ext_codegen") { + nn.relu(%FunctionVar_0) + }; + %3(%2) + } + )"; + auto expected_expr = + Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = collage::PrimitivePartitionRule("primitive", df_rule); + + auto candidates = primitive_rule->AllCandidates(graph, spec); + auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); + ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + } +} + +TEST(PartitionRule, Union) { + IRModule ir_mod = TestIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto abs_pattern = IsOp("abs")({IsWildcard()}); + auto abs_rule = collage::DFPatternPartitionRule("abs_pattern", abs_pattern); + auto relu_pattern = IsOp("nn.relu")({IsWildcard()}); + auto relu_rule = collage::DFPatternPartitionRule("relu_pattern", relu_pattern); + auto union_rule = collage::UnionPartitionRule("union", {abs_rule, relu_rule}); + + auto abs_candidates = MakeCandidates(graph, "abs_pattern", spec, {{3}}); + auto relu_candidates = MakeCandidates(graph, "relu_pattern", spec, {{4}, {5}}); + + std::vector expected_candidates; + expected_candidates.insert(expected_candidates.end(), abs_candidates.begin(), + abs_candidates.end()); + expected_candidates.insert(expected_candidates.end(), relu_candidates.begin(), + relu_candidates.end()); + + auto candidates = union_rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), expected_candidates.size()); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +TEST(PartitionRule, OpCallByKind) { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 4 + %1 = add(%0, %x); // 5 + shape_of(%1) // 6 + } + )"; + auto main = Downcast(parser::ParseModule("string", kMod)->Lookup("main")); + auto graph = collage::DataflowGraph(main); + Target target("llvm"); + auto spec = collage::PartitionSpec("test_spec", target, {}); + + { + auto rule = collage::OpCallByKindPartitionRule("op_call_by_kind"); + auto expected_candidates = MakeCandidates(graph, "op_call_by_kind", spec, {{4}, {5}}); + + auto candidates = rule->AllCandidates(graph, spec); + + ICHECK_EQ(candidates.size(), expected_candidates.size()); + for (size_t i = 0; i < candidates.size(); i++) { + ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + } + } +} + +} // namespace +} // namespace relay +} // namespace tvm From bbac9c86b82171032f9c68505c0ba91e405a1424 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 11 Jul 2022 17:26:34 -0700 Subject: [PATCH 2/3] - Backport improvements to partiton_rule_test.cc --- .../cpp/relay/collage/partition_rule_test.cc | 227 ++++++++++-------- 1 file changed, 129 insertions(+), 98 deletions(-) diff --git a/tests/cpp/relay/collage/partition_rule_test.cc b/tests/cpp/relay/collage/partition_rule_test.cc index 48cf478e09bb..fab34cd3d32d 100644 --- a/tests/cpp/relay/collage/partition_rule_test.cc +++ b/tests/cpp/relay/collage/partition_rule_test.cc @@ -23,15 +23,39 @@ #include #include #include +#include #include "../../../src/relay/collage/partition_spec.h" namespace tvm { namespace relay { +namespace collage { namespace { -IRModule TestIRModule() { - constexpr const char* kModel = R"( +Constant MakeConstant(std::initializer_list shape) { + return Constant(runtime::NDArray::Empty(shape, DataType::Float(32), {kDLCPU, 0})); +} + +Function MakeTestFunction( + const std::string& mod_text, + std::initializer_list> constant_shapes) { + Array constants; + for (const auto& shape : constant_shapes) { + constants.push_back(MakeConstant(shape)); + } + Map> metatable; + metatable.Set("relay.Constant", constants); + IRModule mod = parser::ParseModule("string", mod_text, {}, metatable); + mod = transform::CapturePostDfsIndexInSpans()(mod); + auto func = Downcast(mod->Lookup("main")); + LOG(INFO) << "------- input function -------"; + LOG(INFO) << PrettyPrint(func); + LOG(INFO) << "------------------------------"; + return func; +} + +Function StandardTestFunction() { + constexpr const char* kMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { %0 = abs(%x); // 3 @@ -39,72 +63,100 @@ IRModule TestIRModule() { nn.relu(%1) // 5 } )"; - return parser::ParseModule("string", kModel); + return MakeTestFunction(kMod, /*constant_shapes=*/{}); } -std::vector MakeCandidates( - const collage::DataflowGraph& graph, const runtime::String rule_name, - const collage::PartitionSpec& spec, const std::vector> index_sets) { - std::vector candidate_partitions; +std::vector ActualCandidates(const DataflowGraph& graph, const Function& func, + const PartitionSpec& spec, + const PartitionRule& rule) { + auto candidates = rule->AllCandidates(graph, spec); + LOG(INFO) << "--------- actual candidates -------------"; + for (const auto& candidate : candidates) { + LOG(INFO) << candidate->ToString(); + } + LOG(INFO) << "-----------------------------------------"; + return candidates; +} + +std::vector ExpectedCandidates( + const DataflowGraph& graph, const runtime::String rule_name, const PartitionSpec& spec, + const std::vector> index_sets) { + std::vector candidate_partitions; for (const auto& indexes : index_sets) { - auto subgraph = collage::SubGraph(graph, collage::IndexSet(graph.size(), indexes)); - auto candidate = collage::CandidatePartition(rule_name, subgraph, spec); + auto subgraph = SubGraph(graph, IndexSet(graph.size(), indexes)); + auto candidate = CandidatePartition(rule_name, subgraph, spec); candidate_partitions.emplace_back(std::move(candidate)); } return candidate_partitions; } +void AssertEqual(const std::vector& actual, + const std::vector& expected) { + ASSERT_EQ(actual.size(), expected.size()); + std::set actual_set(actual.begin(), actual.end()); + std::set expected_set(expected.begin(), + expected.end()); + ASSERT_EQ(actual_set.size(), expected_set.size()); + for (const auto& actual_candidate : actual_set) { + ASSERT_EQ(expected_set.count(actual_candidate), 1); + } +} + TEST(PartitionRule, DFPatternSingleOp) { - IRModule ir_mod = TestIRModule(); - auto main = Downcast(ir_mod->Lookup("main")); - auto graph = collage::DataflowGraph(main); + auto func = StandardTestFunction(); + auto graph = DataflowGraph(func); Target target("llvm"); - auto spec = collage::PartitionSpec("test_spec", target, {}); + auto spec = PartitionSpec("test_spec", target, {}); { auto pattern = IsOp("nn.relu")({IsWildcard()}); - auto rule = collage::DFPatternPartitionRule("relu_pattern", pattern); - auto expected_candidates = MakeCandidates(graph, "relu_pattern", spec, {{4}, {5}}); + auto rule = DFPatternPartitionRule("relu_pattern", pattern); + auto expected_candidates = ExpectedCandidates(graph, "relu_pattern", spec, {{4}, {5}}); - auto candidates = rule->AllCandidates(graph, spec); + auto candidates = ActualCandidates(graph, func, spec, rule); ICHECK_EQ(candidates.size(), 2); for (size_t i = 0; i < candidates.size(); i++) { - ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); + ICHECK(CandidatePartitionEquals()(candidates[i], expected_candidates[i])); } } } TEST(PartitionRule, DFPatternOverlap) { - IRModule ir_mod = TestIRModule(); - auto main = Downcast(ir_mod->Lookup("main")); - auto graph = collage::DataflowGraph(main); + auto func = StandardTestFunction(); + auto graph = DataflowGraph(func); Target target("llvm"); - auto spec = collage::PartitionSpec("test_spec", target, {}); + auto spec = PartitionSpec("test_spec", target, {}); { auto pattern = IsOp("nn.relu")({IsOp("nn.relu")({IsWildcard()}) || IsOp("abs")({IsWildcard()})}); - auto rule = collage::DFPatternPartitionRule("relu+abs_pattern", pattern); - auto expected_candidates = MakeCandidates(graph, "relu+abs_pattern", spec, {{3, 4}, {4, 5}}); + auto rule = DFPatternPartitionRule("relu+abs_pattern", pattern); - auto candidates = rule->AllCandidates(graph, spec); + auto candidates = ActualCandidates(graph, func, spec, rule); - ICHECK_EQ(candidates.size(), 2); - for (size_t i = 0; i < candidates.size(); i++) { - ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); - } + auto expected_candidates = + ExpectedCandidates(graph, "relu+abs_pattern", spec, {{3, 4}, {4, 5}}); + AssertEqual(candidates, expected_candidates); } } TEST(PartitionRule, Composite) { - IRModule ir_mod = TestIRModule(); - auto main = Downcast(ir_mod->Lookup("main")); - auto graph = collage::DataflowGraph(main); + auto func = StandardTestFunction(); + auto graph = DataflowGraph(func); Target target("llvm"); - auto spec = collage::PartitionSpec("test_spec", target, {}); + auto spec = PartitionSpec("test_spec", target, {}); { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); + auto composite_rule = CompositePartitionRule("composite", df_rule); + + auto candidates = ActualCandidates(graph, func, spec, composite_rule); + auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); + constexpr const char* kExpectedMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { @@ -119,28 +171,26 @@ TEST(PartitionRule, Composite) { %3(%2) } )"; - auto expected_expr = - Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); - auto pattern = IsOp("nn.relu")({IsWildcard()}); - auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); - auto composite_rule = collage::CompositePartitionRule("composite", df_rule); - - auto candidates = composite_rule->AllCandidates(graph, spec); - auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); - - ICHECK_EQ(candidates.size(), 2); + Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{}); ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); } } TEST(PartitionRule, PrimitiveTVM) { - IRModule ir_mod = TestIRModule(); - auto main = Downcast(ir_mod->Lookup("main")); - auto graph = collage::DataflowGraph(main); + auto func = StandardTestFunction(); + auto graph = DataflowGraph(func); Target target("llvm"); - auto spec = collage::PartitionSpec("test_spec", target, {}); + auto spec = PartitionSpec("test_spec", target, {}); { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = PrimitivePartitionRule("primitive", df_rule); + + auto candidates = ActualCandidates(graph, func, spec, primitive_rule); + auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); constexpr const char* kExpectedMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { @@ -155,16 +205,7 @@ TEST(PartitionRule, PrimitiveTVM) { %3(%2) } )"; - auto expected_expr = - Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); - auto pattern = IsOp("nn.relu")({IsWildcard()}); - auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); - auto primitive_rule = collage::PrimitivePartitionRule("primitive", df_rule); - - auto candidates = primitive_rule->AllCandidates(graph, spec); - auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); - - ICHECK_EQ(candidates.size(), 2); + Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{}); ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); } } @@ -173,13 +214,20 @@ TVM_REGISTER_TARGET_KIND("test_ext_codegen", kDLCUDA) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); TEST(PartitionRule, PrimitiveExternal) { - IRModule ir_mod = TestIRModule(); - auto main = Downcast(ir_mod->Lookup("main")); - auto graph = collage::DataflowGraph(main); + auto func = StandardTestFunction(); + auto graph = DataflowGraph(func); Target target("test_ext_codegen"); - auto spec = collage::PartitionSpec("test_ext_codegen", target, {}); + auto spec = PartitionSpec("test_ext_codegen", target, {}); { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = PrimitivePartitionRule("primitive", df_rule); + + auto candidates = ActualCandidates(graph, func, spec, primitive_rule); + auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates); + + ICHECK_EQ(candidates.size(), 2); constexpr const char* kExpectedMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { @@ -194,49 +242,35 @@ TEST(PartitionRule, PrimitiveExternal) { %3(%2) } )"; - auto expected_expr = - Downcast(parser::ParseModule("string", kExpectedMod)->Lookup("main")); - auto pattern = IsOp("nn.relu")({IsWildcard()}); - auto df_rule = collage::DFPatternPartitionRule("relu_pattern", pattern); - auto primitive_rule = collage::PrimitivePartitionRule("primitive", df_rule); - - auto candidates = primitive_rule->AllCandidates(graph, spec); - auto rewrite_expr = collage::CandidatePartition::ParallelRewrite(graph, candidates); - - ICHECK_EQ(candidates.size(), 2); + Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{}); ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); } } TEST(PartitionRule, Union) { - IRModule ir_mod = TestIRModule(); - auto main = Downcast(ir_mod->Lookup("main")); - auto graph = collage::DataflowGraph(main); + auto func = StandardTestFunction(); + auto graph = DataflowGraph(func); Target target("llvm"); - auto spec = collage::PartitionSpec("test_spec", target, {}); + auto spec = PartitionSpec("test_spec", target, {}); { auto abs_pattern = IsOp("abs")({IsWildcard()}); - auto abs_rule = collage::DFPatternPartitionRule("abs_pattern", abs_pattern); + auto abs_rule = DFPatternPartitionRule("abs_pattern", abs_pattern); auto relu_pattern = IsOp("nn.relu")({IsWildcard()}); - auto relu_rule = collage::DFPatternPartitionRule("relu_pattern", relu_pattern); - auto union_rule = collage::UnionPartitionRule("union", {abs_rule, relu_rule}); + auto relu_rule = DFPatternPartitionRule("relu_pattern", relu_pattern); + auto union_rule = UnionPartitionRule("union", {abs_rule, relu_rule}); + + auto abs_candidates = ExpectedCandidates(graph, "abs_pattern", spec, {{3}}); + auto relu_candidates = ExpectedCandidates(graph, "relu_pattern", spec, {{4}, {5}}); - auto abs_candidates = MakeCandidates(graph, "abs_pattern", spec, {{3}}); - auto relu_candidates = MakeCandidates(graph, "relu_pattern", spec, {{4}, {5}}); + auto candidates = ActualCandidates(graph, func, spec, union_rule); - std::vector expected_candidates; + std::vector expected_candidates; expected_candidates.insert(expected_candidates.end(), abs_candidates.begin(), abs_candidates.end()); expected_candidates.insert(expected_candidates.end(), relu_candidates.begin(), relu_candidates.end()); - - auto candidates = union_rule->AllCandidates(graph, spec); - - ICHECK_EQ(candidates.size(), expected_candidates.size()); - for (size_t i = 0; i < candidates.size(); i++) { - ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); - } + AssertEqual(candidates, expected_candidates); } } @@ -249,24 +283,21 @@ TEST(PartitionRule, OpCallByKind) { shape_of(%1) // 6 } )"; - auto main = Downcast(parser::ParseModule("string", kMod)->Lookup("main")); - auto graph = collage::DataflowGraph(main); + auto func = MakeTestFunction(kMod, {}); + auto graph = DataflowGraph(func); Target target("llvm"); - auto spec = collage::PartitionSpec("test_spec", target, {}); + auto spec = PartitionSpec("test_spec", target, {}); { - auto rule = collage::OpCallByKindPartitionRule("op_call_by_kind"); - auto expected_candidates = MakeCandidates(graph, "op_call_by_kind", spec, {{4}, {5}}); + auto rule = OpCallByKindPartitionRule("op_call_by_kind"); + auto candidates = ActualCandidates(graph, func, spec, rule); - auto candidates = rule->AllCandidates(graph, spec); - - ICHECK_EQ(candidates.size(), expected_candidates.size()); - for (size_t i = 0; i < candidates.size(); i++) { - ICHECK(collage::CandidatePartitionEquals()(candidates[i], expected_candidates[i])); - } + auto expected_candidates = ExpectedCandidates(graph, "op_call_by_kind", spec, {{4}, {5}}); + AssertEqual(candidates, expected_candidates); } } } // namespace +} // namespace collage } // namespace relay } // namespace tvm From 06446ce30c1546b17c729d9853863a71711c34cf Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 11 Jul 2022 17:29:51 -0700 Subject: [PATCH 3/3] - Oops --- src/relay/collage/candidate_partition.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/relay/collage/candidate_partition.h b/src/relay/collage/candidate_partition.h index 8267c3efcb2b..1265087f475f 100644 --- a/src/relay/collage/candidate_partition.h +++ b/src/relay/collage/candidate_partition.h @@ -167,6 +167,12 @@ struct CandidatePartitionEquals { } }; +struct CandidatePartitionCompare { + bool operator()(const CandidatePartition& left, const CandidatePartition& right) const { + return *left->sub_graph_.get() < *right->sub_graph_.get(); + } +}; + } // namespace collage } // namespace relay } // namespace tvm