From b01fcf8bcac01e796b177091156828d69d8ebc93 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 23 Jul 2020 16:08:49 -0700 Subject: [PATCH] address comments --- include/tvm/auto_scheduler/compute_dag.h | 55 ++++--- include/tvm/auto_scheduler/loop_state.h | 4 +- python/tvm/autotvm/task/relay_integration.py | 1 + src/auto_scheduler/compute_dag.cc | 159 +++++++++---------- src/auto_scheduler/utils.h | 18 +++ tests/cpp/auto_scheduler_test.cc | 18 +-- 6 files changed, 132 insertions(+), 123 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index de158915d2668..b9c1f9e8c45ca 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -51,22 +51,27 @@ namespace auto_scheduler { class AccessAnalyzerNode : public Object { public: template - using OperationMap = std::unordered_map; + using OperationMap = std::unordered_map; /*! \brief Map an operation to all operations it reads from. - * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/ + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * The inner vector represents the indices of multi-dimensional access.*/ OperationMap>>> read_from; /*! \brief Map an operation to all operations it is read by. - * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/ + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * The inner vector represents the indices of multi-dimensional access.*/ OperationMap>>> read_by; /*! \brief Store the number of common outer iterators for operation pairs that have * read-write relations. */ OperationMap> num_common_outer_iterators; - /*! \brief Store whether the operation is injective */ - OperationMap is_injective; - /*! \brief Store whether the operation is strictly-inlineable */ + /*! \brief Store whether the operation is an op with only simple access. + * (e.g., injective, broadcast and elementwise ops without reduction) */ + OperationMap is_simple_access; + /*! \brief Store whether the operation is strictly-inlineable + * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) */ OperationMap is_strict_inlineable; - /*! \brief Store whether the operation needs multi-level tiling */ + /*! \brief Store whether the operation needs multi-level tiling + * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */ OperationMap needs_multi_level_tiling; /*! \brief Store whether the operation is an output operation */ OperationMap is_output; @@ -86,22 +91,25 @@ class AccessAnalyzer : public ObjectRef { explicit AccessAnalyzer(const Array& tensors); /*! - * \brief Return whether this operation needs multi-level tiling + * \brief Return whether this operation is an injective operation + * (e.g., injective, broadcast and elementwise ops without reduction) * \param op The operation */ - TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; + TVM_DLL bool IsSimpleAccess(const te::Operation& op) const; /*! - * \brief Return whether this operation is an injective operation + * \brief Return whether this operation is strictly inlinable + * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) * \param op The operation */ - TVM_DLL bool IsInjective(const te::Operation& op) const; + TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; /*! - * \brief Return whether this operation is strictly inlinable + * \brief Return whether this operation needs multi-level tiling + * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) * \param op The operation */ - TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; + TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; /*! * \brief Return whether this operation is an output op @@ -113,33 +121,30 @@ class AccessAnalyzer : public ObjectRef { * \brief Get all consumers of on operation * \param state The current loop state * \param op The operation - * \param consumers The return consumer set + * \return The set of consumers * \note This function propagates the relation for inlined ops */ - TVM_DLL void GetConsumers( - const State& state, const te::Operation& op, - std::unordered_set* consumers) const; + TVM_DLL std::unordered_set GetConsumers( + const State& state, const te::Operation& op) const; /*! * \brief Get all producers of on operation * \param state The current loop state * \param op The operation - * \param producers The return producer set + * \return The set of producers * \note This function propagates the relation for inlined ops */ - TVM_DLL void GetProducers( - const State& state, const te::Operation& op, - std::unordered_set* producers) const; + TVM_DLL std::unordered_set GetProducers( + const State& state, const te::Operation& op) const; /*! * \brief Get all direct producers of on operation * \param op The operation - * \param producers The return producer set + * \return The set of direct producers * \note This function DOES NOT propagate the relation for inlined ops */ - TVM_DLL void GetDirectProducers( - const te::Operation& op, - std::unordered_set* producers) const; + TVM_DLL std::unordered_set GetDirectProducers( + const te::Operation& op) const; /*! * \brief Get the number of common outer iterators. diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 00de0c801568a..4e9cb9bd7d20a 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -159,7 +159,7 @@ using IterKey = std::pair; */ class AttachMapNode : public Object { public: - struct key_hash : public std::function { + struct IterKeyHash { std::size_t operator()(const IterKey& k) const { return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); } @@ -168,7 +168,7 @@ class AttachMapNode : public Object { /*! \brief A Map to store the mapping of stage to its attached iterator. */ std::unordered_map stage_to_attach_iter; /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ - std::unordered_map, key_hash> iter_to_attached_stages; + std::unordered_map, IterKeyHash> iter_to_attached_stages; static constexpr const char* _type_key = "auto_scheduler.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 9a43f2f1ad95e..70f32eb81a75d 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -26,6 +26,7 @@ import tvm from .task import create from .topi_integration import TaskExtractEnv +from .dispatcher import FallbackContext logger = logging.getLogger('autotvm') diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 92239e51101a6..d81bed21c30f3 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -38,6 +38,7 @@ #include #include "utils.h" +#include "../arith/pattern_match.h" namespace tvm { namespace auto_scheduler { @@ -119,7 +120,7 @@ Array TopoSortOps(const Array& tensors) { } // Extract all tensor accesses in an expr -class TensorAccessExtractor : public StmtExprVisitor { +class ReadAccessExtractor : public StmtExprVisitor { public: void Extract(PrimExpr expr) { this->VisitExpr(expr); } @@ -131,8 +132,8 @@ class TensorAccessExtractor : public StmtExprVisitor { } void VisitExpr_(const ProducerLoadNode* op) final { - buf_accesses[Downcast(op->producer)->op].emplace_back(op->indices.begin(), - op->indices.end()); + read_access[Downcast(op->producer)->op].emplace_back(op->indices.begin(), + op->indices.end()); StmtExprVisitor::VisitExpr_(op); } @@ -146,28 +147,31 @@ class TensorAccessExtractor : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - OperationMap>> buf_accesses; + // All read accesses to all operations + // The innermost vector stores mulit-dimentional indices. + // The middle vector stores possible multiple accesses + OperationMap>> read_access; + // Whether this expression has branch bool has_branch{false}; }; -// Returns whether the expr equals to the var with a const shift +// Returns whether the expr equals to the var with an optional const shift bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { - if (auto pv = expr.as()) { - return pv == var.get(); - } else if (auto padd = expr.as()) { - return ((padd->a.get() == var.get() && padd->b->IsInstance()) || - (padd->b.get() == var.get() && padd->a->IsInstance())); - } else if (auto psub = expr.as()) { - return ((psub->a.get() == var.get() && psub->b->IsInstance()) || - (psub->b.get() == var.get() && psub->a->IsInstance())); - } else { - return false; + arith::PVar x; + arith::PVar c; + + if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) && + x.Eval().same_as(var)) { + return true; } + return false; } -// Return whether the access is injective -bool IsInjective(const te::Operation& op, const std::vector& index, bool* axis_missing, - bool* axis_duplicated, bool* same_order) { +// Return whether the access to an operation is a simple access +// (i.e. all index is just a variable with an optional constant shift) +// For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not. +bool IsSimpleAccess(const te::Operation& op, const std::vector& indices, bool* axis_missing, + bool* axis_duplicated, bool* same_order) { auto cop = op.as(); if (cop == nullptr) { return false; @@ -176,7 +180,7 @@ bool IsInjective(const te::Operation& op, const std::vector& index, bo std::vector index_to_var_idx; std::vector var_idx_ct(cop->axis.size(), 0); - for (const auto& expr : index) { + for (const auto& expr : indices) { if (!is_const_int(expr)) { bool found = false; for (size_t i = 0; i < cop->axis.size(); ++i) { @@ -214,7 +218,7 @@ bool IsInjective(const te::Operation& op, const std::vector& index, bo } // Gather all VarNodes in an expr -static void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { +void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { PostOrderVisit(expr, [&vars](const ObjectRef& node) { if (const VarNode* op = node.as()) { vars->insert(op); @@ -223,7 +227,7 @@ static void GatherVars(const PrimExpr& expr, std::unordered_set* } // Check whether an expr has expensive operations (e.g. exp) -static bool HasExpensiveOp(const PrimExpr& expr) { +bool HasExpensiveOp(const PrimExpr& expr) { bool found = false; PostOrderVisit(expr, [&found](const ObjectRef& node) { if (const CallNode* op = node.as()) { @@ -239,28 +243,28 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { auto node = make_object(); OperationMap has_branch; - // get all ops + // Get all ops in topological order node->ops_topo_order = TopoSortOps(tensors); arith::Analyzer analyzer; - // build read & write access map + // Build read & write access map for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { node->read_from[op] = OperationMap>>(); } else if (auto cop = op.as()) { - TensorAccessExtractor extractor; + ReadAccessExtractor extractor; for (const auto& exp : cop->body) { extractor.Extract(exp); } // read_by and read_from map - for (const auto& iter : extractor.buf_accesses) { + for (const auto& iter : extractor.read_access) { std::vector>& accesses = node->read_by[iter.first][op]; accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); } - node->read_from[op] = std::move(extractor.buf_accesses); + node->read_from[op] = std::move(extractor.read_access); has_branch[op] = extractor.has_branch; // compute number of common outer iterators @@ -278,15 +282,15 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { break; } - bool direct_access = true; + bool injective = true; for (const auto& access : access_list) { if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) { - direct_access = false; + injective = false; break; } } - if (!direct_access) { + if (!injective) { break; } } @@ -299,24 +303,24 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { } } - // do some static analysis + // Do some static analysis on ComputeOps for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { - node->is_injective[op] = true; + node->is_simple_access[op] = true; node->needs_multi_level_tiling[op] = false; node->is_strict_inlineable[op] = false; node->is_output[op] = false; - } else if (auto pop = op.as()) { + } else if (auto cop = op.as()) { // check whether this op is element-wise and strict-inlineable bool is_injective = true; bool is_strict_inlineable = true; bool axis_missing, axis_duplicated, same_order; for (const auto& pair : node->read_from[op]) { - const std::vector>& access = pair.second; - for (const auto& index : access) { - if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated, - &same_order)) { + const std::vector>& access_list = pair.second; + for (const auto& access : access_list) { + if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated, + &same_order)) { is_injective = false; is_strict_inlineable = false; break; @@ -330,42 +334,40 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { break; } } - if (has_branch[op]) { - is_strict_inlineable = false; - } // don't strictly inline expensive op (e.g. exp) bool has_expensive_op = false; - for (const auto& expr : pop->body) { + for (const auto& expr : cop->body) { has_expensive_op |= HasExpensiveOp(expr); } + if (has_expensive_op || has_branch[op]) { + is_strict_inlineable = false; + } - node->is_injective[op] = is_injective; - node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; + node->is_simple_access[op] = is_injective; + node->is_strict_inlineable[op] = is_strict_inlineable; // check whether the op needs multi-level tiling bool needs_multi_level_tiling = false; int n_missing = 0; for (const auto& pair : node->read_from[op]) { - const std::vector>& access = pair.second; + const std::vector>& access_list = pair.second; std::unordered_set vars; - for (const std::vector& indices : access) { - for (const PrimExpr& expr : indices) { + for (const std::vector& access : access_list) { + for (const PrimExpr& expr : access) { GatherVars(expr, &vars); } } - bool missing = false; - for (const auto& axis : pop->axis) { + + for (const auto& axis : cop->axis) { if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { - missing = true; + n_missing++; + break; } } - if (missing) { - n_missing++; - } - if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) { + if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) { needs_multi_level_tiling = true; break; } @@ -373,7 +375,7 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { node->needs_multi_level_tiling[op] = needs_multi_level_tiling; - // check whether is output + // check whether the op is output node->is_output[op] = node->read_by[op].empty(); } else { LOG(FATAL) << "Invalid op" << op; @@ -391,16 +393,15 @@ bool AccessAnalyzer::IsOutput(const te::Operation& op) const { return operator->()->is_output.at(op); } -bool AccessAnalyzer::IsInjective(const te::Operation& op) const { - return operator->()->is_injective.at(op); +bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const { + return operator->()->is_simple_access.at(op); } bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const { return operator->()->is_strict_inlineable.at(op); } -void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, - OperationSet* consumers) const { +OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const { OperationSet inlined_ops; for (const auto& stage : state->stages) { if (stage->compute_at == ComputeAtKind::kInlined) { @@ -408,30 +409,31 @@ void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, } } + OperationSet consumers; std::function collect; collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { for (const auto& iter : operator->()->read_by.at(op)) { if (inlined_ops.count(iter.first)) { collect(iter.first); } else { - consumers->insert(iter.first); + consumers.insert(iter.first); } } }; - consumers->clear(); collect(op); + return consumers; } -void AccessAnalyzer::GetDirectProducers(const te::Operation& op, OperationSet* producers) const { - producers->clear(); +OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const { + OperationSet producers; for (const auto& iter : operator->()->read_from.at(op)) { - producers->insert(iter.first); + producers.insert(iter.first); } + return producers; } -void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, - OperationSet* producers) const { +OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const { OperationSet inlined_ops; for (const auto& stage : state->stages) { if (stage->compute_at == ComputeAtKind::kInlined) { @@ -439,19 +441,20 @@ void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, } } + OperationSet producers; std::function collect; collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) { for (const auto& iter : operator->()->read_from.at(op)) { if (inlined_ops.count(iter.first)) { collect(iter.first); } else { - producers->insert(iter.first); + producers.insert(iter.first); } } }; - producers->clear(); collect(op); + return producers; } int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, @@ -478,24 +481,6 @@ int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, return meet ? ret : 0; } -// Return whether two int arrays are elementwise-equal -bool IntArrayEqual(const Array& arr1, const Array& arr2) { - if (arr1.size() != arr2.size()) { - return false; - } - - for (size_t i = 0; i < arr1.size(); ++i) { - auto int1 = arr1[i].as(); - auto int2 = arr2[i].as(); - CHECK(int1 != nullptr); - CHECK(int2 != nullptr); - if (int1->value != int2->value) { - return false; - } - } - return true; -} - bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const { te::Operation cur_op = op; @@ -508,7 +493,7 @@ bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, } te::Operation next_op = map.begin()->first; - // Check condition 1: has the same output size + // Check condition 1: They have the same output size auto p_cur = cur_op.as(); auto p_next = next_op.as(); if (p_cur == nullptr || p_next == nullptr) { @@ -527,12 +512,12 @@ bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, } } - // Check condition 2: read is elementwise + // Check condition 2: The read is elementwise const std::vector> reads = map.begin()->second; bool is_injective, axis_missing, axis_duplicated, same_order; for (const auto& read : reads) { is_injective = - auto_scheduler::IsInjective(next_op, read, &axis_missing, &axis_duplicated, &same_order); + auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing, &axis_duplicated, &same_order); if (!is_injective || axis_missing || axis_duplicated || !same_order) { return false; } diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index de800da13b64f..da5032e11c97a 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -128,6 +128,24 @@ inline std::vector IntArrayToVector( return out; } +/*! \brief Return whether two int arrays are elementwise-equal */ +inline bool IntArrayEqual(const Array& arr1, const Array& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); ++i) { + auto int1 = arr1[i].as(); + auto int2 = arr2[i].as(); + CHECK(int1 != nullptr); + CHECK(int2 != nullptr); + if (int1->value != int2->value) { + return false; + } + } + return true; +} + /********** Utilities for TVM Containers / ByteArray **********/ /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index f21fe1f5c57b3..85266057548c2 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -82,13 +82,13 @@ TEST(ComputeDAG, AccessAnalyzer) { } } - std::set is_injective = {data, padding, kernel, bias, bias_add, - bn_scale, bn_mul, bn_offset, bn_add, relu}; + std::set is_simple_access = {data, padding, kernel, bias, bias_add, + bn_scale, bn_mul, bn_offset, bn_add, relu}; for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { - if (is_injective.count(stage_id)) { - CHECK(dag->access_analyzer.IsInjective(dag->ops[stage_id])); + if (is_simple_access.count(stage_id)) { + CHECK(dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id])); } else { - CHECK(!dag->access_analyzer.IsInjective(dag->ops[stage_id])); + CHECK(!dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id])); } } @@ -125,7 +125,7 @@ TEST(ComputeDAG, AccessAnalyzer) { {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, {bn_mul, bn_add}, {bn_offset, bn_add}, {bn_add, relu}}; for (const auto& pair : consumer_list) { - dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), 1); CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); } @@ -136,7 +136,7 @@ TEST(ComputeDAG, AccessAnalyzer) { {bn_add, {bn_mul, bn_offset}}, {relu, {bn_add}}}; for (const auto& pair : producer_list) { - dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), pair.second.size()); for (const auto& target : pair.second) { CHECK(op_set.count(s0->stages[target]->op)); @@ -151,7 +151,7 @@ TEST(ComputeDAG, AccessAnalyzer) { { std::vector> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}}; for (const auto& pair : consumer_list) { - dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), 1); CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); } @@ -162,7 +162,7 @@ TEST(ComputeDAG, AccessAnalyzer) { {bn_add, {bn_mul, bn_offset}}, {relu, {bn_add}}}; for (const auto& pair : producer_list) { - dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), pair.second.size()); for (const auto& target : pair.second) { CHECK(op_set.count(s0->stages[target]->op));