diff --git a/src/relay/collage/candidate_function_cache.cc b/src/relay/collage/candidate_function_cache.cc new file mode 100644 index 000000000000..32982dc08f3d --- /dev/null +++ b/src/relay/collage/candidate_function_cache.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_function_cache.cc + * \brief A cache of the unique global name and costs for partitioned functions. + */ + +#include "./candidate_function_cache.h" + +namespace tvm { +namespace relay { +namespace collage { + +CandidateFunctionCache::Entry& CandidateFunctionCache::GetEntry(const std::string& label, + const Function& function) { + auto itr = cache_.find(function); + if (itr == cache_.end()) { + String compiler = function->GetAttr(attr::kCompiler, String("tvm")).value(); + std::string global_symbol_name = name_supply_->Fresh({compiler, label}); + GlobalVar global_symbol(std::move(global_symbol_name), function->checked_type()); + itr = cache_.emplace(function, Entry(std::move(global_symbol))).first; + } + return itr->second; +} + +GlobalVar CandidateFunctionCache::GetGlobalSymbol(const Function& function) { + return GetEntry(/*label=*/"", function).global_symbol; +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/candidate_function_cache.h b/src/relay/collage/candidate_function_cache.h new file mode 100644 index 000000000000..8734f5a8e1af --- /dev/null +++ b/src/relay/collage/candidate_function_cache.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/candidate_function_cache.h + * \brief A cache of the unique global symbol name and cost for partitioned functions. + */ + +#ifndef TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_ +#define TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_ + +#include + +#include +#include +#include +#include + +#include "../transforms/compiler_function_utils.h" +#include "./cost.h" +#include "./name_supply.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief A cache of the unique global symbol and cost for functions extracted to represent + * partitions. If two functions are structurally equal (which includes equality of their "Compiler" + * attributes) then they will share the same global symbol and estimated cost. We rely on the + * function's attributes to distinguish partitions which are structurally the same graph but + * intended for different targets. + */ +class CandidateFunctionCache : public transform::GlobalSymbolCache { + public: + explicit CandidateFunctionCache(std::shared_ptr name_supply) + : name_supply_(std::move(name_supply)) {} + + struct Entry { + GlobalVar global_symbol; + Cost cost = Cost::Unknown(); // Filled in when have estimated cost. + + explicit Entry(GlobalVar global_symbol) : global_symbol(std::move(global_symbol)) {} + }; + + /*! + * \brief Returns the unique entry for \p function. If no such entry already exists, create it + * and assign it a unique global symbol name. + */ + Entry& GetEntry(const std::string& label, const Function& function); + + GlobalVar GetGlobalSymbol(const Function& function) final; + + private: + std::shared_ptr name_supply_; + std::unordered_map cache_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_ diff --git a/src/relay/collage/candidate_partition.cc b/src/relay/collage/candidate_partition.cc index 9cccdf96d5a4..20e29a6d4027 100644 --- a/src/relay/collage/candidate_partition.cc +++ b/src/relay/collage/candidate_partition.cc @@ -24,8 +24,12 @@ #include "./candidate_partition.h" +#include #include +#include +#include "../transforms/compiler_function_utils.h" +#include "./candidate_function_cache.h" #include "./candidate_set.h" #include "./partition_rule.h" #include "./partition_spec.h" @@ -106,6 +110,102 @@ std::string CandidatePartitionNode::ToString() const { return os.str(); } +namespace { +/*! + * \brief If function's body is a call to an inlined "Primitive" function, return it. + * Otherwise return function directly. + */ +Function GetPrimitiveFunction(const Function& function) { + if (const auto* call_node = function->body.as()) { + if (const auto* function_node = call_node->op.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + } + } + return function; +} + +/*! + * \brief Eta-expand any tuple arguments of \p function. Ie rewrite: + * \code + * f(x: (t1, t2)) { ... x ... } + * \endcode + * to + * \code + * f(x_1: t1, x_2: t2) { ... (x_1, x_2) ... } + * \endcode + */ +Function EtaExpandTuples(const Function& function) { + Map subst; + Array new_params; + for (const auto& param : function->params) { + std::vector tensor_types = FlattenTupleType(param->type_annotation); + if (tensor_types.size() == 1) { + new_params.push_back(param); + } else { + Array fields; + for (size_t i = 0; i < tensor_types.size(); ++i) { + Var new_param(param->name_hint() + "_" + std::to_string(i), tensor_types[i], param->span); + new_param->checked_type_ = tensor_types[i]; + new_params.push_back(new_param); + fields.push_back(new_param); + } + Tuple new_tuple(fields); + subst.Set(param, new_tuple); + } + } + if (subst.empty()) { + return function; + } + return WithFields(function, new_params, Bind(function->body, subst)); +} + +} // namespace + +Cost CandidatePartitionNode::EstimatedCost( + const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator, + const std::shared_ptr& cache) const { + if (cost_.is_unknown()) { + VLOG_CONTEXT << "spec " << partition_spec_name(); + Function extracted_function = sub_graph_->ExtractAsFunction(dataflow_graph); + VLOG(2) << "Extracted function:" << std::endl << PrettyPrint(extracted_function); + extracted_function = EtaExpandTuples(extracted_function); + VLOG(2) << "Validating function:" << std::endl << PrettyPrint(extracted_function); + String error = partition_spec()->validate_sub_graph_func_(extracted_function); + if (!error.empty()) { + cost_ = Cost::Invalid(); + VLOG(1) << "Unable to rewrite function: " << error; + } else { + // The extracted function may be the eta-expansion of a "Primitive" function. + // If so we want the cached external name and cost to be w.r.t. that function + // rather than the outer so that we'll get a cache hit when we outline functions + // in the final program. + Function primitive_function = GetPrimitiveFunction(extracted_function); + CandidateFunctionCache::Entry& entry = + cache->GetEntry(sub_graph_->label_, primitive_function); + if (entry.cost.is_unknown()) { + IRModule mod = IRModule::FromExpr(extracted_function); + VLOG(1) << "Outlining:" << std::endl << PrettyPrint(mod); + mod = OutlineCompilerFunctions(cache)(mod); + VLOG(1) << "Estimating cost of:" << std::endl + << PrettyPrint(mod) << std::endl + << "using target " << target()->ToDebugString(); + entry.cost = cost_estimator->Estimate(mod, target(), + /*needs_tvm_tuning=*/!target().IsExternalCodegen()); + VLOG(1) << "Measured cost as " << entry.cost.ToString(); + } else { + VLOG(1) << "Reusing cost " << entry.cost.ToString() + << " cached in candidate function cache"; + } + cost_ = entry.cost; + } + } else { + VLOG(1) << "Reusing cost " << cost_.ToString() << " cached in candidate"; + } + return cost_; +} + CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph, ObjectRef /* actually PartitionSpec */ spec, Cost cost) { auto node = runtime::make_object(); diff --git a/src/relay/collage/candidate_partition.h b/src/relay/collage/candidate_partition.h index 1265087f475f..36a23f14bc53 100644 --- a/src/relay/collage/candidate_partition.h +++ b/src/relay/collage/candidate_partition.h @@ -32,7 +32,10 @@ #include #include +#include "./candidate_function_cache.h" #include "./cost.h" +#include "./cost_estimator.h" +#include "./name_supply.h" #include "./sub_graph.h" namespace tvm { @@ -93,6 +96,13 @@ class CandidatePartitionNode : public Object { */ Target target() const; + /*! + * \brief Return the estimated cost of the candidate partition, using \p cost_estimator and + * \p cache. + */ + Cost EstimatedCost(const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator, + const std::shared_ptr& cache) const; + /*! * \brief Returns a brief description of candidate suitable for debugging output. */ diff --git a/src/relay/collage/combiner_rule.cc b/src/relay/collage/combiner_rule.cc new file mode 100644 index 000000000000..bcfef0477292 --- /dev/null +++ b/src/relay/collage/combiner_rule.cc @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/combiner_rule.cc + * \brief Helpers for the \p CombinePartitionRule + */ + +#include "./combiner_rule.h" + +#include "./partition_spec.h" + +namespace tvm { +namespace relay { +namespace collage { + +TVM_REGISTER_NODE_TYPE(SimpleCombinerRuleNode); + +void SimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +bool SimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph, + const CandidatePartition& upstream, + const CandidatePartition& downstream) const { + return false; +} + +std::string SimpleCombinerRuleNode::ToString() const { + return "SimpleCombinerRule(" + rule_name_ + ")"; +} + +SimpleCombinerRule::SimpleCombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(ByKindSimpleCombinerRuleNode); + +void ByKindSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +bool ByKindSimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph, + const CandidatePartition& upstream, + const CandidatePartition& downstream) const { + return upstream->sub_graph_->kind_ <= upstream_kind_ && + downstream->sub_graph_->kind_ <= downstream_kind_; +} + +std::string ByKindSimpleCombinerRuleNode::ToString() const { + std::ostringstream os; + os << "ByKindSimpleCombinerRule(" << rule_name_ << ")"; + return os.str(); +} + +ByKindSimpleCombinerRule::ByKindSimpleCombinerRule(OpPatternKind upstream_kind, + OpPatternKind downstream_kind) { + auto node = runtime::make_object(); + String rule_name = KindToString(upstream_kind) + "->" + KindToString(downstream_kind); + node->rule_name_ = std::move(rule_name); + node->upstream_kind_ = upstream_kind; + node->downstream_kind_ = downstream_kind; + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(CombinerRuleNode); + +void CombinerRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +void CombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {} + +std::string CombinerRuleNode::ToString() const { return "CombinerRuleNode(" + rule_name_ + ")"; } + +CombinerRule::CombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(AllSimpleCombinerRuleNode); + +void AllSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +void AllSimpleCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { + VLOG(1) << "running AllSimpleCombinerRule(" << rule_name_ << ")"; + // Build map from post-dfs indices to the indices of candidates with corresponding entry node. + // NOTE: the index set is over candidate indices not post-dfs indices! + std::vector entry_map(ctxt->dataflow_graph->size(), + IndexSet(ctxt->candidate_set->size())); + for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { + CandidatePartition candidate = ctxt->candidate_set->at(i); + for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) { + entry_map[entry_index].Add(i); + } + } + + for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { + CandidatePartition upstream = ctxt->candidate_set->at(i); + // Narrow our search to just those candidates which could touch. + IndexSet possible_downstream(ctxt->candidate_set->size()); + for (PostDfsIndex output_index : upstream->sub_graph_->output_) { + possible_downstream = possible_downstream | entry_map[output_index]; + } + size_t start_j = + i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0; + for (size_t j : possible_downstream) { + if (i == j) { + continue; + } + if (i < start_j) { + // We already explored the cross-product of candidates [0, first_new_index), so don't + // do it again. + continue; + } + // Note that the rules are not commutative so we can't just ignore if j < i. + CandidatePartition downstream = ctxt->candidate_set->at(j); + if (ctxt->max_depth > 0 && + upstream->sub_graph_->depth_ + downstream->sub_graph_->depth_ > ctxt->max_depth) { + continue; + } + if (!upstream.AreTouching(*ctxt->dataflow_graph, downstream)) { + continue; + } + for (const auto& simple_rule : simple_rules_) { + if (simple_rule->Fires(*ctxt->dataflow_graph, upstream, downstream)) { + CandidatePartition new_candidate = + upstream.DisjointUnion(*ctxt->dataflow_graph, downstream); + VLOG(2) << "Fired " << simple_rule->rule_name_ << " on upstream candidate " + << upstream->ToString() << " and downstream candidate " << downstream->ToString() + << " to yield " << new_candidate->ToString(); + ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); + } + } + } + } +} + +std::string AllSimpleCombinerRuleNode::ToString() const { + std::ostringstream os; + os << "AllSimpleCombinerRule(" << rule_name_; + for (const auto& simple : simple_rules_) { + os << ", " << simple->ToString(); + } + os << ")"; + return os.str(); +} + +AllSimpleCombinerRule::AllSimpleCombinerRule(String rule_name, + Array simple_rules) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->simple_rules_ = std::move(simple_rules); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(TupleArgCombinerRuleNode); + +void TupleArgCombinerRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +void TupleArgCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { + VLOG(1) << "running TupleArgCombinerRule(" << rule_name_ << ")"; + // Build map from post-dfs index to the indices of injective candidates with corresponding entry + // node. NOTE: the index set is over candidate indices not post-dfs indices! + std::vector exit_map(ctxt->dataflow_graph->size(), + IndexSet(ctxt->candidate_set->size())); + for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { + CandidatePartition candidate = ctxt->candidate_set->at(i); + if (candidate->sub_graph_->kind_ > kInjective) { + continue; + } + for (PostDfsIndex exit_index : candidate->sub_graph_->exit_) { + exit_map[exit_index].Add(i); + } + } + + // The two-step I -> tuple -> I rule. + // Look all possible tuple consumers... + for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { + CandidatePartition tuple_consumer_candidate = ctxt->candidate_set->at(i); + if (tuple_consumer_candidate->sub_graph_->kind_ > kInjective) { + continue; + } + // For all possible tuples feeding into candidate... + for (PostDfsIndex input_index : tuple_consumer_candidate->sub_graph_->input_) { + auto node = ctxt->dataflow_graph->index_to_node(input_index); + Expr sub_expr = node->ref(); + const auto* tuple_node = sub_expr.as(); + if (tuple_node == nullptr) { + continue; + } + // The tuple_consumer_candidate candidate consumes (at least one) tuple, eg as an argument + // to an operator. + // eg: concatenate((field1, ..., fieldn)) + auto tuple_dataflow_node = ctxt->dataflow_graph->item_to_node(tuple_node); + + // Collect all the possible unions. There may be more than one if different candidates + // could supply the same tuple field. + std::vector> all_possible_unions; + + // Obviously we must include the consumer. + all_possible_unions.emplace_back(); + all_possible_unions.back().emplace_back(tuple_consumer_candidate); + + // We must include the tuple itself. + SubGraph tuple_sub_graph(*ctxt->dataflow_graph, + IndexSet(ctxt->dataflow_graph->size(), {node->index_}), kInjective, + "tuple"); + CandidatePartition tuple_candidate("", std::move(tuple_sub_graph), + tuple_consumer_candidate->partition_spec()); + all_possible_unions.back().emplace_back(std::move(tuple_candidate)); + + // For all tuple fields... + bool all_tuple_fields_have_producer = true; + for (auto* tuple_field_dataflow_node : tuple_dataflow_node->inputs_) { + // Collect all the candidates which could produce this tuple field. + std::vector to_appends; + size_t start_j = + i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0; + for (size_t j : exit_map[tuple_field_dataflow_node->index_]) { + if (i == j) { + continue; + } + if (i < start_j) { + // We already explored the cross-product of candidates [0, first_new_index), so don't + // do it again. + continue; + } + CandidatePartition tuple_field_producer = ctxt->candidate_set->at(j); + // The tuple_field_producer candidate can provide this tuple field. + // eg concatenate((..., producer, ...)) + to_appends.emplace_back(tuple_field_producer); + } + if (to_appends.empty()) { + // At least one of the tuple's fields does not have a producer candidate we can + // union in, so we need to give up. + all_tuple_fields_have_producer = false; + break; + } else { + // If to_appends = [A, B] and we already have possible unions [C, D] and [E, F] then + // the new possible unions are [C, D, A], [C, D, B], [E, F, A] and [E, F, B]. + std::vector> new_all_possible_unions; + for (const auto& to_append : to_appends) { + for (const auto& possible_union : all_possible_unions) { + new_all_possible_unions.emplace_back(possible_union); + new_all_possible_unions.back().emplace_back(to_append); + } + } + all_possible_unions = std::move(new_all_possible_unions); + } + } + + if (!all_tuple_fields_have_producer) { + continue; + } + + // Actually build the candidates which union according to all_possible_unions. + for (const auto& possible_union : all_possible_unions) { + if (possible_union.size() > 2) { + CandidatePartition new_candidate = + CandidatePartition::DisjointUnion(*ctxt->dataflow_graph, possible_union); +#if TVM_LOG_DEBUG + std::ostringstream os; + bool first = true; + for (const auto& candidate : possible_union) { + if (first) { + first = false; + } else { + os << ", "; + } + os << candidate->ToString(); + } + VLOG(2) << "Fired rule " << rule_name_ << " on {" << os.str() << "} to yield " + << new_candidate->ToString(); +#endif + ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); + } + } + } + } +} + +std::string TupleArgCombinerRuleNode::ToString() const { + return "TupleArgCombinerRule(" + rule_name_ + ")"; +} + +TupleArgCombinerRule::TupleArgCombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(TupleProjCombinerRuleNode); + +void TupleProjCombinerRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +void TupleProjCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { + VLOG(1) << "running TupleProjCombinerRule(" << rule_name_ << ")"; + // We already explored [0, first_new_index), so don't do it again. + for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) { + CandidatePartition base = ctxt->candidate_set->at(i); + for (PostDfsIndex index : base->sub_graph_->output_) { + auto node = ctxt->dataflow_graph->index_to_node(index); + if (node->ref().as()) { + IndexSet index_set(ctxt->dataflow_graph->size(), {node->index_}); + SubGraph sub_graph(*ctxt->dataflow_graph, std::move(index_set), kInjective, "proj"); + CandidatePartition proj_candidate("", std::move(sub_graph), base->spec_); + CandidatePartition new_candidate = + base.DisjointUnion(*ctxt->dataflow_graph, proj_candidate); + VLOG(2) << "Fired rule " << rule_name_ << " on " << proj_candidate->ToString() << " and " + << base->ToString() << " to yield " << new_candidate->ToString(); + ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); + } + } + } +} + +std::string TupleProjCombinerRuleNode::ToString() const { + return "TupleProjCombinerRule(" + rule_name_ + ")"; +} + +TupleProjCombinerRule::TupleProjCombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(ConstantCombinerRuleNode); + +void ConstantCombinerRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +void ConstantCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { + VLOG(1) << "running ConstantCombinerRule(" << rule_name_ << ")"; + // We already explored [0, first_new_index), so don't do it again. + for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) { + CandidatePartition base = ctxt->candidate_set->at(i); + IndexSet new_constants(ctxt->dataflow_graph->size()); + for (PostDfsIndex index : base->sub_graph_->input_) { + auto node = ctxt->dataflow_graph->index_to_node(index); + if (node->ref().as()) { + new_constants.Add(index); + } + } + if (!new_constants.IsZero()) { + SubGraph sub_graph(*ctxt->dataflow_graph, new_constants, kElemWise, "const"); + CandidatePartition new_const_candidate("", std::move(sub_graph), base->spec_); + CandidatePartition new_candidate = + base.DisjointUnion(*ctxt->dataflow_graph, new_const_candidate); + VLOG(2) << "Fired rule " << rule_name_ << " on " << new_const_candidate->ToString() << " and " + << base->ToString() << " to yield " << new_candidate->ToString(); + ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); + } + } +} + +std::string ConstantCombinerRuleNode::ToString() const { + return "ConstantCombinerRule(" + rule_name_ + ")"; +} + +ConstantCombinerRule::ConstantCombinerRule(String rule_name) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + data_ = std::move(node); +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/combiner_rule.h b/src/relay/collage/combiner_rule.h new file mode 100644 index 000000000000..04ea2a9cc127 --- /dev/null +++ b/src/relay/collage/combiner_rule.h @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/combiner_rule.h + * \brief Helpers for the \p CombinePartitionRule + */ + +#ifndef TVM_RELAY_COLLAGE_COMBINER_RULE_H_ +#define TVM_RELAY_COLLAGE_COMBINER_RULE_H_ + +#include +#include + +#include + +#include "./candidate_partition.h" +#include "./candidate_set.h" +#include "./sub_graph.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief Base class for all 'simple' combiner rules. + * + * Given \p upstream and \p downstream candidates which touch, a simple combiner rule returns + * true if their union should also be considered a candidate. + */ +class SimpleCombinerRuleNode : public Object { + public: + String rule_name_; + + void VisitAttrs(AttrVisitor* v); + + virtual bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream, + const CandidatePartition& downstream) const; + + virtual std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.SimpleCombinerRule"; + static constexpr const uint32_t _type_child_slots = 1; + TVM_DECLARE_BASE_OBJECT_INFO(SimpleCombinerRuleNode, Object); +}; + +class SimpleCombinerRule : public ObjectRef { + public: + explicit SimpleCombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(SimpleCombinerRule, ObjectRef, SimpleCombinerRuleNode); +}; + +/*! + * \brief A simple combiner rule which fires if the \p upstream and \p downstream candidates have + * the given \p upstream_kind and \p downstream_kind (or less) respectively. + */ +class ByKindSimpleCombinerRuleNode : public SimpleCombinerRuleNode { + public: + OpPatternKind upstream_kind_; + OpPatternKind downstream_kind_; + + void VisitAttrs(AttrVisitor* v); + + bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream, + const CandidatePartition& downstream) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.ByKindSimpleCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(ByKindSimpleCombinerRuleNode, SimpleCombinerRuleNode); +}; + +class ByKindSimpleCombinerRule : public SimpleCombinerRule { + public: + ByKindSimpleCombinerRule(OpPatternKind upstream_kind, OpPatternKind downstream_kind); + + TVM_DEFINE_OBJECT_REF_METHODS(ByKindSimpleCombinerRule, SimpleCombinerRule, + ByKindSimpleCombinerRuleNode); +}; + +/*! \brief Context required by CombineRuleNode::AppendAllResultsContext. */ +struct AppendAllResultsContext { + AppendAllResultsContext(const DataflowGraph* dataflow_graph, size_t max_depth, + CandidateSet* candidate_set) + : dataflow_graph(dataflow_graph), max_depth(max_depth), candidate_set(candidate_set) {} + + const DataflowGraph* dataflow_graph; + size_t max_depth; + CandidateSet* candidate_set; +}; + +/*! + * \brief Base class for all 'combiner' rules. + * + * Given the current candidate set, a combiner rule looks for opportunities to form larger + * candidates, optionally removing existing candidates in the process. + */ +class CombinerRuleNode : public Object { + public: + String rule_name_; + + void VisitAttrs(AttrVisitor* v); + + virtual void AppendAllResults(AppendAllResultsContext* ctxt) const; + virtual std::string ToString() const; + + static constexpr const char* _type_key = "relay.collage.CombinerRule"; + static constexpr const uint32_t _type_child_slots = 4; + TVM_DECLARE_BASE_OBJECT_INFO(CombinerRuleNode, Object); +}; + +class CombinerRule : public ObjectRef { + public: + explicit CombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(CombinerRule, ObjectRef, CombinerRuleNode); +}; + +/*! + * \brief A combiner rule which runs one or more simple combiner rules over the current + * touching candidates. + */ +class AllSimpleCombinerRuleNode : public CombinerRuleNode { + public: + Array simple_rules_; + + void VisitAttrs(AttrVisitor* v); + + void AppendAllResults(AppendAllResultsContext* ctxt) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.AllSimpleCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllSimpleCombinerRuleNode, CombinerRuleNode); +}; + +class AllSimpleCombinerRule : public CombinerRule { + public: + AllSimpleCombinerRule(String rule_name, Array simple_rules); + + TVM_DEFINE_OBJECT_REF_METHODS(AllSimpleCombinerRule, CombinerRule, AllSimpleCombinerRuleNode); +}; + +/*! + * \brief A combiner rule which combines injective sub-groups which appear inside tuples which are + * themselves inputs to injective sub-groups. + */ +class TupleArgCombinerRuleNode : public CombinerRuleNode { + public: + void VisitAttrs(AttrVisitor* v); + + void AppendAllResults(AppendAllResultsContext* ctxt) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.TupleArgCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleArgCombinerRuleNode, CombinerRuleNode); +}; + +class TupleArgCombinerRule : public CombinerRule { + public: + explicit TupleArgCombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleArgCombinerRule, CombinerRule, TupleArgCombinerRuleNode); +}; + +/*! + * \brief A combiner rule which combines tuple projection if it's an output of an injective + * group. + */ +class TupleProjCombinerRuleNode : public CombinerRuleNode { + public: + void VisitAttrs(AttrVisitor* v); + + void AppendAllResults(AppendAllResultsContext* ctxt) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.TupleProjCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleProjCombinerRuleNode, CombinerRuleNode); +}; + +class TupleProjCombinerRule : public CombinerRule { + public: + explicit TupleProjCombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleProjCombinerRule, CombinerRule, TupleProjCombinerRuleNode); +}; + +/*! + * \brief A combiner rule which combines constants in argument positions to existing candidates. + * Note that scalars are always inlined, so this rule only combines tensor constant arguments. + */ +class ConstantCombinerRuleNode : public CombinerRuleNode { + public: + void VisitAttrs(AttrVisitor* v); + + void AppendAllResults(AppendAllResultsContext* ctxt) const override; + std::string ToString() const override; + + static constexpr const char* _type_key = "relay.collage.ConstantCombinerRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantCombinerRuleNode, CombinerRuleNode); +}; + +class ConstantCombinerRule : public CombinerRule { + public: + explicit ConstantCombinerRule(String rule_name); + + TVM_DEFINE_OBJECT_REF_METHODS(ConstantCombinerRule, CombinerRule, ConstantCombinerRuleNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_COMBINER_RULE_H_ diff --git a/src/relay/collage/cost.h b/src/relay/collage/cost.h index 8ae276d22078..723c5b58ac94 100644 --- a/src/relay/collage/cost.h +++ b/src/relay/collage/cost.h @@ -71,6 +71,11 @@ class Cost { bool is_value() const { return !std::isnan(value_) && !std::isinf(value_); } + double value() const { + ICHECK(is_value()); + return value_; + } + /*! \brief Return true if the less-than relation is defined for this and that. */ bool are_comparable(Cost that) const { return !std::isnan(value_) && !std::isnan(that.value_); } diff --git a/src/relay/collage/cost_estimator.cc b/src/relay/collage/cost_estimator.cc new file mode 100644 index 000000000000..e2ea99ce9b2a --- /dev/null +++ b/src/relay/collage/cost_estimator.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost_estimator.cc + * \brief Interface for measuring candidate partition cost. + */ + +#include "./cost_estimator.h" + +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +TVM_REGISTER_OBJECT_TYPE(CostEstimatorNode); +TVM_REGISTER_OBJECT_TYPE(MockEstimatorNode); + +CostEstimator::CostEstimator() { + auto node = make_object(); + data_ = std::move(node); +} + +Cost CostEstimatorNode::Estimate(const IRModule& mod, const Target& target, + bool needs_tvm_turning) const { + static const runtime::PackedFunc* estimate_seconds = + runtime::Registry::Get("tvm.relay.collage.estimate_seconds"); + ICHECK(estimate_seconds); + const double value = (*estimate_seconds)(mod, target, needs_tvm_turning); + if (std::isinf(value)) { + return Cost::Invalid(); + } else if (std::isnan(value)) { + return Cost::Unknown(); + } else { + return Cost::Value(value); + } +} + +/*! + * \brief Visitor to accumulate the costs of all calls to operators in an expression. + */ +class MockEstimationVisitor : private ExprVisitor { + public: + MockEstimationVisitor(double op_cost, double fusion_benefit) + : op_cost_(op_cost), fusion_benefit_(fusion_benefit) {} + + double EstimateCost(const Expr& body) { + this->VisitExpr(body); + return cost_; + } + + private: + /*! \brief The assumed baseline cost of each operator call. */ + double op_cost_; + /*! + * \brief The factor by which each operator call cost is to be changed for every other + * operator call in the same group. + */ + double fusion_benefit_; + /*! \brief The number of operator calls seen so far. */ + size_t num_ops_ = 0; + /*! \brief Accumulate overall cost. */ + double cost_ = 0.0; + + void VisitExpr_(const CallNode* call_node) final { + if (call_node->op->IsInstance()) { + cost_ += op_cost_ * pow(fusion_benefit_, num_ops_); + num_ops_++; + } + ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const FunctionNode* function_node) final { + // No "Compiler" functions can be inlined. + ICHECK(!function_node->GetAttr(attr::kCompiler).defined()); + ExprVisitor::VisitExpr_(function_node); + } +}; + +Cost MockEstimatorNode::Estimate(const IRModule& mod, const Target& target, + bool needs_tvm_tuning) const { + double op_cost = static_cast(target_costs_.at(target->kind->name)->value); + double cost = 0.0; + for (const auto& kv : mod->functions) { + if (const auto* function_node = kv.second.as()) { + auto function = GetRef(function_node); + if (kv.first->name_hint == "main") { + // Only tensor args are allowed to main. + for (const auto& param : function->params) { + ICHECK(param->type_annotation->IsInstance()); + } + } + cost += MockEstimationVisitor(op_cost, /*fusion_benefit=*/0.9).EstimateCost(function->body); + } + } + return Cost::Value(cost); +} + +MockEstimator::MockEstimator(Map target_costs) { + auto node = make_object(); + node->target_costs_ = std::move(target_costs); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("relay.collage.CostEstimator").set_body_typed([]() { return CostEstimator(); }); + +TVM_REGISTER_GLOBAL("relay.collage.MockEstimator") + .set_body_typed([](Map target_costs) { + return MockEstimator(std::move(target_costs)); + }); + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/cost_estimator.h b/src/relay/collage/cost_estimator.h new file mode 100644 index 000000000000..f433fd58401e --- /dev/null +++ b/src/relay/collage/cost_estimator.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/cost_estimator.cc + * \brief Interface for measuring candidate partition cost. + */ + +#ifndef TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_ +#define TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_ + +#include + +#include "./cost.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! + * \brief An (abstract) estimator for the cost of executing "main" in an \p IRModule representing + * a candidate partition, using the given target for lowering and codegen. + * + * Generally the implementation will compile to a \p runtime::Module (possibly on a target-specific + * worker if cross-compilation is not available), repeatedly invoke "main" with random data until + * measure variance is acceptable (on a target-specific worker), and return the summarized costs. + * + * If using a TVM native \p Target, it is possible compilation will itself invoke TVM tuning. + * + * TODO(mbs): Actually, currently not abstract so can get some local measurements. + */ +class CostEstimatorNode : public Object { + public: + /*! + * \brief Returns the estimated cost (possibly after many many minutes of training time) of + * running "main" in \p mod using \p target, which represents a possible partitioning of + * some overall Relay expression. + */ + virtual Cost Estimate(const IRModule& mod, const Target& target, bool needs_tvm_tuning) const; + + static constexpr const char* _type_key = "relay.collage.CostEstimator"; + TVM_DECLARE_BASE_OBJECT_INFO(CostEstimatorNode, Object); +}; + +class CostEstimator : public ObjectRef { + public: + CostEstimator(); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CostEstimator, ObjectRef, CostEstimatorNode); +}; + +/*! + * \brief A mock cost estimator which can determine the cost of a candidate based on both + * the candidate's target and the number of operator calls inside it. + * + * The estimator also ICHECKs the given module has all "Compiler" functions outlined and @main + * takes only tensor arguments (ie no tuple types). + * + * To support testing only. + */ +class MockEstimatorNode : public CostEstimatorNode { + public: + Cost Estimate(const IRModule& mod, const Target& target, bool needs_tvm_tuning) const override; + + static constexpr const char* _type_key = "relay.collage.MockEstimator"; + TVM_DECLARE_FINAL_OBJECT_INFO(MockEstimatorNode, CostEstimatorNode); + + protected: + friend class MockEstimator; + + /*! + * \brief Map from target kind name to assumed baseline cost (in integer seconds) for all + * operator calls. + */ + Map target_costs_; +}; + +class MockEstimator : public CostEstimator { + public: + explicit MockEstimator(Map target_costs); + + TVM_DEFINE_OBJECT_REF_METHODS(MockEstimator, CostEstimator, MockEstimatorNode); +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_ diff --git a/src/relay/collage/name_supply.cc b/src/relay/collage/name_supply.cc new file mode 100644 index 000000000000..4b7d497b0d57 --- /dev/null +++ b/src/relay/collage/name_supply.cc @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/name_supply.cc + * \brief A source of fresh variable names. + */ + +#include "./name_supply.h" + +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +namespace { +void AppendCSafe(bool* first, std::ostringstream& os, const std::string& str) { + for (size_t i = 0; i < str.size(); ++i) { + const char c = str[i]; + if (i == 0 && first && (!std::isalpha(c) && c != '_')) { + os << "_"; + } + if (c == '_' || std::isalnum(c)) { + os << c; + } else { + os << "_"; + } + *first = false; + } +} +} // namespace + +NameSupply NameSupply::MakeSubNameSupply() { + NameSupply result(prefix_); + for (const auto& kv : next_free_index_) { + result.next_free_index_.emplace(kv.first, kv.second); + } + return result; +} + +std::string NameSupply::Fresh(const std::initializer_list& hints) { + std::ostringstream os; + bool first = true; + bool need_sep = false; + if (!prefix_.empty()) { + AppendCSafe(&first, os, prefix_); + need_sep = true; + } + for (const auto& hint : hints) { + if (hint.empty()) { + continue; + } + if (need_sep) { + os << "_"; + } + AppendCSafe(&first, os, hint); + need_sep = true; + } + std::string name = os.str(); + auto itr = next_free_index_.find(name); + if (itr == next_free_index_.end()) { + next_free_index_.emplace(name, 1); + } else { + os << "_" << itr->second++; + name = os.str(); + } + return name; +} + +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/src/relay/collage/name_supply.h b/src/relay/collage/name_supply.h new file mode 100644 index 000000000000..d37023ab6f81 --- /dev/null +++ b/src/relay/collage/name_supply.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/name_supply.h + * \brief A source of fresh variable names. + */ + +#ifndef TVM_RELAY_COLLAGE_NAME_SUPPLY_H_ +#define TVM_RELAY_COLLAGE_NAME_SUPPLY_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace collage { + +/*! \brief A supply of fresh names. */ +class NameSupply { + public: + explicit NameSupply(std::string prefix) : prefix_(std::move(prefix)) {} + + NameSupply MakeSubNameSupply(); + + void Reserve(const std::string& existing) { next_free_index_.emplace(existing, 1); } + + std::string Fresh(const std::initializer_list& hints); + + private: + /*! \brief Prefix for all names. May be empty. */ + std::string prefix_; + /*! \brief Next unused index for variables with given basename. */ + std::unordered_map next_free_index_; +}; + +} // namespace collage +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COLLAGE_NAME_SUPPLY_H_ diff --git a/src/relay/collage/partition_rule.cc b/src/relay/collage/partition_rule.cc index 1cedbfc9d72c..e11f740acfe9 100644 --- a/src/relay/collage/partition_rule.cc +++ b/src/relay/collage/partition_rule.cc @@ -285,6 +285,66 @@ OpCallByKindPartitionRule::OpCallByKindPartitionRule(String rule_name) { data_ = std::move(node); } +TVM_REGISTER_NODE_TYPE(CombinePartitionRuleNode); + +void CombinePartitionRuleNode::VisitAttrs(AttrVisitor* v) { + // TODO(mbs) +} + +std::vector CombinePartitionRuleNode::AllCandidates( + const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { + // We'll accumulate all the candidates here, starting with those from the sub-rule. + // Once a candidate is added to this vector it is immutable. + std::vector candidates = sub_rule_->AllCandidates(dataflow_graph, spec); + VLOG(1) << "running CombinePartitionRule(" << rule_name_ << ") over " << candidates.size() + << " sub-candidates"; + CandidateSet result_set(std::move(candidates)); + + size_t num_rounds = 0; + AppendAllResultsContext ctxt(&dataflow_graph, max_depth_, &result_set); + while (result_set.PrepareForNextRound()) { + VLOG_CONTEXT << "round " << ++num_rounds; + VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index() + << " existing)"; + for (const auto& combiner_rule : combiner_rules_) { + combiner_rule->AppendAllResults(&ctxt); + } + } + + std::vector result; + for (auto& candidate : result_set.MovedCurrentCandidates()) { + String rule_name = NestLabels(rule_name_, candidate->rule_name_); + CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); + VLOG(2) << "CombinePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); + result.emplace_back(std::move(new_candidate)); + } + VLOG(1) << "CombinePartitionRule(" << rule_name_ << ") produced " << result.size() + << " candidates"; + return result; +} + +void CombinePartitionRuleNode::AppendBodyItems(std::vector* body_items) const { + PartitionRuleNode::AppendBodyItems(body_items); + body_items->emplace_back(); + body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); + for (const auto& combiner_rule : combiner_rules_) { + body_items->emplace_back(); + body_items->back() << "combiner_rule=" << combiner_rule->ToString(); + } + body_items->emplace_back(); + body_items->back() << "max_depth=" << max_depth_; +} + +CombinePartitionRule::CombinePartitionRule(String rule_name, PartitionRule sub_rule, + Array combiner_rules, size_t max_depth_) { + auto node = runtime::make_object(); + node->rule_name_ = std::move(rule_name); + node->sub_rule_ = std::move(sub_rule); + node->combiner_rules_ = std::move(combiner_rules); + node->max_depth_ = max_depth_; + data_ = std::move(node); +} + TVM_REGISTER_NODE_TYPE(OnlyValidPartitionRuleNode); void OnlyValidPartitionRuleNode::VisitAttrs(AttrVisitor* v) { diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h index 13f5c0b01d31..19e7f3ccebfb 100644 --- a/src/relay/collage/partition_rule.h +++ b/src/relay/collage/partition_rule.h @@ -33,6 +33,7 @@ #include "../../printer/doc.h" #include "./candidate_partition.h" +#include "./combiner_rule.h" #include "./sub_graph.h" namespace tvm { @@ -88,6 +89,15 @@ bool DefaultPatternPredicate(const Expr& matched_sub_expr); * delineate a partition (or kernel). * - \p UnionPartitionRule: Simply unions all the candidates from all sub-rules together. Used to * combine individual \p DFPatternPartitionRules. + * - \p CombinePartitionRule: Given a sub-rule and a list of 'combiner' rules, finds + * all possible ways of combining the sub-rule's candidates to yield even larger candidates. + * Note that the sub-rule's candidates may also be directly included in the results. The + * 'combiner' rules allow combining by \p OpPatternKinds, combining the arguments to tuples + * which themselves are arguments to Relay operator calls, and so on. This rule is intended to + * mimic the existing TVM \p FuseOps pass, though: + * i) all candidates are found rather than just the largest, ii) the starting set of candidates + * can be provided by any other rule, and iii) we rely on \p SubGraph validity checking to weed + * out infeasible candidates. * - \p OnlyValidPartitionRule: Given a \p SubGraphConfig, ignores candidates with 'invalid' * sub-graphs. Used to limit the maximum candidate depth, the number of independent outputs, * and whether intermediate 'taps' are allowed. @@ -100,6 +110,54 @@ bool DefaultPatternPredicate(const Expr& matched_sub_expr); * partition on more primitive candidates. Note that the \p SubGraph machinery supports * multiple-input and -output sub-graphs and their validation, so horizontal partition is easy * implement.) + * + * Here are some typical ways to combine \p PartitionRules for different partition/fusion + * strategies: + * + * - Classic pattern-based BYOC with \p MergeComposite/AnnotateTarget/PartitionGraph passes: + * \code + * PrimitivePartitionRule + * OnlyValidPartitionRule + * CombinePartitionRule (with join-anything combiner rule) + * UnionPartitionRule + * CompositePartitionRule(label1) + * DFPatternPartitionRule(pattern1) + * : + * CompositePartitionRule(labeln) + * DFPatternPartitionRule(patternn) + * \endcode + * + * - "Consider this library implementation for these sub-expressions", using \p DFPatterns to + * pick out which Relay operators are supported: + * \code + * OnlyValidPartitionRule + * CombinePartitionRule (with default TVM combiner rules) + * UnionPartitionRule + * OpCallByKindPartitionRule + * CompositePartitionRule(lable1) + * DFPatternPartitionRule(pattern1) + * : + * CompositePartitionRule(lablen) + * DFPatternPartitionRule(patternn) + * \endcode + * + * - Classic TVM \p FuseOps + * \code + * PrimitivePartitionRule + * OnlyValidPartitionRule + * CombinePartitionRule (with default TVM combiner rules) + * OpCallByKindPartitionRule + * \endcode + * + * - "Just fuse what I tell you to fuse", using \p DFPatterns to directly select candidates: + * \code + * PrimitivePartitionRule + * OnlyValidPartitionRule + * UnionPartitionRule + * DFPatternPartitionRule(pattern1) + * : + * DFPatternPartitionRule(patternn) + * \endcode */ class PartitionRuleNode : public Object { public: @@ -293,6 +351,80 @@ class OpCallByKindPartitionRule : public PartitionRule { OpCallByKindPartitionRuleNode); }; +/*! + * \brief Partition rule which combines sub-graphs to exploit optimizations commonly available in + * backends (including the TVM lowering backend). Those optimization rules are in turn described by + * one or more primitive \p CombinerRules. + * + * For TVM these primitive combiner rules are guided by the \p OpPatternKind associated with every + * sub-graph. That in turn is the maximum of the kind of each expression node in the sub-graph, + * using the rules: + * - Constants are \p kElemwise. + * - A call to a Relay operator has the kind of its callee. + * - Tuple construction and projection are injective provided all tuple fields are of tensor type. + * - All other sub-expressions are opaque. + * + * The available \p OpPatternKinds (and our abbreviations for them) are: + * - E: kElemWise, eg nn.relu + * - B: kBroadcast, eg add + * - I: kInjective, eg concatenate + * - R: kCommReduce, eg sum + * - A: kOutEWiseFusable, eg nn.conv2d (often called 'anchor nodes', hence the A abbreviation) + * - O: kOpaque, everything else + * (The kTuple kind is not used by this machinery.) + * + * Kinds are ordered as above from least- to most-constraining w.r.t. possible partition + * opportunities. When we write a kind abbreviation below we intend it to mean that kind *or less*. + * And when when write 'kl -> kr' we mean it to match a sub-expression of kind kr or less who's + * dataflow inputs are all of kind kl or less. + * + * We can then mimic the classic \p FuseOps TVM Pass with the following more primitive combiner + * rules: + * - Sub-groups cannot have taps. In the classic \p FuseOps pass taps are avoided by construction + * by always considering all node->dominator paths. Here we naively allow taps on all candidates, + * but reject them using SubGraph::IsValid with a SubGraphConfig with allow_taps = false. + * - Combine A -> B + * - Combine B -> R + * - Combine I -> I + * - Combine I -> tuple -> I. That is, if an I sub-graph has a tuple as input, and at least one + * tuple field can be provided by an I sub-graph exit, then both the tuple and all such fields + * may be joined. + gt* + * Note that \p FuseOps only considers the largest possible sub-graphs. However this partition rule + * considers all possibilities so as to 'make room' for other targets supplying other + * overlapping candidates. + * + * See combiner_rule.h for the more primitive combiner rules which implement the above. + */ +class CombinePartitionRuleNode : public PartitionRuleNode { + public: + /*! \brief The sub-rule supplying the initial set of candidates. */ + PartitionRule sub_rule_; + /*! \brief The more primitive rules to use to combine the candidates found by the above rule. */ + Array combiner_rules_; + /*! \brief Maximum max_depth for candidates. */ + size_t max_depth_; + + void VisitAttrs(AttrVisitor* v); + + std::vector AllCandidates(const DataflowGraph& dataflow_graph, + const PartitionSpec& spec) const override; + + void AppendBodyItems(std::vector* body_items) const override; + + public: + static constexpr const char* _type_key = "relay.collage.CombinePartitionRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(CombinePartitionRuleNode, PartitionRuleNode); +}; + +class CombinePartitionRule : public PartitionRule { + public: + CombinePartitionRule(String rule_name, PartitionRule sub_rule, Array combiner_rules, + size_t max_depth_); + + TVM_DEFINE_OBJECT_REF_METHODS(CombinePartitionRule, PartitionRule, CombinePartitionRuleNode); +}; + /*! * \brief Partition rules which keeps only candidates from the sub-rule whose sub-groups are valid * w.r.t. the given \p SubGraphConfig. diff --git a/tests/cpp/relay/collage/candidate_partition_test.cc b/tests/cpp/relay/collage/candidate_partition_test.cc new file mode 100644 index 000000000000..c4f81e18ec55 --- /dev/null +++ b/tests/cpp/relay/collage/candidate_partition_test.cc @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/relay/collage/candidate_partition.h" + +#include +#include +#include +#include +#include + +#include "../../../src/relay/collage/partition_spec.h" + +namespace tvm { +namespace relay { +namespace collage { +namespace { + +// NOTE: CandidatePartition::ParallelRewrite is effectively tested in partition_rule_test.cc +// so not re-tested here. The only other non-trivial code is CandidatePartition::EstimateCost + +Function MakeTestFunction(const std::string& mod_text) { + IRModule mod = parser::ParseModule("string", mod_text, {}, {}); + mod = transform::CapturePostDfsIndexInSpans()(mod); + auto func = Downcast(mod->Lookup("main")); + LOG(INFO) << "------- input function -------"; + LOG(INFO) << PrettyPrint(func); + LOG(INFO) << "------------------------------"; + return func; +} + +PartitionSpec StandardSpec() { return PartitionSpec("test_spec", Target("llvm"), {}); } + +String AlwaysInvalid(const Function& function) { return "invalid"; } + +PartitionSpec AlwaysInvalidSpec() { + return PartitionSpec("test_spec", Target("llvm"), {}, AlwaysInvalid); +} + +/*! + * \brief Returns candidate containing nodes with given \p indexes wrapped within a + * "Primitive" and "Compiler" function. + */ +CandidatePartition MakeCandidate(const DataflowGraph& graph, const PartitionSpec& spec, + const std::vector& indexes) { + IndexSet inside(graph.size(), indexes); + SubGraph inner_sub_graph(graph, inside); + FunctionAttrsMap attrs_map; + attrs_map.Set(attr::kPrimitive, Integer(1)); + attrs_map.Set(attr::kCompiler, String("llvm")); + NestedSubGraph nested_sub_graph(inner_sub_graph, attrs_map); + SubGraph outer_sub_graph(graph, inside, inner_sub_graph->kind_, inner_sub_graph->label_, + {nested_sub_graph}); + return CandidatePartition(/*rule_name=*/"", outer_sub_graph, spec); +} + +CostEstimator StandardEstimator() { + Map target_costs; + target_costs.Set("llvm", 3); + return MockEstimator(std::move(target_costs)); +} + +CostEstimator AlternateEstimator() { + Map target_costs; + target_costs.Set("llvm", 7); + return MockEstimator(std::move(target_costs)); +} + +std::shared_ptr Cache() { + return std::make_shared(std::make_shared("test")); +} + +TEST(CandidatePartition, EstimateCost_Simple) { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 3 + %1 = nn.relu(%0); // 4 + nn.relu(%1) // 5 + } + )"; + auto func = MakeTestFunction(kMod); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + auto candidate = MakeCandidate(graph, spec, {3, 4}); + auto estimator = StandardEstimator(); + auto cache = Cache(); + + { + auto cost = candidate->EstimatedCost(graph, estimator, cache); + ASSERT_TRUE(cost.is_value()); + // cost is 3 for nn.rulu plus 3 * 0.9 for the nested abs + ASSERT_EQ(cost.value(), 5.7); + } +} + +TEST(CandidatePartition, EstimateCost_AlreadyCached) { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 3 + %1 = nn.relu(%0); // 4 + nn.relu(%1) // 5 + } + )"; + auto func = MakeTestFunction(kMod); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + auto candidate = MakeCandidate(graph, spec, {3, 4}); + candidate->cost_ = Cost::Value(42.0); + auto estimator = StandardEstimator(); + auto cache = Cache(); + + { + auto cost = candidate->EstimatedCost(graph, estimator, cache); + ASSERT_TRUE(cost.is_value()); + ASSERT_EQ(cost.value(), 42.0); + } +} + +TEST(CandidatePartition, EstimateCost_Invalid) { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 3 + %1 = nn.relu(%0); // 4 + nn.relu(%1) // 5 + } + )"; + auto func = MakeTestFunction(kMod); + auto graph = DataflowGraph(func); + auto spec = AlwaysInvalidSpec(); + auto candidate = MakeCandidate(graph, spec, {3, 4}); + auto estimator = StandardEstimator(); + auto cache = Cache(); + + { + auto cost = candidate->EstimatedCost(graph, estimator, cache); + ASSERT_TRUE(cost.is_invalid()); + } +} + +TEST(CandidatePartition, EstimateCost_Cached) { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 4 + %1 = nn.relu(%0); // 5 + %2 = abs(%1); // 6 + %3 = nn.relu(%2); // 7 + add(%1, %3) // 8 + } + )"; + auto func = MakeTestFunction(kMod); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + auto candidateA = MakeCandidate(graph, spec, {4, 5}); + auto candidateB = MakeCandidate(graph, spec, {6, 7}); + auto standard_estimator = StandardEstimator(); + auto alternate_estimator = AlternateEstimator(); + auto cache = Cache(); + + { + // First candidate estimated as per usual. + auto costA = candidateA->EstimatedCost(graph, standard_estimator, cache); + ASSERT_TRUE(costA.is_value()); + ASSERT_EQ(costA.value(), 5.7); + + // Second candidate is structurally equal to first, so reuse first's cost even though + // estimator has different weights. + auto costB = candidateB->EstimatedCost(graph, alternate_estimator, cache); + ASSERT_TRUE(costB.is_value()); + ASSERT_EQ(costB.value(), costA.value()); + } +} + +TEST(CandidatePartition, EstimateCost_EtaExpandTuples) { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + %0 = abs(%x); // 3 + %1 = nn.relu(%0); // 5 + %2 = (%0, %1); // 6 + concatenate(%2) // 7 + } + )"; + auto func = MakeTestFunction(kMod); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + auto candidate = MakeCandidate(graph, spec, {7}); + auto estimator = StandardEstimator(); + auto cache = Cache(); + + { + auto cost = candidate->EstimatedCost(graph, estimator, cache); + ASSERT_TRUE(cost.is_value()); + ASSERT_EQ(cost.value(), 3); + } +} + +} // namespace +} // namespace collage +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay/collage/partition_rule_test.cc b/tests/cpp/relay/collage/partition_rule_test.cc index fab34cd3d32d..51a4970c7ec0 100644 --- a/tests/cpp/relay/collage/partition_rule_test.cc +++ b/tests/cpp/relay/collage/partition_rule_test.cc @@ -38,7 +38,8 @@ Constant MakeConstant(std::initializer_list shape) { Function MakeTestFunction( const std::string& mod_text, - std::initializer_list> constant_shapes) { + const std::initializer_list>& constant_shapes = + {}) { Array constants; for (const auto& shape : constant_shapes) { constants.push_back(MakeConstant(shape)); @@ -58,12 +59,73 @@ Function StandardTestFunction() { constexpr const char* kMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { - %0 = abs(%x); // 3 - %1 = nn.relu(%0); // 4 - nn.relu(%1) // 5 + // index, kind + %0 = abs(%x); // 3, E + %1 = nn.relu(%0); // 4, E + nn.relu(%1) // 5, E } )"; - return MakeTestFunction(kMod, /*constant_shapes=*/{}); + return MakeTestFunction(kMod); +} + +Function VariantTestFunction() { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(10, 10), float32]) { + // index, kind + %0 = abs(%x); // 4, E + %1 = add(%0, %x); // 5, E + shape_of(%1) // 6, O + } + )"; + return MakeTestFunction(kMod); +} + +Function GPT2ExtractOps() { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(1600, 768), float32]) { + // index, kind + %60 = nn.dense(%x, meta[relay.Constant][0] /*(3072, 768)*/, units=3072); // 6, A + %61 = add(%60, meta[relay.Constant][1] /*(3072)*/); // 8, B + %62 = reshape(%61, newshape=[50, 32, 3072]); // 9, I + %63 = power(%62, 3f); // 15, B + %64 = multiply(%63, 0.044715f); // 17, B + %65 = add(%62, %64); // 18, B + %66 = multiply(%65, 0.797885f); // 20, B + %67 = tanh(%66); // 21, E + %68 = multiply(%62, 0.5f); // 11, B + %69 = add(%67, 1f); // 23, B + multiply(%68, %69) // 24, B + } + )"; + return MakeTestFunction(kMod, {{3072, 768}, {3072}}); +} + +Function GPT2ExtractTuples() { + constexpr const char* kMod = R"( + #[version = "0.0.5"] + def @main(%x: Tensor[(50, 32, 2304), float32]) { + // index, kind + %19 = split(%x, indices_or_sections=[768, 1536], axis=2); // 6, I + %23 = %19.1; // 7 + %24 = reshape(%23, newshape=[50, 32, 12, 64]); // 8, I + %35 = %19.2; // 11 + %36 = reshape(%35, newshape=[50, 32, 12, 64]); // 12, I + %37 = transpose(%36, axes=[0, 2, 1, 3]); // 13, I + %855 = transpose(%24, axes=[0, 2, 1, 3]); // 9, I + %856 = expand_dims(%855, axis=0); // 10, B + %857 = expand_dims(%37, axis=0); // 14, B + %858 = (%856, %857); // 15, B + concatenate(%858) // 16, I + } + )"; + return MakeTestFunction(kMod); +} + +PartitionSpec StandardSpec(const std::string& spec_name = "test_spec", + const std::string& target = "llvm") { + return PartitionSpec(spec_name, Target(target), {}); } std::vector ActualCandidates(const DataflowGraph& graph, const Function& func, @@ -79,12 +141,12 @@ std::vector ActualCandidates(const DataflowGraph& graph, con } std::vector ExpectedCandidates( - const DataflowGraph& graph, const runtime::String rule_name, const PartitionSpec& spec, - const std::vector> index_sets) { + const DataflowGraph& graph, const PartitionSpec& spec, + const std::vector>& index_sets) { std::vector candidate_partitions; for (const auto& indexes : index_sets) { auto subgraph = SubGraph(graph, IndexSet(graph.size(), indexes)); - auto candidate = CandidatePartition(rule_name, subgraph, spec); + auto candidate = CandidatePartition(/*rule_name=*/"", subgraph, spec); candidate_partitions.emplace_back(std::move(candidate)); } return candidate_partitions; @@ -98,66 +160,53 @@ void AssertEqual(const std::vector& actual, expected.end()); ASSERT_EQ(actual_set.size(), expected_set.size()); for (const auto& actual_candidate : actual_set) { - ASSERT_EQ(expected_set.count(actual_candidate), 1); + ASSERT_EQ(expected_set.count(actual_candidate), 1) << actual_candidate->ToString(); } } +void AssertEqual(const Expr& actual, const Expr& expected) { + ASSERT_TRUE(StructuralEqual()(actual, expected)) << PrettyPrint(actual); +} + TEST(PartitionRule, DFPatternSingleOp) { auto func = StandardTestFunction(); auto graph = DataflowGraph(func); - Target target("llvm"); - auto spec = PartitionSpec("test_spec", target, {}); + auto spec = StandardSpec(); { auto pattern = IsOp("nn.relu")({IsWildcard()}); auto rule = DFPatternPartitionRule("relu_pattern", pattern); - auto expected_candidates = ExpectedCandidates(graph, "relu_pattern", spec, {{4}, {5}}); - auto candidates = ActualCandidates(graph, func, spec, rule); + auto actual_candidates = ActualCandidates(graph, func, spec, rule); - ICHECK_EQ(candidates.size(), 2); - for (size_t i = 0; i < candidates.size(); i++) { - ICHECK(CandidatePartitionEquals()(candidates[i], expected_candidates[i])); - } + auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}}); + AssertEqual(actual_candidates, expected_candidates); } } TEST(PartitionRule, DFPatternOverlap) { auto func = StandardTestFunction(); auto graph = DataflowGraph(func); - Target target("llvm"); - auto spec = PartitionSpec("test_spec", target, {}); + auto spec = StandardSpec(); { auto pattern = IsOp("nn.relu")({IsOp("nn.relu")({IsWildcard()}) || IsOp("abs")({IsWildcard()})}); auto rule = DFPatternPartitionRule("relu+abs_pattern", pattern); - auto candidates = ActualCandidates(graph, func, spec, rule); + auto actual_candidates = ActualCandidates(graph, func, spec, rule); - auto expected_candidates = - ExpectedCandidates(graph, "relu+abs_pattern", spec, {{3, 4}, {4, 5}}); - AssertEqual(candidates, expected_candidates); + auto expected_candidates = ExpectedCandidates(graph, spec, {{3, 4}, {4, 5}}); + AssertEqual(actual_candidates, expected_candidates); } } TEST(PartitionRule, Composite) { auto func = StandardTestFunction(); auto graph = DataflowGraph(func); - Target target("llvm"); - auto spec = PartitionSpec("test_spec", target, {}); - - { - auto pattern = IsOp("nn.relu")({IsWildcard()}); - auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); - auto composite_rule = CompositePartitionRule("composite", df_rule); - - auto candidates = ActualCandidates(graph, func, spec, composite_rule); - auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates); + auto spec = StandardSpec(); - ICHECK_EQ(candidates.size(), 2); - - constexpr const char* kExpectedMod = R"( + constexpr const char* kExpectedMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { %0 = abs(%x); @@ -171,27 +220,28 @@ TEST(PartitionRule, Composite) { %3(%2) } )"; - Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{}); - ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + Expr expected_expr = MakeTestFunction(kExpectedMod); + + { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); + auto composite_rule = CompositePartitionRule("composite", df_rule); + + auto actual_candidates = ActualCandidates(graph, func, spec, composite_rule); + auto actual_expr = CandidatePartition::ParallelRewrite(graph, actual_candidates); + + auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}}); + AssertEqual(actual_candidates, expected_candidates); + AssertEqual(actual_expr, expected_expr); } } TEST(PartitionRule, PrimitiveTVM) { auto func = StandardTestFunction(); auto graph = DataflowGraph(func); - Target target("llvm"); - auto spec = PartitionSpec("test_spec", target, {}); - - { - auto pattern = IsOp("nn.relu")({IsWildcard()}); - auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); - auto primitive_rule = PrimitivePartitionRule("primitive", df_rule); - - auto candidates = ActualCandidates(graph, func, spec, primitive_rule); - auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates); + auto spec = StandardSpec(); - ICHECK_EQ(candidates.size(), 2); - constexpr const char* kExpectedMod = R"( + constexpr const char* kExpectedMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { %0 = abs(%x); @@ -205,8 +255,19 @@ TEST(PartitionRule, PrimitiveTVM) { %3(%2) } )"; - Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{}); - ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + Expr expected_expr = MakeTestFunction(kExpectedMod); + + { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = PrimitivePartitionRule("primitive", df_rule); + + auto actual_candidates = ActualCandidates(graph, func, spec, primitive_rule); + auto actual_expr = CandidatePartition::ParallelRewrite(graph, actual_candidates); + + auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}}); + AssertEqual(actual_candidates, expected_candidates); + AssertEqual(actual_expr, expected_expr); } } @@ -216,19 +277,9 @@ TVM_REGISTER_TARGET_KIND("test_ext_codegen", kDLCUDA) TEST(PartitionRule, PrimitiveExternal) { auto func = StandardTestFunction(); auto graph = DataflowGraph(func); - Target target("test_ext_codegen"); - auto spec = PartitionSpec("test_ext_codegen", target, {}); - - { - auto pattern = IsOp("nn.relu")({IsWildcard()}); - auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); - auto primitive_rule = PrimitivePartitionRule("primitive", df_rule); + auto spec = StandardSpec("test_ext_codegen", "test_ext_codegen"); - auto candidates = ActualCandidates(graph, func, spec, primitive_rule); - auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates); - - ICHECK_EQ(candidates.size(), 2); - constexpr const char* kExpectedMod = R"( + constexpr const char* kExpectedMod = R"( #[version = "0.0.5"] def @main(%x: Tensor[(10, 10), float32]) { %0 = abs(%x); @@ -242,16 +293,26 @@ TEST(PartitionRule, PrimitiveExternal) { %3(%2) } )"; - Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{}); - ICHECK(StructuralEqual()(rewrite_expr, expected_expr)); + Expr expected_expr = MakeTestFunction(kExpectedMod); + + { + auto pattern = IsOp("nn.relu")({IsWildcard()}); + auto df_rule = DFPatternPartitionRule("relu_pattern", pattern); + auto primitive_rule = PrimitivePartitionRule("primitive", df_rule); + + auto actual_candidates = ActualCandidates(graph, func, spec, primitive_rule); + auto actual_expr = CandidatePartition::ParallelRewrite(graph, actual_candidates); + + auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}}); + AssertEqual(actual_candidates, expected_candidates); + AssertEqual(actual_expr, expected_expr); } } TEST(PartitionRule, Union) { auto func = StandardTestFunction(); auto graph = DataflowGraph(func); - Target target("llvm"); - auto spec = PartitionSpec("test_spec", target, {}); + auto spec = StandardSpec(); { auto abs_pattern = IsOp("abs")({IsWildcard()}); @@ -260,40 +321,391 @@ TEST(PartitionRule, Union) { auto relu_rule = DFPatternPartitionRule("relu_pattern", relu_pattern); auto union_rule = UnionPartitionRule("union", {abs_rule, relu_rule}); - auto abs_candidates = ExpectedCandidates(graph, "abs_pattern", spec, {{3}}); - auto relu_candidates = ExpectedCandidates(graph, "relu_pattern", spec, {{4}, {5}}); - - auto candidates = ActualCandidates(graph, func, spec, union_rule); + auto actual_candidates = ActualCandidates(graph, func, spec, union_rule); - std::vector expected_candidates; - expected_candidates.insert(expected_candidates.end(), abs_candidates.begin(), - abs_candidates.end()); - expected_candidates.insert(expected_candidates.end(), relu_candidates.begin(), - relu_candidates.end()); - AssertEqual(candidates, expected_candidates); + auto expected_candidates = ExpectedCandidates(graph, spec, {{3}, {4}, {5}}); + AssertEqual(actual_candidates, expected_candidates); } } TEST(PartitionRule, OpCallByKind) { - constexpr const char* kMod = R"( - #[version = "0.0.5"] - def @main(%x: Tensor[(10, 10), float32]) { - %0 = abs(%x); // 4 - %1 = add(%0, %x); // 5 - shape_of(%1) // 6 - } - )"; - auto func = MakeTestFunction(kMod, {}); + auto func = VariantTestFunction(); auto graph = DataflowGraph(func); - Target target("llvm"); - auto spec = PartitionSpec("test_spec", target, {}); + auto spec = StandardSpec(); { auto rule = OpCallByKindPartitionRule("op_call_by_kind"); - auto candidates = ActualCandidates(graph, func, spec, rule); + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}}); + AssertEqual(actual_candidates, expected_candidates); + } +} + +TEST(PartitionRule, Combine_ByKind) { + auto func = GPT2ExtractOps(); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + + { + // Prime the system by picking out all 11 calls to non-opaque ops. + auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind"); + // Combine all <= kOutEWiseFusable (A) actual_candidates (ie anything) with downstream + // <= kBroadcast (B) actual_candidates (ie B or E). + Array simple_rules; + simple_rules.push_back(ByKindSimpleCombinerRule(/*upstream_kind=*/kOutEWiseFusable, + /*downstream_kind=*/kBroadcast)); + Array combiner_rules; + combiner_rules.push_back(AllSimpleCombinerRule("all_simple", std::move(simple_rules))); + // Build the overall partition rule. + auto rule = CombinePartitionRule("combine_by_kind_A_B", std::move(sub_rule), + std::move(combiner_rules), /*max_depth=*/3); + + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + // The original calls. + std::vector> expected; + expected.push_back({6}); + expected.push_back({8}); + expected.push_back({9}); + expected.push_back({11}); + expected.push_back({15}); + expected.push_back({17}); + expected.push_back({18}); + expected.push_back({20}); + expected.push_back({21}); + expected.push_back({23}); + expected.push_back({24}); + + // nn.dense (A) and the following add (B) + expected.push_back({6, 8}); + + // reshape (I) and the following power or multiply or both + expected.push_back({9, 11}); + expected.push_back({9, 15}); + expected.push_back({9, 11, 15}); + + // reshape (I) and the following power and multiply + expected.push_back({9, 15, 17}); + + // reshape (I) and everything after it to the max depth of 3 + expected.push_back({9, 11, 15, 17}); + + // pairs of broadcasts + expected.push_back({11, 24}); // multiply / multiply + expected.push_back({15, 17}); // power / multiply + expected.push_back({17, 18}); // multiply / add + expected.push_back({18, 20}); // add / multiply + expected.push_back({20, 21}); // multiply / tanh + expected.push_back({21, 23}); // tanh / add + expected.push_back({23, 24}); // add / multiply + + // triples of broadcasts + expected.push_back({15, 17, 18}); // power / multiply / add + expected.push_back({17, 18, 20}); // multiply / add / multiply + expected.push_back({18, 20, 21}); // add / multiply / tanh + expected.push_back({20, 21, 23}); // multiply / tanh / add + expected.push_back({21, 23, 24}); // tanh / add / multiply + + auto expected_candidates = ExpectedCandidates(graph, spec, expected); + AssertEqual(actual_candidates, expected_candidates); + } +} + +TEST(PartitionRule, Combine_TupleArg) { + auto func = GPT2ExtractTuples(); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + + { + // Prime the system by picking out all 8 calls to non-opaque ops. + auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind"); + // Merge args of tuples of <= injective (I) fields into the call's group. + Array combiner_rules; + combiner_rules.push_back(TupleArgCombinerRule("tuple_arg")); + // Build the overall partition rule. + auto rule = CombinePartitionRule("combine_tuple_arg", std::move(sub_rule), + std::move(combiner_rules), /*max_depth=*/3); + + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + // The original calls + std::vector> expected; + expected.push_back({6}); + expected.push_back({8}); + expected.push_back({9}); + expected.push_back({10}); + expected.push_back({12}); + expected.push_back({13}); + expected.push_back({14}); + expected.push_back({16}); + + // The concatenate((expand_dims(...), expand_dims(...)) is grouped. + expected.push_back({10, 14, 15, 16}); + + auto expected_candidates = ExpectedCandidates(graph, spec, expected); + AssertEqual(actual_candidates, expected_candidates); + } +} + +TEST(PartitionRule, Combine_TupleProj) { + auto func = GPT2ExtractTuples(); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + + { + // Prime the system by picking out all 8 calls to non-opaque ops. + auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind"); + // Merge projections from injective groups. + Array combiner_rules; + combiner_rules.push_back(TupleProjCombinerRule("tuple_proj")); + // Build the overall partition rule. + auto rule = CombinePartitionRule("combine_tuple_proj", std::move(sub_rule), + std::move(combiner_rules), /*max_depth=*/3); + + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + // The original calls + std::vector> expected; + expected.push_back({6}); + expected.push_back({8}); + expected.push_back({9}); + expected.push_back({10}); + expected.push_back({12}); + expected.push_back({13}); + expected.push_back({14}); + expected.push_back({16}); + + // split / proj 1 + expected.push_back({6, 7}); + // split / proj 2 + expected.push_back({6, 11}); + // split and both projections + expected.push_back({6, 7, 11}); + + auto expected_candidates = ExpectedCandidates(graph, spec, expected); + AssertEqual(actual_candidates, expected_candidates); + } +} + +TEST(PartitionRule, Combine_Constant) { + auto func = GPT2ExtractOps(); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + + { + // Prime the system by picking out all 11 calls to non-opaque ops. + auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind"); + // Merge constant args into injective groups + Array combiner_rules; + combiner_rules.push_back(ConstantCombinerRule("constant")); + // Build the overall partition rule. + auto rule = CombinePartitionRule("combine_constant", std::move(sub_rule), + std::move(combiner_rules), /*max_depth=*/3); + + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + // The original calls + std::vector> expected; + expected.push_back({6}); + expected.push_back({8}); + expected.push_back({9}); + expected.push_back({11}); + expected.push_back({15}); + expected.push_back({17}); + expected.push_back({18}); + expected.push_back({20}); + expected.push_back({21}); + expected.push_back({23}); + expected.push_back({24}); + + // Constant arg to nn.dense + expected.push_back({5, 6}); + + // Constant arg to add + expected.push_back({7, 8}); + + auto expected_candidates = ExpectedCandidates(graph, spec, expected); + AssertEqual(actual_candidates, expected_candidates); + } +} + +TEST(PartitionRule, Combine_Mixed) { + auto func = GPT2ExtractOps(); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + + { + // Prime the system by picking out all 11 calls to non-opaque ops. + auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind"); + + // Mimic the FuseOps rules. + Array simple_rules; + simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast)); + simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce)); + simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective)); + Array combiner_rules; + combiner_rules.push_back(AllSimpleCombinerRule("all_simple", std::move(simple_rules))); + + // Merge constant args into injective groups + combiner_rules.push_back(ConstantCombinerRule("constant")); + + // Build the overall partition rule. + auto rule = CombinePartitionRule("combine_mixed", std::move(sub_rule), + std::move(combiner_rules), /*max_depth=*/3); + + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + // The original calls + std::vector> expected; + expected.push_back({6}); + expected.push_back({8}); + expected.push_back({9}); + expected.push_back({11}); + expected.push_back({15}); + expected.push_back({17}); + expected.push_back({18}); + expected.push_back({20}); + expected.push_back({21}); + expected.push_back({23}); + expected.push_back({24}); + + // A -> B merging + expected.push_back({6, 8}); + expected.push_back({9, 11}); + expected.push_back({9, 15}); + expected.push_back({9, 11, 15}); + expected.push_back({9, 15, 17}); + expected.push_back({9, 11, 15, 17}); + expected.push_back({11, 24}); + expected.push_back({15, 17}); + expected.push_back({17, 18}); + expected.push_back({18, 20}); + expected.push_back({20, 21}); + expected.push_back({21, 23}); + expected.push_back({23, 24}); + expected.push_back({15, 17, 18}); + expected.push_back({17, 18, 20}); + expected.push_back({18, 20, 21}); + expected.push_back({20, 21, 23}); + expected.push_back({21, 23, 24}); + + // Constant args + expected.push_back({5, 6}); + expected.push_back({7, 8}); + + // B -> R + expected.push_back({8, 9}); + expected.push_back({8, 9, 11}); + expected.push_back({8, 9, 15}); + + // Constant's and A -> B + expected.push_back({5, 6, 8}); + expected.push_back({5, 6, 7, 8}); + + // Constants and B -> R + expected.push_back({7, 8, 9}); + expected.push_back({7, 8, 9, 11}); + expected.push_back({7, 8, 9, 15}); + + auto expected_candidates = ExpectedCandidates(graph, spec, expected); + AssertEqual(actual_candidates, expected_candidates); + } +} + +TEST(PartitionRule, OnlyValid) { + auto func = GPT2ExtractOps(); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); - auto expected_candidates = ExpectedCandidates(graph, "op_call_by_kind", spec, {{4}, {5}}); - AssertEqual(candidates, expected_candidates); + { + // Prime the system by picking out all 11 calls to non-opaque ops. + auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind"); + // Combine all <= kOutEWiseFusable (A) actual_candidates (ie anything) with downstream + // <= kBroadcast (B) actual_candidates (ie B or E). + Array simple_rules; + simple_rules.push_back(ByKindSimpleCombinerRule(/*upstream_kind=*/kOutEWiseFusable, + /*downstream_kind=*/kBroadcast)); + Array combiner_rules; + combiner_rules.push_back(AllSimpleCombinerRule("all_simple", std::move(simple_rules))); + auto combine_rule = CombinePartitionRule("combine_by_kind_A_B", std::move(sub_rule), + std::move(combiner_rules), /*max_depth=*/3); + // Only allow up to depth 2, no taps and 1 exit. + SubGraphConfig config; + config.allow_taps = false; + config.max_depth = 2; + config.max_exits = 1; + + // Build the overall partition rule. + auto rule = OnlyValidPartitionRule("only_valid", std::move(combine_rule), config); + + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + // The original calls. + std::vector> expected; + expected.push_back({6}); + expected.push_back({8}); + expected.push_back({9}); + expected.push_back({11}); + expected.push_back({15}); + expected.push_back({17}); + expected.push_back({18}); + expected.push_back({20}); + expected.push_back({21}); + expected.push_back({23}); + expected.push_back({24}); + + // nn.dense (A) and the following add (B) + expected.push_back({6, 8}); + + // pairs of broadcasts + expected.push_back({11, 24}); // multiply / multiply + expected.push_back({15, 17}); // power / multiply + expected.push_back({17, 18}); // multiply / add + expected.push_back({18, 20}); // add / multiply + expected.push_back({20, 21}); // multiply / tanh + expected.push_back({21, 23}); // tanh / add + expected.push_back({23, 24}); // add / multiply + + // The following candidates are filtered out because they have 2 or 3 exits: + // {9, 11}, {9, 15}, {9,11,15}, {9,15,17}, {15,17,18}, {17,18,20}, + // {18,20,21}, {20,21,23}, {21,23,24}, {9,11,15,17} + + auto expected_candidates = ExpectedCandidates(graph, spec, expected); + AssertEqual(actual_candidates, expected_candidates); + } +} + +TEST(PartitionRule, Host) { + auto func = GPT2ExtractTuples(); + auto graph = DataflowGraph(func); + auto spec = StandardSpec(); + + { + auto rule = HostPartitionRule("host"); + + auto actual_candidates = ActualCandidates(graph, func, spec, rule); + + std::vector> expected; + + // Function arg %x + expected.push_back({0}); + // Operators + expected.push_back({1}); // concatenate + expected.push_back({2}); // expand_dims + expected.push_back({3}); // transpose + expected.push_back({4}); // reshape + expected.push_back({5}); // split + // Tuple projection + expected.push_back({7}); + expected.push_back({11}); + // Tuple construction + expected.push_back({15}); + // The overall @main function + expected.push_back({17}); + + auto expected_candidates = ExpectedCandidates(graph, spec, expected); + AssertEqual(actual_candidates, expected_candidates); } }