From 603d3d25d83bf4b3af0fc69448236ea752b5d6ea Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 15 Jul 2020 16:01:00 +0800 Subject: [PATCH 01/14] Add annotation step --- python/tvm/auto_scheduler/loop_state.py | 31 +++++++ src/auto_scheduler/compute_dag.cc | 4 + src/auto_scheduler/loop_state.cc | 92 +++++++++++++++++++ src/auto_scheduler/loop_state.h | 30 ++---- src/auto_scheduler/measure_record.cc | 16 +++- src/auto_scheduler/transform_step.cc | 75 +++++++++++++++ src/auto_scheduler/transform_step.h | 54 +++++++++++ .../test_auto_scheduler_loop_state.py | 12 ++- 8 files changed, 287 insertions(+), 27 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 693a668a158e..2c52e0dd6a28 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -173,6 +173,37 @@ def fuse(self, stage, iters): self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) return res + def vectorize(self, stage, iterator): + stage_id = self._resolve_stage_id(stage) + self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) + return res + + def parallel(self, stage, iterator): + stage_id = self._resolve_stage_id(stage) + self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) + return res + + def unroll(self, stage, iterator, max_unroll=None): + stage_id = self._resolve_stage_id(stage) + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, + max_unroll if max_unroll else -1) + return res + + def bind_thread(self, stage_id, iterator, thread_name): + trans_table = { + "vthread": 4, + "blockIdx.x": 5, + "threadIdx.x": 6, + "blockIdx.y": 7, + "threadIdx.y": 8, + } + thread_id = trans_table[thread_name] + stage_id = self._resolve_stage_id(stage_id) + + self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, + thread_id) + return res + def copy(self): """ Do deep copy of this State. """ state = State(self.state_object, self.compute_dag) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index a7abcb8a7ebf..c81aa445f85e 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -280,6 +280,8 @@ std::pair> ComputeDAG::ApplySteps( ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } @@ -332,6 +334,8 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 1bfcb9ebc58a..23c0a65b9965 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -131,6 +131,54 @@ Iterator State::fuse(int stage_id, const Array& iters) { return DoFuseStep(step); } +Iterator State::vectorize(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = AnnotationStep( + stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +Iterator State::parallel(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); + + // don't unroll if the extent is larger than max_unroll + if (max_unroll != -1 && it->range.defined()) { + if (auto imm = it->range->extent.as()) { + if (imm->value > max_unroll) { + return it; + } + } + } + + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +Iterator State::bind_thread(int stage_id, const Iterator& it, + IteratorAnnotation thread_type) { + const Stage& stage = operator->()->stages[stage_id]; + if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadY) { + LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " + << "kThreadX, kThreadY"; + } + AnnotationStep step = AnnotationStep( + stage_id, GetIndex(stage->iters, it), thread_type); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + + /********** Step implementations for state **********/ void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -267,6 +315,20 @@ Iterator State::DoFuseStep(const FuseStep& step) { return new_it; } +Iterator State::DoAnnotationStep(const AnnotationStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Iterator it = stage->iters[step->iter_id]; + + CHECK(it->annotation == IteratorAnnotation::kNone); + Iterator new_it = Iterator(it->name, it->range, it->iter_kind, + step->annotation); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters.Set(step->iter_id, std::move(new_it)); + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, std::move(new_stage)); + return new_it; +} + void State::DoSteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; @@ -277,6 +339,8 @@ void State::DoSteps(const ComputeDAG& dag) { DoSplitStep(GetRef(ps)); } else if (auto ps = step.as()) { DoFuseStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoAnnotationStep(GetRef(ps)); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -405,6 +469,34 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel") + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateBindThread") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int thread_type) { + const auto& res = + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; + }); + + TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 04e5304b6943..f0c9a60a5eed 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -91,30 +91,6 @@ enum class IteratorKind : int { kSpecial = 3 }; -/*! \brief The type of an iterator's annotation. */ -enum class IteratorAnnotation : int { - /*! \brief This iterator has no annotation. */ - kNone = 0, - /*! \brief This iterator has been unrolled. */ - kUnroll = 1, - /*! \brief This iterator has been vectorized. */ - kVectorize = 2, - /*! \brief This iterator has been paralleld. */ - kParallel = 3, - /*! \brief This iterator has been bind to vthread. */ - kVThread = 4, - /*! \brief This iterator has been bind to blockIdx.x. */ - kBlockX = 5, - /*! \brief This iterator has been bind to threadIdx.x. */ - kThreadX = 6, - /*! \brief This iterator has been bind to blockIdx.y. */ - kBlockY = 7, - /*! \brief This iterator has been bind to threadIdx.y. */ - kThreadY = 8, - /*! \brief This iterator has been mapped with a tensorize intrinsic. */ - kTensorized = 9 -}; - /*! * \brief A for loop iterator * Similar to tvm::IterVar in `include/tvm/tir/expr.h` @@ -308,6 +284,11 @@ class State : public ObjectRef { * \return The iterator result after fuse. */ Iterator fuse(int stage_id, const Array& iters); + Iterator vectorize(int stage_id, const Iterator& it); + Iterator parallel(int stage_id, const Iterator& it); + Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); + Iterator bind_thread(int stage_id, const Iterator& it, + IteratorAnnotation thread_type); TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); @@ -334,6 +315,7 @@ class State : public ObjectRef { * \return The iterator result after fuse. */ Iterator DoFuseStep(const FuseStep& step); + Iterator DoAnnotationStep(const AnnotationStep& step); /*! * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later). diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index f6f882edb5a2..684f6874cbae 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -101,6 +101,11 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(IntArrayToVector(ps->fused_ids)); + } else if (auto ps = data[i].as<::tvm::auto_scheduler::AnnotationStepNode>()) { + writer->WriteArrayItem(std::string("AN")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(static_cast(ps->annotation)); } else { LOG(FATAL) << "Invalid step: " << data[i]; } @@ -114,7 +119,7 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { std::vector int_list; bool s, inner_to_outer; std::string name, scope_name, pragma_type, ti_func_name; - int stage_id, iter_id, extent; + int stage_id, iter_id, extent, ann; reader->BeginArray(); data->clear(); @@ -169,6 +174,15 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { fused_ids.push_back(i); } data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids)); + } else if (name == "AN") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&ann); + data->push_back(::tvm::auto_scheduler::AnnotationStep(stage_id, + iter_id, ::tvm::auto_scheduler::IteratorAnnotation(ann))); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 90b4db838fef..58896111eb81 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -235,5 +235,80 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } +/********** Annotation **********/ +AnnotationStep::AnnotationStep(int stage_id, int iter_id, + IteratorAnnotation ann) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->annotation = ann; + data_ = std::move(node); +} + +void AnnotationStepNode::ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + + switch (annotation) { + case IteratorAnnotation::kUnroll: + stage.unroll(axes[iter_id]); break; + case IteratorAnnotation::kVectorize: + stage.vectorize(axes[iter_id]); break; + case IteratorAnnotation::kParallel: + stage.parallel(axes[iter_id]); break; + case IteratorAnnotation::kVThread: + stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; + case IteratorAnnotation::kBlockX: + stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; + case IteratorAnnotation::kBlockY: + stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; + case IteratorAnnotation::kThreadX: + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); break; + case IteratorAnnotation::kThreadY: + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; + case IteratorAnnotation::kNone: break; + default: + LOG(FATAL) << "Invalid Annotation " << static_cast(annotation); break; + } + + stages->Set(stage_id, std::move(stage)); +} + +String AnnotationStepNode::PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& iter = (*stage_to_axes)[stage][iter_id]; + + ss << "s[" << CleanName(stage->op->name) << "]."; + switch (annotation) { + case IteratorAnnotation::kUnroll: ss << "unroll("; break; + case IteratorAnnotation::kVectorize: ss << "vectorize("; break; + case IteratorAnnotation::kParallel: ss << "parallel("; break; + case IteratorAnnotation::kVThread: + case IteratorAnnotation::kBlockX: + case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kThreadX: + case IteratorAnnotation::kThreadY: ss << "bind("; break; + case IteratorAnnotation::kNone: break; + default: + LOG(FATAL) << "Invalid annotation " << static_cast(annotation); break; + } + ss << CleanName(iter->var->name_hint); + switch (annotation) { + case IteratorAnnotation::kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; + case IteratorAnnotation::kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; + case IteratorAnnotation::kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; + case IteratorAnnotation::kThreadX: ss << ", tvm.thread_axis(\"threadIdx.x\")"; break; + case IteratorAnnotation::kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; + default: break; + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index d840cc009e2d..cb71a9d48b14 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -53,6 +53,30 @@ namespace auto_scheduler { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; +/*! \brief The type of an iterator's annotation. */ +enum class IteratorAnnotation : int { + /*! \brief This iterator has no annotation. */ + kNone = 0, + /*! \brief This iterator has been unrolled. */ + kUnroll = 1, + /*! \brief This iterator has been vectorized. */ + kVectorize = 2, + /*! \brief This iterator has been paralleld. */ + kParallel = 3, + /*! \brief This iterator has been bind to vthread. */ + kVThread = 4, + /*! \brief This iterator has been bind to blockIdx.x. */ + kBlockX = 5, + /*! \brief This iterator has been bind to threadIdx.x. */ + kThreadX = 6, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockY = 7, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadY = 8, + /*! \brief This iterator has been mapped with a tensorize intrinsic. */ + kTensorized = 9 +}; + /*! * \brief The base class of transformation steps. Each step has its corresponding tvm.te * schedule primitives. @@ -220,6 +244,36 @@ class FuseStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; +/*! + * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. + * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) + */ +class AnnotationStepNode: public StepNode { + public: + int iter_id; + IteratorAnnotation annotation; + + void ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes) const; + + static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); +}; + +/*! + * \brief Managed reference to AnnotationStepNode. + * \sa AnnotationStepNode + */ +class AnnotationStep : public Step { + public: + AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); + + TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 0801d9200275..f437a79997c2 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -26,7 +26,7 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, conv2d_nchw_bn_relu -def test_split_fuse_reorder(): +def test_split_fuse_reorder_annotation(): A, B, C = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() @@ -61,5 +61,13 @@ def test_split_fuse_reorder(): assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2 + s1.parallel(C, j1) + s1.unroll(C, j2) + s1.vectorize(C, j3) + s1.bind_thread(C, i1, "blockIdx.x") + s1.bind_thread(C, i2, "vthread") + s1.bind_thread(C, i3, "threadIdx.y") + + if __name__ == "__main__": - test_split_fuse_reorder() + test_split_fuse_reorder_annotation() From d81553ce2f976096e8025ebc05bab310d3f1be63 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 15 Jul 2020 20:35:37 +0800 Subject: [PATCH 02/14] Add compute_at/compute_root/compute_inline --- python/tvm/auto_scheduler/loop_state.py | 39 +++ src/auto_scheduler/compute_dag.cc | 12 + src/auto_scheduler/loop_state.cc | 249 ++++++++++++++++++ src/auto_scheduler/loop_state.h | 53 ++++ src/auto_scheduler/measure_record.cc | 35 ++- src/auto_scheduler/transform_step.cc | 85 ++++++ src/auto_scheduler/transform_step.h | 75 ++++++ src/auto_scheduler/utils.h | 17 ++ .../unittest/test_auto_scheduler_common.py | 2 +- .../test_auto_scheduler_loop_state.py | 65 +++++ 10 files changed, 630 insertions(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 2c52e0dd6a28..468c612e3c87 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -123,6 +123,45 @@ def reorder(self, stage, order): self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) + def compute_at(self, stage_id, target_stage_id, target_iter): + """ + Parameters + ---------- + stage_id : Union[int, Operation, Tensor] + The index of source stage + target_stage_id : Union[int, Operation, Tensor] + The index of the target stage of compute_at + target_iter : Iterator + The target Iterator of compute_at + """ + stage_id = self._resolve_stage_id(stage_id) + target_stage_id = self._resolve_stage_id(target_stage_id) + + self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, + target_stage_id, target_iter) + + def compute_root(self, stage_id): + """ + Parameters + ---------- + stage_id : Union[int, Operation, Tensor] + The index of the stage to compute root + """ + stage_id = self._resolve_stage_id(stage_id) + + self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) + + def compute_inline(self, stage_id): + """ + Parameters + ---------- + stage_id : Union[int, Operation, Tensor] + The index of the stage to compute inline + """ + stage_id = self._resolve_stage_id(stage_id) + + self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) + def split(self, stage, iterator, lengths, inner_to_outer=True): """ Schedule primitive corresponds to te.split. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index c81aa445f85e..3db3cba7b534 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -276,6 +276,12 @@ std::pair> ComputeDAG::ApplySteps( // return value, so the ApplyToSchedule is not able to be merged to single interface if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -330,6 +336,12 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const for (const auto& step : transform_steps) { if (auto ps = step.as()) { ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 23c0a65b9965..63491a889a16 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -90,12 +90,102 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array& iters, data_ = std::move(node); } +/********** AttachMap **********/ +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, + int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the current entry of stage + DeleteStageEntry(pnode, stage_id); + + // store the new relation + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = + std::make_pair(target_stage_id, target_iter_id); + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the entry of old stage + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters) { + AttachMapNode* pnode = CopyOnWrite(); + + CHECK_EQ(old_iters.size(), new_iters.size()); + for (size_t i = 0; i < old_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + if (entry == pnode->iter_to_attached_stages.end()) { + continue; + } + + // replace iter in the value of `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // replace iter in the key of `iter_to_attached_stages` + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + if (old_entry != pnode->stage_to_attach_iter.end()) { + // delete value in `iter_to_attached_stages` + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + DeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // delete key in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMap(make_object()); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + /********** State **********/ State::State(const Array& ops) { auto node = make_object(); for (const auto& op : ops) { node->stages.push_back(Stage(op)); } + node->attach_map = AttachMap(make_object()); node->concrete = true; data_ = std::move(node); } @@ -112,6 +202,27 @@ void State::reorder(int stage_id, const Array& order) { DoReorderStep(step); } +void State::compute_at(int stage_id, int target_stage_id, + const Iterator& target_iter) { + const Stage& target_stage = operator->()->stages[target_stage_id]; + ComputeAtStep step = ComputeAtStep( + stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeAtStep(step); +} + +void State::compute_root(int stage_id) { + ComputeRootStep step = ComputeRootStep(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeRootStep(step); +} + +void State::compute_inline(int stage_id) { + ComputeInlineStep step = ComputeInlineStep(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeInlineStep(step); +} + Array State::split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; @@ -191,12 +302,74 @@ void State::DoReorderStep(const ReorderStep& step) { Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); } +void State::DoComputeAtStep(const ComputeAtStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + // after compute_at, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call + // ComputeDAG::ReplayAndInferBound + std::vector new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, + it->annotation)); + } + + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, + Stage(stage->op, stage->op_type, std::move(new_iters), ComputeAtKind::kIter, + stage->attrs)); + pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, + step->target_iter_id); +} + +void State::DoComputeRootStep(const ComputeRootStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + // after compute_root, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call + // ComputeDAG::ReplayAndInferBound + std::vector new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, + it->annotation)); + } + + // update attach map + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, + std::move(new_iters), ComputeAtKind::kRoot, + stage->attrs)); + pstate->attach_map.DeleteStage(step->stage_id); +} + +void State::DoComputeInlineStep(const ComputeInlineStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + StateNode* pstate = CopyOnWrite(); + + // CHECK the validity of compute_inline + const auto& iter_to_attached_stages = + pstate->attach_map->iter_to_attached_stages; + for (size_t i = 0; i < stage->iters.size(); ++i) { + CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), + 0) + << "Invalid compute_inline: Because there are some other stages " + "that are attached to the target stage"; + } + + auto new_stage = pstate->stages[step->stage_id]; + new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; + pstate->stages.Set(step->stage_id, std::move(new_stage)); + pstate->attach_map.DeleteStage(step->stage_id); +} + // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; + size_t old_iter_size = stage->iters.size(); bool concrete = true; Optional tosplit_min, tosplit_extent; @@ -258,6 +431,16 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); pstate->concrete &= concrete; + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + for (size_t i = iter_id; i < old_iter_size; ++i) { + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i + lengths.size()); + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return outs; } @@ -268,6 +451,7 @@ Array State::DoSplitStep(const SplitStep& step) { Iterator State::DoFuseStep(const FuseStep& step) { int stage_id = step->stage_id; const Stage& stage = operator->()->stages[stage_id]; + size_t old_iter_size = static_cast(stage->iters.size()); String new_name; PrimExpr new_extent = 1; @@ -278,6 +462,17 @@ Iterator State::DoFuseStep(const FuseStep& step) { CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); } + if (i != step->fused_ids.size() - 1) { + const auto& iter_to_attached_stage = + operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair( + stage_id, step->fused_ids[i])) != iter_to_attached_stage.end()) { + LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " + << "stages. State before fusion:\n" + << *this; + } + } + const Iterator& it = stage->iters[step->fused_ids[i]]; new_name = new_name + it->name + "@"; @@ -312,6 +507,24 @@ Iterator State::DoFuseStep(const FuseStep& step) { pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + const size_t begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); + for (size_t i = 0; i < old_iter_size; ++i) { + if (i <= begin_id) { + continue; + } else if (i > end_id) { // move forward + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i - end_id + begin_id); + } else { // move to the fused id + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, begin_id); + } + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return new_it; } @@ -335,6 +548,12 @@ void State::DoSteps(const ComputeDAG& dag) { for (const auto& step : operator->()->transform_steps) { if (auto ps = step.as()) { DoReorderStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeAtStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeRootStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeInlineStep(GetRef(ps)); } else if (auto ps = step.as()) { DoSplitStep(GetRef(ps)); } else if (auto ps = step.as()) { @@ -396,6 +615,17 @@ void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_ indent += 2; } + + if (state.defined()) { + AttachMap::IterKey iter_key(stage_id, i); + auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); + if (pair != state->attach_map->iter_to_attached_stages.end()) { + for (const auto& attach_stage_id : pair->second) { + PrintStage(os, attach_stage_id, state, base_indent + indent, + delete_trivial_loop); + } + } + } } for (size_t j = 0; j < base_indent + indent; ++j) { @@ -456,6 +686,25 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") return state; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") + .set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot") + .set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline") + .set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index f0c9a60a5eed..b5f1227a4c07 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -51,6 +51,9 @@ #include #include +#include +#include +#include #include "transform_step.h" @@ -193,6 +196,48 @@ class Stage : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); }; +/*! + * \brief stores the compute_at relation between stages + * This stores a bi-directional mapping from stages and iter: + * 1. Stage to its attached iterator + * 2. Iterator to the stage attached to it + * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages + * to query the relations + */ +class AttachMapNode: public Object { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + std::unordered_map stage_to_attach_iter; + std::unordered_map> iter_to_attached_stages; + + static constexpr const char* _type_key = "ansor.AttachMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); +}; + +/*! + * \brief Managed reference to AttachMapNode. + * \sa AttachMapNode + */ +class AttachMap : public ObjectRef { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + void DeleteStage(int stage_id); + void ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters); + AttachMap ApplyStageIdOfffset(int start_id, int offset) const; + + TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); + + private: + static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); +}; + /*! * \brief A state in the search process. * It consists of the current loop structure and a list of transformation steps used to construct @@ -205,6 +250,7 @@ class StateNode : public Object { Array stages; /*! \brief History transformation steps. */ Array transform_steps; + AttachMap attach_map; /*! * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. @@ -214,6 +260,7 @@ class StateNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("stages", &stages); v->Visit("transform_steps", &transform_steps); + v->Visit("attach_map", &attach_map); v->Visit("concrete", &concrete); } @@ -267,6 +314,9 @@ class State : public ObjectRef { * \param order The expected iterator order. */ void reorder(int stage_id, const Array& order); + void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + void compute_root(int stage_id); + void compute_inline(int stage_id); /*! * \brief Schedule primitive corresponds to te.split. * \param stage_id The index of the stage to be split. @@ -303,6 +353,9 @@ class State : public ObjectRef { * \param step A ReorderStep. */ void DoReorderStep(const ReorderStep& step); + void DoComputeAtStep(const ComputeAtStep& step); + void DoComputeRootStep(const ComputeRootStep& step); + void DoComputeInlineStep(const ComputeInlineStep& step); /*! * \brief Apply split step to current state. * \param step A SplitStep. diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 684f6874cbae..61957ff82f33 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -89,6 +89,17 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { writer->WriteArrayItem(std::string("RE")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(IntArrayToVector(ps->after_ids)); + } else if (auto ps = data[i].as<::tvm::auto_scheduler::ComputeAtStepNode>()) { + writer->WriteArrayItem(std::string("CA")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->target_stage_id); + writer->WriteArrayItem(ps->target_iter_id); + } else if (auto ps = data[i].as<::tvm::auto_scheduler::ComputeRootStepNode>()) { + writer->WriteArrayItem(std::string("CR")); + writer->WriteArrayItem(ps->stage_id); + } else if (auto ps = data[i].as<::tvm::auto_scheduler::ComputeInlineStepNode>()) { + writer->WriteArrayItem(std::string("CI")); + writer->WriteArrayItem(ps->stage_id); } else if (auto ps = data[i].as<::tvm::auto_scheduler::SplitStepNode>()) { writer->WriteArrayItem(std::string("SP")); writer->WriteArrayItem(ps->stage_id); @@ -119,7 +130,7 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { std::vector int_list; bool s, inner_to_outer; std::string name, scope_name, pragma_type, ti_func_name; - int stage_id, iter_id, extent, ann; + int stage_id, iter_id, extent, ann, target_stage_id; reader->BeginArray(); data->clear(); @@ -140,6 +151,28 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { after_ids.push_back(i); } data->push_back(::tvm::auto_scheduler::ReorderStep(stage_id, after_ids)); + } else if (name == "CA") { + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&target_stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&iter_id); + data->push_back(::tvm::auto_scheduler::ComputeAtStep( + stage_id, target_stage_id, iter_id)); + } else if (name == "CR") { + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&stage_id); + data->push_back(::tvm::auto_scheduler::ComputeRootStep(stage_id)); + } else if (name == "CI") { + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&stage_id); + data->push_back(::tvm::auto_scheduler::ComputeInlineStep(stage_id)); } else if (name == "SP") { s = reader->NextArrayItem(); CHECK(s); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 58896111eb81..3852ed642122 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -82,6 +82,91 @@ String ReorderStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } +/********** Compute At **********/ +ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->target_stage_id = target_stage_id; + node->target_iter_id = target_iter_id; + data_ = std::move(node); +} + +void ComputeAtStepNode::ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const IterVar& target_axis = + (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; + stage.compute_at((*stages)[target_stage_id], target_axis); + + stages->Set(stage_id, std::move(stage)); +} + +String ComputeAtStepNode::PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + + ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" + << CleanName(target_stage->op->name) << "], " + << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); + + ss << ")\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute Root **********/ +ComputeRootStep::ComputeRootStep(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + data_ = std::move(node); +} + +void ComputeRootStepNode::ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const { + auto stage = (*stages)[stage_id]; + stage.compute_root(); + stages->Set(stage_id, std::move(stage)); +} + +String ComputeRootStepNode::PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Compute Inline **********/ +ComputeInlineStep::ComputeInlineStep(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + data_ = std::move(node); +} + +void ComputeInlineStepNode::ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const { + auto stage = (*stages)[stage_id]; + stage.compute_inline(); + stages->Set(stage_id, std::move(stage)); +} + +String ComputeInlineStepNode::PrintAsPythonAPI( + Array *stages, + StageToAxesMap *stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + /********** Split **********/ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, int iter_id, diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index cb71a9d48b14..efe630560ca5 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -143,6 +143,81 @@ class ReorderStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; +/*! \brief Compute at step that corresponds to te::Stage::compute_at */ +class ComputeAtStepNode: public StepNode { + public: + int target_stage_id; + int target_iter_id; + + void ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes) const; + + static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); +}; + +/*! + * \brief Managed reference to ComputeAtStepNode. + * \sa ComputeAtStepNode + */ +class ComputeAtStep : public Step { + public: + ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); +}; + +/*! \brief Compute root step that corresponds to te::Stage::compute_root */ +class ComputeRootStepNode: public StepNode { + public: + void ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes) const; + + static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); +}; + +/*! + * \brief Managed reference to ComputeRootStepNode. + * \sa ComputeRootStepNode + */ +class ComputeRootStep : public Step { + public: + explicit ComputeRootStep(int stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); +}; + +/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ +class ComputeInlineStepNode: public StepNode { + public: + void ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes) const; + + static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); +}; + +/*! + * \brief Managed reference to ComputeInlineStepNode. + * \sa ComputeInlineStepNode + */ +class ComputeInlineStep : public Step { + public: + explicit ComputeInlineStep(int stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); +}; + /*! * \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 5637780e3991..976f2802842c 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -89,6 +89,23 @@ inline int GetIndex(const Array& array, const T& to_locate) { return -1; } +/*! \brief Delete an element in a vector */ +template +inline void DeleteItem(Array* array, const T& to_delete) { + auto iter = std::find(array->begin(), array->end(), to_delete); + if (iter != array->end()) { + array->erase(iter); + } +} + +template +inline void DeleteItem(std::vector* array, const T& to_delete) { + auto iter = std::find(array->begin(), array->end(), to_delete); + if (iter != array->end()) { + array->erase(iter); + } +} + /*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 078e1ae8e854..fa22fdc5597c 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -40,7 +40,7 @@ def matmul_auto_scheduler_test_rename_0(N, M, K): C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] - +@auto_scheduler.register_workload def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, CI, H, W), name='Data') kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index f437a79997c2..fa503a20205f 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -69,5 +69,70 @@ def test_split_fuse_reorder_annotation(): s1.bind_thread(C, i3, "threadIdx.y") +def test_compute_at_root_inline(): + dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + s0 = dag.get_init_state() + + # data, padding, kernel = 0, 1, 2 + conv = s0.stage_ops[3] + # bias = 4 + bias_add = s0.stage_ops[5] + # bn_scale = 6 + bn_mul = s0.stage_ops[7] + # bn_offset = 8 + bn_add = s0.stage_ops[9] + relu = s0.stage_ops[10] + + s0.compute_inline(bn_add) + s0.compute_inline(bn_mul) + s0.compute_inline(bias_add) + s0.compute_at(conv, relu, s0[relu].iters[2]) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + s0.compute_root(conv) + s0.compute_root(bn_mul) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + "for i (None)\n" + \ + " for j (None)\n" + \ + " for k (None)\n" + \ + " for l (None)\n" + \ + " Bn_mul = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + if __name__ == "__main__": test_split_fuse_reorder_annotation() + test_compute_at_root_inline() From 78db4c9f2b8a8abe6df84f831596242879a68498 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 16 Jul 2020 14:56:22 +0800 Subject: [PATCH 03/14] Doc update --- python/tvm/auto_scheduler/loop_state.py | 163 +++++++++++++----- src/auto_scheduler/loop_state.cc | 154 ++++++----------- src/auto_scheduler/loop_state.h | 91 ++++++++-- src/auto_scheduler/measure_record.cc | 16 +- src/auto_scheduler/transform_step.cc | 119 ++++++++----- src/auto_scheduler/transform_step.h | 103 ++++++++--- src/auto_scheduler/utils.h | 9 - .../test_auto_scheduler_loop_state.py | 7 +- 8 files changed, 415 insertions(+), 247 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 468c612e3c87..b334f6674366 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -119,48 +119,51 @@ def reorder(self, stage, order): order : List[Iterator] Iterators in the expected order. """ - stage_id = self._resolve_stage_id(stage) + self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage), + order) - self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) + def compute_at(self, stage, target_stage, target_iter): + """ Schedule primitive corresponds to te.compute_at. - def compute_at(self, stage_id, target_stage_id, target_iter): - """ Parameters ---------- - stage_id : Union[int, Operation, Tensor] - The index of source stage - target_stage_id : Union[int, Operation, Tensor] - The index of the target stage of compute_at + stage : Union[int, Operation, Tensor] + The Stage to be compute at, can be a Stage order index, Stage operation or stage + output tensor. + target_stage : Union[int, Operation, Tensor] + The target stage of compute_at, can be a Stage order index, Stage operation or stage + output tensor. target_iter : Iterator - The target Iterator of compute_at + The target Iterator of compute_at. """ - stage_id = self._resolve_stage_id(stage_id) - target_stage_id = self._resolve_stage_id(target_stage_id) + self.state_object = _ffi_api.StateComputeAt(self.state_object, + self._resolve_stage_id(stage), + self._resolve_stage_id(target_stage), + target_iter) - self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, - target_stage_id, target_iter) + def compute_root(self, stage): + """ Schedule primitive corresponds to te.compute_root. - def compute_root(self, stage_id): - """ Parameters ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to compute root + stage : Union[int, Operation, Tensor] + The Stage to be compute root, can be a Stage order index, Stage operation or stage + output tensor. """ - stage_id = self._resolve_stage_id(stage_id) + self.state_object = _ffi_api.StateComputeRoot(self.state_object, + self._resolve_stage_id(stage)) - self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) + def compute_inline(self, stage): + """ Schedule primitive corresponds to te.compute_inline. - def compute_inline(self, stage_id): - """ Parameters ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to compute inline + stage : Union[int, Operation, Tensor] + The Stage to be compute inline, can be a Stage order index, Stage operation or stage + output tensor. """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) + self.state_object = _ffi_api.StateComputeInline(self.state_object, + self._resolve_stage_id(stage)) def split(self, stage, iterator, lengths, inner_to_outer=True): """ Schedule primitive corresponds to te.split. @@ -185,10 +188,9 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): res_its : List[Iterator] The splitted new Iterators """ - stage_id = self._resolve_stage_id(stage) - - self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, - inner_to_outer) + self.state_object, res = _ffi_api.StateSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, lengths, inner_to_outer) return res def fuse(self, stage, iters): @@ -200,35 +202,103 @@ def fuse(self, stage, iters): The Stage to be fused, can be a Stage order index, Stage operation or stage output tensor. iters : List[Iterator] - The iterators to be fused + The iterators to be fused. Returns ------- res_it : Iterator - The fused Iterator + The fused Iterator. """ - stage_id = self._resolve_stage_id(stage) - - self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) + self.state_object, res = _ffi_api.StateFuse(self.state_object, + self._resolve_stage_id(stage), iters) return res def vectorize(self, stage, iterator): - stage_id = self._resolve_stage_id(stage) - self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) + """ Schedule primitive corresponds to te.vectorize. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be vectorized, can be a Stage order index, Stage operation or stage + output tensor. + iterator : Iterator + The iterator to be vectorized. + + Returns + ------- + res_it : Iterator + The vectorized Iterator. + """ + self.state_object, res = _ffi_api.StateVectorize(self.state_object, + self._resolve_stage_id(stage), iterator) return res def parallel(self, stage, iterator): - stage_id = self._resolve_stage_id(stage) - self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) + """ Schedule primitive corresponds to te.parallel. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be paralleled, can be a Stage order index, Stage operation or stage + output tensor. + iterator : Iterator + The iterator to be paralleled. + + Returns + ------- + res_it : Iterator + The paralleled Iterator. + """ + self.state_object, res = _ffi_api.StateParallel(self.state_object, + self._resolve_stage_id(stage), iterator) return res def unroll(self, stage, iterator, max_unroll=None): - stage_id = self._resolve_stage_id(stage) - self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, + """ Schedule primitive corresponds to te.unrolled. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be unrolled, can be a Stage order index, Stage operation or stage + output tensor. + iterator : Iterator + The iterator to be unrolled. + max_unroll : Optional[int] + The max unroll limit. Iterator with extent larger than this limit will be skipped. + + Returns + ------- + res_it : Iterator + The unrolled Iterator. + """ + self.state_object, res = _ffi_api.StateUnroll(self.state_object, + self._resolve_stage_id(stage), iterator, max_unroll if max_unroll else -1) return res - def bind_thread(self, stage_id, iterator, thread_name): + def bind(self, stage, iterator, thread_name): + """ Schedule primitive corresponds to te.bind. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be binded, can be a Stage order index, Stage operation or stage + output tensor. + iterator : Iterator + The iterator to be binded. + thread_name : str + The thread type to be binded. Currently support: + - vthread + - blockIdx.x + - threadIdx.x + - blockIdx.y + - threadIdx.y + + Returns + ------- + res_it : Iterator + The binded Iterator. + """ trans_table = { "vthread": 4, "blockIdx.x": 5, @@ -236,11 +306,12 @@ def bind_thread(self, stage_id, iterator, thread_name): "blockIdx.y": 7, "threadIdx.y": 8, } - thread_id = trans_table[thread_name] - stage_id = self._resolve_stage_id(stage_id) + if not thread_name in trans_table.keys(): + raise ValueError("Invalid thread_name: ", thread_name) - self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, - thread_id) + self.state_object, res = _ffi_api.StateBind(self.state_object, + self._resolve_stage_id(stage), iterator, + trans_table[thread_name]) return res def copy(self): diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 63491a889a16..b646f8c63142 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -91,29 +91,26 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array& iters, } /********** AttachMap **********/ -void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, - int target_iter_id) { +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { AttachMapNode* pnode = CopyOnWrite(); - // delete the current entry of stage + // Delete the current entry of this stage DeleteStageEntry(pnode, stage_id); - // store the new relation + // Store the new relations to map IterKey iter_key(target_stage_id, target_iter_id); - pnode->stage_to_attach_iter[stage_id] = - std::make_pair(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = iter_key; pnode->iter_to_attached_stages[iter_key].push_back(stage_id); } void AttachMap::DeleteStage(int stage_id) { AttachMapNode* pnode = CopyOnWrite(); - - // delete the entry of old stage + // Delete the original stage entry DeleteStageEntry(pnode, stage_id); } -void AttachMap::ReplaceIters(const std::vector& old_iters, - const std::vector& new_iters) { +void AttachMap::UpdateIters(const std::vector& old_iters, + const std::vector& new_iters) { AttachMapNode* pnode = CopyOnWrite(); CHECK_EQ(old_iters.size(), new_iters.size()); @@ -123,12 +120,12 @@ void AttachMap::ReplaceIters(const std::vector& old_iters, continue; } - // replace iter in the value of `stage_to_attach_iter` + // Replace iter in the value of `stage_to_attach_iter` for (const auto& s : entry->second) { pnode->stage_to_attach_iter[s] = new_iters[i]; } - // replace iter in the key of `iter_to_attached_stages` + // Replace iter in the key of `iter_to_attached_stages` std::vector attached_stages = std::move(entry->second); pnode->iter_to_attached_stages.erase(entry); pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); @@ -138,47 +135,17 @@ void AttachMap::ReplaceIters(const std::vector& old_iters, void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { auto old_entry = pnode->stage_to_attach_iter.find(stage_id); if (old_entry != pnode->stage_to_attach_iter.end()) { - // delete value in `iter_to_attached_stages` + // Delete value in `iter_to_attached_stages` auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); DeleteItem(&entry2->second, stage_id); if (entry2->second.size() == 0) { pnode->iter_to_attached_stages.erase(entry2); } - // delete key in `stage_to_attach_iter` + // Delete key in `stage_to_attach_iter` pnode->stage_to_attach_iter.erase(old_entry); } } -AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { - AttachMap map = AttachMap(make_object()); - auto pmap = map.CopyOnWrite(); - for (const auto& x : operator->()->stage_to_attach_iter) { - auto key = x.first; - if (key >= start_id) { - key += offset; - } - auto value = x.second; - if (value.first >= start_id) { - value.first += offset; - } - pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); - } - for (const auto& x : operator->()->iter_to_attached_stages) { - auto key = x.first; - if (key.first >= start_id) { - key.first += offset; - } - auto value = x.second; - for (auto& i : value) { - if (i >= start_id) { - i += offset; - } - } - pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); - } - return map; -} - /********** State **********/ State::State(const Array& ops) { auto node = make_object(); @@ -202,11 +169,10 @@ void State::reorder(int stage_id, const Array& order) { DoReorderStep(step); } -void State::compute_at(int stage_id, int target_stage_id, - const Iterator& target_iter) { +void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStep( - stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); + ComputeAtStep step = + ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); return DoComputeAtStep(step); } @@ -244,8 +210,8 @@ Iterator State::fuse(int stage_id, const Array& iters) { Iterator State::vectorize(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStep( - stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } @@ -276,20 +242,17 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { return DoAnnotationStep(step); } -Iterator State::bind_thread(int stage_id, const Iterator& it, - IteratorAnnotation thread_type) { +Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadY) { LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " << "kThreadX, kThreadY"; } - AnnotationStep step = AnnotationStep( - stage_id, GetIndex(stage->iters, it), thread_type); + AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } - /********** Step implementations for state **********/ void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -310,16 +273,13 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, - it->annotation)); + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); } StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, - Stage(stage->op, stage->op_type, std::move(new_iters), ComputeAtKind::kIter, - stage->attrs)); - pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, - step->target_iter_id); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kIter, stage->attrs)); + pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } void State::DoComputeRootStep(const ComputeRootStep& step) { @@ -330,15 +290,13 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, - it->annotation)); + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); } // update attach map StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, - std::move(new_iters), ComputeAtKind::kRoot, - stage->attrs)); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kRoot, stage->attrs)); pstate->attach_map.DeleteStage(step->stage_id); } @@ -348,11 +306,9 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { StateNode* pstate = CopyOnWrite(); // CHECK the validity of compute_inline - const auto& iter_to_attached_stages = - pstate->attach_map->iter_to_attached_stages; + const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), - 0) + CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) << "Invalid compute_inline: Because there are some other stages " "that are attached to the target stage"; } @@ -431,15 +387,15 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); pstate->concrete &= concrete; - // we have to replace the iterators in attach map, - // these two vectors keep the replacement mapping - std::vector from_iters; - std::vector to_iters; + // We have to update the iterator relations in attach map, these two vectors keep the replacement + // mapping + std::vector from_iters; + std::vector to_iters; for (size_t i = iter_id; i < old_iter_size; ++i) { from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, i + lengths.size()); } - pstate->attach_map.ReplaceIters(from_iters, to_iters); + pstate->attach_map.UpdateIters(from_iters, to_iters); return outs; } @@ -463,10 +419,9 @@ Iterator State::DoFuseStep(const FuseStep& step) { } if (i != step->fused_ids.size() - 1) { - const auto& iter_to_attached_stage = - operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair( - stage_id, step->fused_ids[i])) != iter_to_attached_stage.end()) { + const auto& iter_to_attached_stage = operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, step->fused_ids[i])) != + iter_to_attached_stage.end()) { LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " << "stages. State before fusion:\n" << *this; @@ -507,23 +462,25 @@ Iterator State::DoFuseStep(const FuseStep& step) { pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); - // we have to replace the iterators in attach map, - // these two vectors keep the replacement mapping - std::vector from_iters; - std::vector to_iters; + // We have to update the iterator relations in attach map, these two vectors keep the replacement + // mapping + std::vector from_iters; + std::vector to_iters; const size_t begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); for (size_t i = 0; i < old_iter_size; ++i) { if (i <= begin_id) { continue; - } else if (i > end_id) { // move forward + } else if (i > end_id) { + // move forward from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, i - end_id + begin_id); - } else { // move to the fused id + } else { + // move to the fused id from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, begin_id); } } - pstate->attach_map.ReplaceIters(from_iters, to_iters); + pstate->attach_map.UpdateIters(from_iters, to_iters); return new_it; } @@ -533,8 +490,7 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { Iterator it = stage->iters[step->iter_id]; CHECK(it->annotation == IteratorAnnotation::kNone); - Iterator new_it = Iterator(it->name, it->range, it->iter_kind, - step->annotation); + Iterator new_it = Iterator(it->name, it->range, it->iter_kind, step->annotation); Stage new_stage = stage; new_stage.CopyOnWrite()->iters.Set(step->iter_id, std::move(new_it)); StateNode* pstate = CopyOnWrite(); @@ -617,12 +573,12 @@ void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_ } if (state.defined()) { - AttachMap::IterKey iter_key(stage_id, i); + IterKey iter_key(stage_id, i); auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); if (pair != state->attach_map->iter_to_attached_stages.end()) { + // Print the attached stage for (const auto& attach_stage_id : pair->second) { - PrintStage(os, attach_stage_id, state, base_indent + indent, - delete_trivial_loop); + PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); } } } @@ -688,7 +644,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") .set_body_typed([](State state, int stage_id, int target_stage_id, - const Iterator& target_iter) { + const Iterator& target_iter) { state.compute_at(stage_id, target_stage_id, target_iter); return state; }); @@ -731,21 +687,17 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel") }); TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int max_unroll) { + .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) { const auto& res = state.unroll(stage_id, it, max_unroll); return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateBindThread") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int thread_type) { - const auto& res = - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); +TVM_REGISTER_GLOBAL("auto_scheduler.StateBind") + .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) { + const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type)); return Array{state, res}; }); - TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index b5f1227a4c07..0f95794fa886 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -196,23 +196,27 @@ class Stage : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); }; +/*! \brief Use stage_id to represent a stage. */ +using StageKey = int; +/*! \brief Use stage_id and iter_id to represent a iterator. */ +using IterKey = std::pair; + /*! * \brief stores the compute_at relation between stages * This stores a bi-directional mapping from stages and iter: * 1. Stage to its attached iterator - * 2. Iterator to the stage attached to it + * 2. Iterator to the stage attached to it * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages * to query the relations */ -class AttachMapNode: public Object { +class AttachMapNode : public Object { public: - using StageKey = int; - using IterKey = std::pair; // stage_id and iter_id - + /*! \brief A Map to store the mapping of stage to its attached iterator. */ std::unordered_map stage_to_attach_iter; + /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ std::unordered_map> iter_to_attached_stages; - static constexpr const char* _type_key = "ansor.AttachMap"; + static constexpr const char* _type_key = "auto_scheduler.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); }; @@ -222,19 +226,35 @@ class AttachMapNode: public Object { */ class AttachMap : public ObjectRef { public: - using StageKey = int; - using IterKey = std::pair; // stage_id and iter_id - + /*! + * \brief Process the stage/iterator mapping after compute at. + * \param stage_id The index of the stage to be compute at. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter_id The index of iterator in target stage that this step will compute at to. + */ void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + /*! + * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage. + * \param stage_id The index of the stage to be compute at. + */ void DeleteStage(int stage_id); - void ReplaceIters(const std::vector& old_iters, - const std::vector& new_iters); - AttachMap ApplyStageIdOfffset(int start_id, int offset) const; + /*! + * \brief Update the iterator relations in AttachMap. + * \param old_iters The original IterKey. + * \param new_iters The new IterKey to update. + */ + void UpdateIters(const std::vector& old_iters, const std::vector& new_iters); TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); private: + /*! + * \brief To delete the entry of a specific stage. This will remove the items related to this + * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map. + * \param pnode A mutable pointer to AttachMapNode. + * \param stage_id The index of stage that will be removed from the map. + */ static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); }; @@ -314,13 +334,27 @@ class State : public ObjectRef { * \param order The expected iterator order. */ void reorder(int stage_id, const Array& order); + /*! + * \brief Schedule primitive corresponds to te.compute_at. + * \param stage_id The index of the stage to be reordered. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter The iterator in target stage that this step will compute at to. + */ void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + /*! + * \brief Schedule primitive corresponds to te.compute_root. + * \param stage_id The index of the stage to be reordered. + */ void compute_root(int stage_id); + /*! + * \brief Schedule primitive corresponds to te.compute_inline. + * \param stage_id The index of the stage to be reordered. + */ void compute_inline(int stage_id); /*! * \brief Schedule primitive corresponds to te.split. * \param stage_id The index of the stage to be split. - * \param it The iterator the be split. + * \param it The iterator to be split. * \param lengths The multiple split factors. Can be None to be filled by search policy. * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner. * \return The iterator results after split. @@ -334,11 +368,38 @@ class State : public ObjectRef { * \return The iterator result after fuse. */ Iterator fuse(int stage_id, const Array& iters); + /*! + * \brief Schedule primitive corresponds to te.vectorize. + * \param stage_id The index of the stage to be vectorized. + * \param it The iterator to be vectorized. + * \return The iterator result after vectorize. + */ Iterator vectorize(int stage_id, const Iterator& it); + /*! + * \brief Schedule primitive corresponds to te.parallel. + * \param stage_id The index of the stage to be paralleled. + * \param it The iterator to be paralleled. + * \return The iterator result after parallel. + */ Iterator parallel(int stage_id, const Iterator& it); + /*! + * \brief Schedule primitive corresponds to te.unroll. + * \param stage_id The index of the stage to be unrolled. + * \param it The iterator to be unrolled. + * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be + * skipped. + * \return The iterator result after unrolled. + */ Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); - Iterator bind_thread(int stage_id, const Iterator& it, - IteratorAnnotation thread_type); + /*! + * \brief Schedule primitive corresponds to te.bind. + * \param stage_id The index of the stage to be binded. + * \param it The iterator to be binded. + * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as + * this input. + * \return The iterator result after binded. + */ + Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 61957ff82f33..0e01dab434fb 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -161,8 +161,7 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); - data->push_back(::tvm::auto_scheduler::ComputeAtStep( - stage_id, target_stage_id, iter_id)); + data->push_back(::tvm::auto_scheduler::ComputeAtStep(stage_id, target_stage_id, iter_id)); } else if (name == "CR") { s = reader->NextArrayItem(); CHECK(s); @@ -208,14 +207,17 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { } data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids)); } else if (name == "AN") { - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&ann); - data->push_back(::tvm::auto_scheduler::AnnotationStep(stage_id, - iter_id, ::tvm::auto_scheduler::IteratorAnnotation(ann))); + data->push_back(::tvm::auto_scheduler::AnnotationStep( + stage_id, iter_id, ::tvm::auto_scheduler::IteratorAnnotation(ann))); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 3852ed642122..f8be067ad55b 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -91,25 +91,23 @@ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_ data_ = std::move(node); } -void ComputeAtStepNode::ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const { +void ComputeAtStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { te::Stage stage = (*stages)[stage_id]; - const IterVar& target_axis = - (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; + const IterVar& target_axis = (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; stage.compute_at((*stages)[target_stage_id], target_axis); stages->Set(stage_id, std::move(stage)); } -String ComputeAtStepNode::PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes) const { +String ComputeAtStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; const auto& target_stage = (*stages)[target_stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" - << CleanName(target_stage->op->name) << "], " - << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); + ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name) + << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); ss << ")\n"; ApplyToSchedule(stages, stage_to_axes); @@ -123,15 +121,15 @@ ComputeRootStep::ComputeRootStep(int stage_id) { data_ = std::move(node); } -void ComputeRootStepNode::ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const { +void ComputeRootStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; stage.compute_root(); stages->Set(stage_id, std::move(stage)); } -String ComputeRootStepNode::PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes) const { +String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -148,16 +146,15 @@ ComputeInlineStep::ComputeInlineStep(int stage_id) { data_ = std::move(node); } -void ComputeInlineStepNode::ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const { +void ComputeInlineStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; stage.compute_inline(); stages->Set(stage_id, std::move(stage)); } -String ComputeInlineStepNode::PrintAsPythonAPI( - Array *stages, - StageToAxesMap *stage_to_axes) const { +String ComputeInlineStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -321,8 +318,7 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, } /********** Annotation **********/ -AnnotationStep::AnnotationStep(int stage_id, int iter_id, - IteratorAnnotation ann) { +AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; @@ -330,64 +326,95 @@ AnnotationStep::AnnotationStep(int stage_id, int iter_id, data_ = std::move(node); } -void AnnotationStepNode::ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const { +void AnnotationStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { te::Stage stage = (*stages)[stage_id]; const Array& axes = (*stage_to_axes)[stage]; switch (annotation) { case IteratorAnnotation::kUnroll: - stage.unroll(axes[iter_id]); break; + stage.unroll(axes[iter_id]); + break; case IteratorAnnotation::kVectorize: - stage.vectorize(axes[iter_id]); break; + stage.vectorize(axes[iter_id]); + break; case IteratorAnnotation::kParallel: - stage.parallel(axes[iter_id]); break; + stage.parallel(axes[iter_id]); + break; case IteratorAnnotation::kVThread: - stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; + stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); + break; case IteratorAnnotation::kBlockX: - stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; + stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); + break; case IteratorAnnotation::kBlockY: - stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; + stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); + break; case IteratorAnnotation::kThreadX: - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); break; + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); + break; case IteratorAnnotation::kThreadY: - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; - case IteratorAnnotation::kNone: break; + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); + break; + case IteratorAnnotation::kNone: + break; default: - LOG(FATAL) << "Invalid Annotation " << static_cast(annotation); break; + LOG(FATAL) << "Invalid Annotation " << static_cast(annotation); + break; } stages->Set(stage_id, std::move(stage)); } -String AnnotationStepNode::PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes) const { +String AnnotationStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; const auto& iter = (*stage_to_axes)[stage][iter_id]; ss << "s[" << CleanName(stage->op->name) << "]."; switch (annotation) { - case IteratorAnnotation::kUnroll: ss << "unroll("; break; - case IteratorAnnotation::kVectorize: ss << "vectorize("; break; - case IteratorAnnotation::kParallel: ss << "parallel("; break; + case IteratorAnnotation::kUnroll: + ss << "unroll("; + break; + case IteratorAnnotation::kVectorize: + ss << "vectorize("; + break; + case IteratorAnnotation::kParallel: + ss << "parallel("; + break; case IteratorAnnotation::kVThread: case IteratorAnnotation::kBlockX: case IteratorAnnotation::kBlockY: case IteratorAnnotation::kThreadX: - case IteratorAnnotation::kThreadY: ss << "bind("; break; - case IteratorAnnotation::kNone: break; + case IteratorAnnotation::kThreadY: + ss << "bind("; + break; + case IteratorAnnotation::kNone: + break; default: - LOG(FATAL) << "Invalid annotation " << static_cast(annotation); break; + LOG(FATAL) << "Invalid annotation " << static_cast(annotation); + break; } ss << CleanName(iter->var->name_hint); switch (annotation) { - case IteratorAnnotation::kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; - case IteratorAnnotation::kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; - case IteratorAnnotation::kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; - case IteratorAnnotation::kThreadX: ss << ", tvm.thread_axis(\"threadIdx.x\")"; break; - case IteratorAnnotation::kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; - default: break; + case IteratorAnnotation::kVThread: + ss << ", tvm.thread_axis(\"vthread\")"; + break; + case IteratorAnnotation::kBlockX: + ss << ", tvm.thread_axis(\"blockIdx.x\")"; + break; + case IteratorAnnotation::kBlockY: + ss << ", tvm.thread_axis(\"blockIdy.y\")"; + break; + case IteratorAnnotation::kThreadX: + ss << ", tvm.thread_axis(\"threadIdx.x\")"; + break; + case IteratorAnnotation::kThreadY: + ss << ", tvm.thread_axis(\"threadIdx.y\")"; + break; + default: + break; } ss << ")\n"; diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index efe630560ca5..58f1c28a8792 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -144,16 +144,27 @@ class ReorderStep : public Step { }; /*! \brief Compute at step that corresponds to te::Stage::compute_at */ -class ComputeAtStepNode: public StepNode { +class ComputeAtStepNode : public StepNode { public: + /*! \brief The index of stage that this step will compute at to. */ int target_stage_id; + /*! \brief The index of iterator in target stage that this step will compute at to. */ int target_iter_id; - void ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); @@ -165,19 +176,35 @@ class ComputeAtStepNode: public StepNode { */ class ComputeAtStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be compute at. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter_id The index of iterator in target stage that this step will compute at to. + */ ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); }; /*! \brief Compute root step that corresponds to te::Stage::compute_root */ -class ComputeRootStepNode: public StepNode { +class ComputeRootStepNode : public StepNode { public: - void ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator result after fuse. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); @@ -189,19 +216,33 @@ class ComputeRootStepNode: public StepNode { */ class ComputeRootStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be compute root + */ explicit ComputeRootStep(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; /*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ -class ComputeInlineStepNode: public StepNode { +class ComputeInlineStepNode : public StepNode { public: - void ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator result after fuse. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); @@ -213,6 +254,10 @@ class ComputeInlineStepNode: public StepNode { */ class ComputeInlineStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be compute inline. + */ explicit ComputeInlineStep(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); @@ -323,16 +368,28 @@ class FuseStep : public Step { * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) */ -class AnnotationStepNode: public StepNode { +class AnnotationStepNode : public StepNode { public: + /*! \brief The index of the iterator to add annotation. */ int iter_id; + /*! \brief The annotation type of this step. */ IteratorAnnotation annotation; - void ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator result after fuse. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); @@ -344,6 +401,12 @@ class AnnotationStepNode: public StepNode { */ class AnnotationStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to add annotation. + * \param iter_id The index of the iterator to add annotation. + * \param ann The annotation type of this step. + */ AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 976f2802842c..7dc3618750a8 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -89,15 +89,6 @@ inline int GetIndex(const Array& array, const T& to_locate) { return -1; } -/*! \brief Delete an element in a vector */ -template -inline void DeleteItem(Array* array, const T& to_delete) { - auto iter = std::find(array->begin(), array->end(), to_delete); - if (iter != array->end()) { - array->erase(iter); - } -} - template inline void DeleteItem(std::vector* array, const T& to_delete) { auto iter = std::find(array->begin(), array->end(), to_delete); diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index fa503a20205f..6d3df2155871 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -64,9 +64,9 @@ def test_split_fuse_reorder_annotation(): s1.parallel(C, j1) s1.unroll(C, j2) s1.vectorize(C, j3) - s1.bind_thread(C, i1, "blockIdx.x") - s1.bind_thread(C, i2, "vthread") - s1.bind_thread(C, i3, "threadIdx.y") + s1.bind(C, i1, "blockIdx.x") + s1.bind(C, i2, "vthread") + s1.bind(C, i3, "threadIdx.y") def test_compute_at_root_inline(): @@ -87,6 +87,7 @@ def test_compute_at_root_inline(): s0.compute_inline(bn_mul) s0.compute_inline(bias_add) s0.compute_at(conv, relu, s0[relu].iters[2]) + print(s0) assert str(s0) == \ "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ "for i1 (0,3)\n" + \ From 308f141ebd012cd5036724748a3551e7a2215c10 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 16 Jul 2020 15:24:29 +0800 Subject: [PATCH 04/14] Update --- python/tvm/auto_scheduler/loop_state.py | 4 ++-- src/auto_scheduler/loop_state.cc | 26 ++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index b334f6674366..66924a7c8a32 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -186,7 +186,7 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): Returns ------- res_its : List[Iterator] - The splitted new Iterators + The splitted new Iterators. """ self.state_object, res = _ffi_api.StateSplit(self.state_object, self._resolve_stage_id(stage), @@ -254,7 +254,7 @@ def parallel(self, stage, iterator): return res def unroll(self, stage, iterator, max_unroll=None): - """ Schedule primitive corresponds to te.unrolled. + """ Schedule primitive corresponds to te.unroll. Parameters ---------- diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index b646f8c63142..8e08f1bfc97a 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -268,10 +268,9 @@ void State::DoReorderStep(const ReorderStep& step) { void State::DoComputeAtStep(const ComputeAtStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - // after compute_at, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call - // ComputeDAG::ReplayAndInferBound - std::vector new_iters; + // After compute_at, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound + Array new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); } @@ -279,43 +278,44 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), ComputeAtKind::kIter, stage->attrs)); + // Update attach map pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } void State::DoComputeRootStep(const ComputeRootStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - // after compute_root, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call - // ComputeDAG::ReplayAndInferBound - std::vector new_iters; + // After compute_at, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound + Array new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); } - // update attach map StateNode* pstate = CopyOnWrite(); pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), ComputeAtKind::kRoot, stage->attrs)); + // Update attach map pstate->attach_map.DeleteStage(step->stage_id); } void State::DoComputeInlineStep(const ComputeInlineStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - StateNode* pstate = CopyOnWrite(); - // CHECK the validity of compute_inline - const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) + CHECK_EQ(operator->()->attach_map->iter_to_attached_stages.count( + std::make_pair(step->stage_id, i)), + 0) << "Invalid compute_inline: Because there are some other stages " "that are attached to the target stage"; } + StateNode* pstate = CopyOnWrite(); auto new_stage = pstate->stages[step->stage_id]; new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; pstate->stages.Set(step->stage_id, std::move(new_stage)); + // Update attach map pstate->attach_map.DeleteStage(step->stage_id); } From 35eb81d93934da3d6afcc1e92205598a1bb437fc Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 16 Jul 2020 15:39:45 +0800 Subject: [PATCH 05/14] Update --- src/auto_scheduler/loop_state.cc | 9 ++++----- src/auto_scheduler/loop_state.h | 22 +++++++++++++++++++++- src/auto_scheduler/transform_step.cc | 11 +++-------- src/auto_scheduler/utils.h | 3 ++- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 8e08f1bfc97a..e65766b45547 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -226,10 +226,8 @@ Iterator State::parallel(int stage_id, const Iterator& it) { Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); - // don't unroll if the extent is larger than max_unroll + // Don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { if (auto imm = it->range->extent.as()) { if (imm->value > max_unroll) { @@ -238,6 +236,8 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { } } + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } @@ -493,8 +493,7 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { Iterator new_it = Iterator(it->name, it->range, it->iter_kind, step->annotation); Stage new_stage = stage; new_stage.CopyOnWrite()->iters.Set(step->iter_id, std::move(new_it)); - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, std::move(new_stage)); + CopyOnWrite()->stages.Set(step->stage_id, std::move(new_stage)); return new_it; } diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 0f95794fa886..f577bc4f1012 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -270,6 +270,10 @@ class StateNode : public Object { Array stages; /*! \brief History transformation steps. */ Array transform_steps; + /*! + * \brief The attach relations of stages and iterators. This is used to track the compute at + * operation. + */ AttachMap attach_map; /*! * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all @@ -280,7 +284,6 @@ class StateNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("stages", &stages); v->Visit("transform_steps", &transform_steps); - v->Visit("attach_map", &attach_map); v->Visit("concrete", &concrete); } @@ -414,8 +417,20 @@ class State : public ObjectRef { * \param step A ReorderStep. */ void DoReorderStep(const ReorderStep& step); + /*! + * \brief Apply compute at step to current state. + * \param step A ComputeAtStep. + */ void DoComputeAtStep(const ComputeAtStep& step); + /*! + * \brief Apply compute root step to current state. + * \param step A ComputeRootStep. + */ void DoComputeRootStep(const ComputeRootStep& step); + /*! + * \brief Apply compute inline to current state. + * \param step A ComputeInline. + */ void DoComputeInlineStep(const ComputeInlineStep& step); /*! * \brief Apply split step to current state. @@ -429,6 +444,11 @@ class State : public ObjectRef { * \return The iterator result after fuse. */ Iterator DoFuseStep(const FuseStep& step); + /*! + * \brief Apply annotation step to current state. + * \param step A AnnotationStep. + * \return The iterator result after annotate. + */ Iterator DoAnnotationStep(const AnnotationStep& step); /*! diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index f8be067ad55b..62e9c896c02a 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -105,11 +105,8 @@ String ComputeAtStepNode::PrintAsPythonAPI(Array* stages, std::stringstream ss; const auto& stage = (*stages)[stage_id]; const auto& target_stage = (*stages)[target_stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name) - << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); - - ss << ")\n"; + << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); } @@ -125,6 +122,7 @@ void ComputeRootStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; stage.compute_root(); + stages->Set(stage_id, std::move(stage)); } @@ -132,10 +130,8 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; ApplyToSchedule(stages, stage_to_axes); - return ss.str(); } @@ -150,6 +146,7 @@ void ComputeInlineStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; stage.compute_inline(); + stages->Set(stage_id, std::move(stage)); } @@ -157,10 +154,8 @@ String ComputeInlineStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; ApplyToSchedule(stages, stage_to_axes); - return ss.str(); } diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 7dc3618750a8..0541e0817350 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -63,7 +63,7 @@ struct hash> { namespace tvm { namespace auto_scheduler { -/********** Utilities for Array, std::string **********/ +/********** Utilities for Array, std::vector, std::string **********/ /*! \brief Get the first appearance index of elements in an Array */ template inline void GetIndices(const Array& array, const Array& to_locate, Array* indices) { @@ -89,6 +89,7 @@ inline int GetIndex(const Array& array, const T& to_locate) { return -1; } +/*! \brief Delete the item in a std::vector. */ template inline void DeleteItem(std::vector* array, const T& to_delete) { auto iter = std::find(array->begin(), array->end(), to_delete); From a4589a55bc21f453cc59fced9809f379539f1bc9 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 16 Jul 2020 16:15:57 +0800 Subject: [PATCH 06/14] Update measure record UT --- .../unittest/test_auto_scheduler_measure.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 1bcd0540ecb4..c6f01e2b2a29 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -18,14 +18,51 @@ """ Test measurement and log serialization. """ import tvm -from tvm import auto_scheduler +import topi +from tvm import te, auto_scheduler import tempfile from test_auto_scheduler_common import get_tiled_matmul def test_record(): - dag, s = get_tiled_matmul() + A = te.placeholder((512, 512), name='A') + B = te.placeholder((512, 512), name='B') + k = te.reduce_axis((0, 512), name='k') + C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') + D = topi.nn.relu(C) + k = te.reduce_axis((0, 512), name='k') + E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C') + F = topi.nn.relu(E) + + dag = auto_scheduler.ComputeDAG([A, B, F]) + s = dag.get_init_state() + + # Split + its0 = s.split(C, s[C].iters[0], [4, 8, 8]) + its1 = s.split(C, s[C].iters[4], [8, 4, 4]) + # Reorder + s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], + its1[3]]) + # Fuse + s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]]) + # Compute at + s.split(F, s[F].iters[0], [2]) + s.compute_at(E, F, s[F].iters[0]) + # Compute inline + s.compute_inline(D) + # Compute root + s.compute_root(D) + # Parallel + s.parallel(C, s[C].iters[0]) + # Thread bind + s.bind(C, s[C].iters[1], "blockIdx.x") + s.bind(C, s[C].iters[2], "threadIdx.y") + s.bind(C, s[C].iters[3], "vthread") + # Unroll + s.unroll(C, s[C].iters[4]) + # Vectorize + s.vectorize(C, s[C].iters[6]) if not tvm.runtime.enabled("llvm"): return From 8b051332a7cfe619e952e19f456074133d2cb767 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 17 Jul 2020 16:15:56 +0800 Subject: [PATCH 07/14] Update --- python/tvm/auto_scheduler/loop_state.py | 51 ++++++-- src/auto_scheduler/loop_state.cc | 68 +++++------ src/auto_scheduler/loop_state.h | 27 ++++- src/auto_scheduler/transform_step.cc | 50 ++++---- src/auto_scheduler/transform_step.h | 8 +- src/auto_scheduler/utils.h | 4 +- .../test_auto_scheduler_loop_state.py | 113 ++++++++++-------- .../unittest/test_auto_scheduler_measure.py | 5 +- 8 files changed, 199 insertions(+), 127 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 66924a7c8a32..c6abe3b77dab 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -83,6 +83,24 @@ class State: ----- This is a wrapper class of StateObject to deal with copy-on-write property """ + + # Static trans table for thread bind + # This is used to transform the annotation name to C++ enum + ANNOTATION_TRANS_TABLE = { + "none": 0, + "unroll": 1, + "vectorize": 2, + "parallel": 3, + "vthread": 4, + "blockIdx.x": 5, + "threadIdx.x": 6, + "blockIdx.y": 7, + "threadIdx.y": 8, + "blockIdx.z": 9, + "threadIdx.z": 10, + "tensorize": 11 + } + def __init__(self, state_object, dag): self.state_object = state_object self.compute_dag = dag @@ -135,6 +153,11 @@ def compute_at(self, stage, target_stage, target_iter): output tensor. target_iter : Iterator The target Iterator of compute_at. + + Notes + ----- + After compute_at, the extent of each iterator may not be accurate any more, so the bound + information will be removed from this state. Run ComputeDAG::InferBound to recover. """ self.state_object = _ffi_api.StateComputeAt(self.state_object, self._resolve_stage_id(stage), @@ -149,6 +172,11 @@ def compute_root(self, stage): stage : Union[int, Operation, Tensor] The Stage to be compute root, can be a Stage order index, Stage operation or stage output tensor. + + Notes + ----- + After compute_root, the extent of each iterator may not be accurate any more, so the bound + information will be removed from this state. Run ComputeDAG::InferBound to recover. """ self.state_object = _ffi_api.StateComputeRoot(self.state_object, self._resolve_stage_id(stage)) @@ -187,6 +215,11 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): ------- res_its : List[Iterator] The splitted new Iterators. + + Notes + ----- + If we do split on an iterator which has stages attached at it(by compute_at), the inner + most iterator of split results will become the new attach point. """ self.state_object, res = _ffi_api.StateSplit(self.state_object, self._resolve_stage_id(stage), @@ -208,6 +241,11 @@ def fuse(self, stage, iters): ------- res_it : Iterator The fused Iterator. + + Notes + ----- + If the iterators to be fused have stages attached at them(by compute_at), the fused + result will become the new attach point. """ self.state_object, res = _ffi_api.StateFuse(self.state_object, self._resolve_stage_id(stage), iters) @@ -293,25 +331,20 @@ def bind(self, stage, iterator, thread_name): - threadIdx.x - blockIdx.y - threadIdx.y + - blockIdx.z + - threadIdx.z Returns ------- res_it : Iterator The binded Iterator. """ - trans_table = { - "vthread": 4, - "blockIdx.x": 5, - "threadIdx.x": 6, - "blockIdx.y": 7, - "threadIdx.y": 8, - } - if not thread_name in trans_table.keys(): + if not thread_name in State.ANNOTATION_TRANS_TABLE.keys(): raise ValueError("Invalid thread_name: ", thread_name) self.state_object, res = _ffi_api.StateBind(self.state_object, self._resolve_stage_id(stage), iterator, - trans_table[thread_name]) + State.ANNOTATION_TRANS_TABLE[thread_name]) return res def copy(self): diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index e65766b45547..0dbece221aa4 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -97,7 +97,7 @@ void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_i // Delete the current entry of this stage DeleteStageEntry(pnode, stage_id); - // Store the new relations to map + // Store the new stage/iterator relations to map IterKey iter_key(target_stage_id, target_iter_id); pnode->stage_to_attach_iter[stage_id] = iter_key; pnode->iter_to_attached_stages[iter_key].push_back(stage_id); @@ -109,23 +109,25 @@ void AttachMap::DeleteStage(int stage_id) { DeleteStageEntry(pnode, stage_id); } -void AttachMap::UpdateIters(const std::vector& old_iters, +void AttachMap::UpdateIters(const std::vector& original_iters, const std::vector& new_iters) { + CHECK_EQ(original_iters.size(), new_iters.size()); AttachMapNode* pnode = CopyOnWrite(); - - CHECK_EQ(old_iters.size(), new_iters.size()); - for (size_t i = 0; i < old_iters.size(); ++i) { - auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + for (size_t i = 0; i < original_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(original_iters[i]); + // We get > from this map if (entry == pnode->iter_to_attached_stages.end()) { + // Skip if this iterator does not have any attach relations continue; } - // Replace iter in the value of `stage_to_attach_iter` + // Update the attaching target of an stage to the new iter in `stage_to_attach_iter` for (const auto& s : entry->second) { pnode->stage_to_attach_iter[s] = new_iters[i]; } - // Replace iter in the key of `iter_to_attached_stages` + // Remove the original iterator relation from `iter_to_attached_stages` and add the new + // iterator to it std::vector attached_stages = std::move(entry->second); pnode->iter_to_attached_stages.erase(entry); pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); @@ -134,14 +136,17 @@ void AttachMap::UpdateIters(const std::vector& old_iters, void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + // We get from this map if (old_entry != pnode->stage_to_attach_iter.end()) { - // Delete value in `iter_to_attached_stages` + // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have + // any attatched stage, delete this iterm too auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); - DeleteItem(&entry2->second, stage_id); + // We get > from this map + FindAndDeleteItem(&entry2->second, stage_id); if (entry2->second.size() == 0) { pnode->iter_to_attached_stages.erase(entry2); } - // Delete key in `stage_to_attach_iter` + // Delete the stage in `stage_to_attach_iter` pnode->stage_to_attach_iter.erase(old_entry); } } @@ -244,9 +249,9 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; - if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadY) { + if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) { LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " - << "kThreadX, kThreadY"; + << "kThreadX, kThreadY, kBlockZ, kThreadZ"; } AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); @@ -268,8 +273,8 @@ void State::DoReorderStep(const ReorderStep& step) { void State::DoComputeAtStep(const ComputeAtStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - // After compute_at, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound + // Remove the bound information of each iterator since they may not be accurate after + // compute at Array new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); @@ -285,8 +290,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { void State::DoComputeRootStep(const ComputeRootStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - // After compute_at, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound + // Remove the bound information of each iterator since they may not be accurate after + // compute root Array new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); @@ -302,13 +307,13 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { void State::DoComputeInlineStep(const ComputeInlineStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - // CHECK the validity of compute_inline + // Check the validity of compute_inline for (size_t i = 0; i < stage->iters.size(); ++i) { CHECK_EQ(operator->()->attach_map->iter_to_attached_stages.count( std::make_pair(step->stage_id, i)), 0) - << "Invalid compute_inline: Because there are some other stages " - "that are attached to the target stage"; + << "Invalid compute_inline: There are some other stages that are attached to the " + << "target stage"; } StateNode* pstate = CopyOnWrite(); @@ -387,8 +392,8 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); pstate->concrete &= concrete; - // We have to update the iterator relations in attach map, these two vectors keep the replacement - // mapping + // Two vectors are used to represent the iterator relation before and after split + // The original iterators in AttachMap will be updated with the new iterators std::vector from_iters; std::vector to_iters; for (size_t i = iter_id; i < old_iter_size; ++i) { @@ -462,8 +467,8 @@ Iterator State::DoFuseStep(const FuseStep& step) { pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); - // We have to update the iterator relations in attach map, these two vectors keep the replacement - // mapping + // Two vectors are used to represent the iterator relation before and after fuse + // The original iterators in AttachMap will be updated with the new iterators std::vector from_iters; std::vector to_iters; const size_t begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); @@ -492,7 +497,7 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { CHECK(it->annotation == IteratorAnnotation::kNone); Iterator new_it = Iterator(it->name, it->range, it->iter_kind, step->annotation); Stage new_stage = stage; - new_stage.CopyOnWrite()->iters.Set(step->iter_id, std::move(new_it)); + new_stage.CopyOnWrite()->iters.Set(step->iter_id, new_it); CopyOnWrite()->stages.Set(step->stage_id, std::move(new_stage)); return new_it; } @@ -521,19 +526,6 @@ void State::DoSteps(const ComputeDAG& dag) { } } -static const char* IteratorAnnotationString[] = { - "for", // kNone = 0 - "unroll", // kUnroll = 1 - "vectorize", // kVectorize = 2 - "parallel", // kParallel = 3 - "vthread", // kVThread = 4 - "gpu.blockIdx.x", // kBlockX = 5 - "gpu.threadIdx.x", // kThreadX = 6 - "gpu.blockIdx.y", // kBlockY = 7 - "gpu.threadIdx.y", // kThreadY = 8 - "tensorize" // kTensorized = 9 -}; - // Print stage to ostream void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent, bool delete_trivial_loop) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index f577bc4f1012..2179c511d7a1 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -112,6 +112,8 @@ class IteratorNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("range", &range); + v->Visit("iter_kind", &iter_kind); + v->Visit("annotation", &annotation); } static constexpr const char* _type_key = "auto_scheduler.Iterator"; @@ -239,11 +241,13 @@ class AttachMap : public ObjectRef { */ void DeleteStage(int stage_id); /*! - * \brief Update the iterator relations in AttachMap. - * \param old_iters The original IterKey. + * \brief Find the relations of original iterators in AttachMap, and update them with the new + * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated. + * \param original_iters The original IterKey. * \param new_iters The new IterKey to update. */ - void UpdateIters(const std::vector& old_iters, const std::vector& new_iters); + void UpdateIters(const std::vector& original_iters, + const std::vector& new_iters); TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); @@ -342,11 +346,15 @@ class State : public ObjectRef { * \param stage_id The index of the stage to be reordered. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter The iterator in target stage that this step will compute at to. + * \note After compute_at, the extent of each iterator may not be accurate any more, so the + * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. */ void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! * \brief Schedule primitive corresponds to te.compute_root. * \param stage_id The index of the stage to be reordered. + * \note After compute_root, the extent of each iterator may not be accurate any more, so the + * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. */ void compute_root(int stage_id); /*! @@ -361,6 +369,8 @@ class State : public ObjectRef { * \param lengths The multiple split factors. Can be None to be filled by search policy. * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner. * \return The iterator results after split. + * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner + * most iterator of split results will become the new attach point. */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); @@ -369,6 +379,8 @@ class State : public ObjectRef { * \param stage_id The index of the stage to be fused. * \param iters The iterators to be fused. * \return The iterator result after fuse. + * \note If the iterators to be fused have stages attached at them(by compute_at), the fused + * result will become the new attach point. */ Iterator fuse(int stage_id, const Array& iters); /*! @@ -408,8 +420,9 @@ class State : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); private: - /* Do transform steps - * Note: The following functions only change loop state but do not change transform_history. + /*! + * \brief Do transform steps. + * \note The following functions only change loop state but do not change transform_history. * We separate these functions out, so you can call them for replay easily given history steps */ /*! @@ -420,11 +433,15 @@ class State : public ObjectRef { /*! * \brief Apply compute at step to current state. * \param step A ComputeAtStep. + * \note After compute_at, the extent of each iterator may not be accurate any more, so the + * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. */ void DoComputeAtStep(const ComputeAtStep& step); /*! * \brief Apply compute root step to current state. * \param step A ComputeRootStep. + * \note After compute_root, the extent of each iterator may not be accurate any more, so the + * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. */ void DoComputeRootStep(const ComputeRootStep& step); /*! diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 62e9c896c02a..9bb8bffc144f 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -36,6 +36,21 @@ namespace tvm { namespace auto_scheduler { +const char* IteratorAnnotationString[] = { + "for", // kNone = 0 + "unroll", // kUnroll = 1 + "vectorize", // kVectorize = 2 + "parallel", // kParallel = 3 + "vthread", // kVThread = 4 + "blockIdx.x", // kBlockX = 5 + "threadIdx.x", // kThreadX = 6 + "blockIdx.y", // kBlockY = 7 + "threadIdx.y", // kThreadY = 8 + "blockIdx.z", // kBlockZ = 9 + "threadIdx.z", // kThreadZ = 10 + "tensorize" // kTensorized = 11 +}; + /********** Reorder **********/ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); @@ -94,8 +109,9 @@ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_ void ComputeAtStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { te::Stage stage = (*stages)[stage_id]; - const IterVar& target_axis = (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; - stage.compute_at((*stages)[target_stage_id], target_axis); + const auto& target_stage = (*stages)[target_stage_id]; + const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id]; + stage.compute_at(target_stage, target_axis); stages->Set(stage_id, std::move(stage)); } @@ -122,7 +138,6 @@ void ComputeRootStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; stage.compute_root(); - stages->Set(stage_id, std::move(stage)); } @@ -146,7 +161,6 @@ void ComputeInlineStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; stage.compute_inline(); - stages->Set(stage_id, std::move(stage)); } @@ -337,19 +351,14 @@ void AnnotationStepNode::ApplyToSchedule(Array* stages, stage.parallel(axes[iter_id]); break; case IteratorAnnotation::kVThread: - stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); - break; case IteratorAnnotation::kBlockX: - stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); - break; case IteratorAnnotation::kBlockY: - stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); - break; + case IteratorAnnotation::kBlockZ: case IteratorAnnotation::kThreadX: - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); - break; case IteratorAnnotation::kThreadY: - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); + case IteratorAnnotation::kThreadZ: + stage.bind(axes[iter_id], + te::thread_axis(Range(), IteratorAnnotationString[static_cast(annotation)])); break; case IteratorAnnotation::kNone: break; @@ -381,8 +390,10 @@ String AnnotationStepNode::PrintAsPythonAPI(Array* stages, case IteratorAnnotation::kVThread: case IteratorAnnotation::kBlockX: case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kBlockZ: case IteratorAnnotation::kThreadX: case IteratorAnnotation::kThreadY: + case IteratorAnnotation::kThreadZ: ss << "bind("; break; case IteratorAnnotation::kNone: @@ -394,19 +405,14 @@ String AnnotationStepNode::PrintAsPythonAPI(Array* stages, ss << CleanName(iter->var->name_hint); switch (annotation) { case IteratorAnnotation::kVThread: - ss << ", tvm.thread_axis(\"vthread\")"; - break; case IteratorAnnotation::kBlockX: - ss << ", tvm.thread_axis(\"blockIdx.x\")"; - break; case IteratorAnnotation::kBlockY: - ss << ", tvm.thread_axis(\"blockIdy.y\")"; - break; + case IteratorAnnotation::kBlockZ: case IteratorAnnotation::kThreadX: - ss << ", tvm.thread_axis(\"threadIdx.x\")"; - break; case IteratorAnnotation::kThreadY: - ss << ", tvm.thread_axis(\"threadIdx.y\")"; + case IteratorAnnotation::kThreadZ: + ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast(annotation)] + << "\")"; break; default: break; diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 58f1c28a8792..51cac332ec45 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -73,10 +73,16 @@ enum class IteratorAnnotation : int { kBlockY = 7, /*! \brief This iterator has been bind to threadIdx.y. */ kThreadY = 8, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockZ = 9, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadZ = 10, /*! \brief This iterator has been mapped with a tensorize intrinsic. */ - kTensorized = 9 + kTensorize = 11 }; +extern const char* IteratorAnnotationString[]; + /*! * \brief The base class of transformation steps. Each step has its corresponding tvm.te * schedule primitives. diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 0541e0817350..ccac3bbb75a2 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -89,9 +89,9 @@ inline int GetIndex(const Array& array, const T& to_locate) { return -1; } -/*! \brief Delete the item in a std::vector. */ +/*! \brief Delete the item in a std::vector if it exists. */ template -inline void DeleteItem(std::vector* array, const T& to_delete) { +inline void FindAndDeleteItem(std::vector* array, const T& to_delete) { auto iter = std::find(array->begin(), array->end(), to_delete); if (iter != array->end()) { array->erase(iter); diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 6d3df2155871..2c1528e16735 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -27,7 +27,7 @@ def test_split_fuse_reorder_annotation(): - A, B, C = matmul_auto_scheduler_test(512, 512, 512) + A, B, C = matmul_auto_scheduler_test(N=512, M=512, K=512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() i, j, k = s0[C].iters @@ -61,16 +61,34 @@ def test_split_fuse_reorder_annotation(): assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2 - s1.parallel(C, j1) - s1.unroll(C, j2) - s1.vectorize(C, j3) - s1.bind(C, i1, "blockIdx.x") - s1.bind(C, i2, "vthread") - s1.bind(C, i3, "threadIdx.y") + res = s1.bind(C, i1, "blockIdx.x") + assert res == s1[C].iters[0] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["blockIdx.x"] + + res = s1.bind(C, i2, "vthread") + assert res == s1[C].iters[1] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vthread"] + + res = s1.bind(C, i3, "threadIdx.y") + assert res == s1[C].iters[2] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["threadIdx.y"] + + res = s1.parallel(C, j1) + assert res == s1[C].iters[3] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["parallel"] + + res = s1.unroll(C, j2) + assert res == s1[C].iters[4] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["unroll"] + + res = s1.vectorize(C, j3) + assert res == s1[C].iters[5] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"] def test_compute_at_root_inline(): - dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64, + kernel_size=7, strides=2, padding=3)) s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 @@ -87,51 +105,50 @@ def test_compute_at_root_inline(): s0.compute_inline(bn_mul) s0.compute_inline(bias_add) s0.compute_at(conv, relu, s0[relu].iters[2]) - print(s0) assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" + """Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n""" + \ + """for i1 (0,3)\n""" + \ + """ for i2 (0,230)\n""" + \ + """ for i3 (0,230)\n""" + \ + """ pad_temp = ...\n""" + \ + """for i1 (0,64)\n""" + \ + """ for i2 (0,112)\n""" + \ + """ for nn (None)\n""" + \ + """ for ff (None)\n""" + \ + """ for yy (None)\n""" + \ + """ for xx (None)\n""" + \ + """ for rc (None)\n""" + \ + """ for ry (None)\n""" + \ + """ for rx (None)\n""" + \ + """ compute = ...\n""" + \ + """ for i3 (0,112)\n""" + \ + """ compute = ...\n""" s0.compute_root(conv) s0.compute_root(bn_mul) assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - "for i (None)\n" + \ - " for j (None)\n" + \ - " for k (None)\n" + \ - " for l (None)\n" + \ - " Bn_mul = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" + """Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n""" + \ + """for i1 (0,3)\n""" + \ + """ for i2 (0,230)\n""" + \ + """ for i3 (0,230)\n""" + \ + """ pad_temp = ...\n""" + \ + """for nn (None)\n""" + \ + """ for ff (None)\n""" + \ + """ for yy (None)\n""" + \ + """ for xx (None)\n""" + \ + """ for rc (None)\n""" + \ + """ for ry (None)\n""" + \ + """ for rx (None)\n""" + \ + """ compute = ...\n""" + \ + """for i (None)\n""" + \ + """ for j (None)\n""" + \ + """ for k (None)\n""" + \ + """ for l (None)\n""" + \ + """ Bn_mul = ...\n""" + \ + """for i1 (0,64)\n""" + \ + """ for i2 (0,112)\n""" + \ + """ for i3 (0,112)\n""" + \ + """ compute = ...\n""" if __name__ == "__main__": diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index c6f01e2b2a29..1403ae088bb0 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -26,6 +26,9 @@ def test_record(): + if not tvm.runtime.enabled("llvm"): + return + A = te.placeholder((512, 512), name='A') B = te.placeholder((512, 512), name='B') k = te.reduce_axis((0, 512), name='k') @@ -64,8 +67,6 @@ def test_record(): # Vectorize s.vectorize(C, s[C].iters[6]) - if not tvm.runtime.enabled("llvm"): - return target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 8fb87ae573bb84513a59734a915b941b8e284748 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 17 Jul 2020 17:30:41 +0800 Subject: [PATCH 08/14] Update --- src/auto_scheduler/loop_state.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 2179c511d7a1..e484e8b947eb 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -453,12 +453,16 @@ class State : public ObjectRef { * \brief Apply split step to current state. * \param step A SplitStep. * \return The iterator results after split. + * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner + * most iterator of split results will become the new attach point. */ Array DoSplitStep(const SplitStep& step); /*! * \brief Apply fuse step to current state. * \param step A FuseStep. * \return The iterator result after fuse. + * \note If the iterators to be fused have stages attached at them(by compute_at), the fused + * result will become the new attach point. */ Iterator DoFuseStep(const FuseStep& step); /*! From cc0ad1d434c92e271a6695669a3476d46a9cef2c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 18 Jul 2020 17:50:53 +0800 Subject: [PATCH 09/14] Update --- python/tvm/auto_scheduler/loop_state.py | 58 +++++++------- src/auto_scheduler/loop_state.cc | 2 +- src/auto_scheduler/loop_state.h | 30 +++++--- .../test_auto_scheduler_loop_state.py | 76 +++++++++---------- .../unittest/test_auto_scheduler_measure.py | 2 +- 5 files changed, 87 insertions(+), 81 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index c6abe3b77dab..238851ad6d42 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -132,8 +132,8 @@ def reorder(self, stage, order): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be reordered, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be reordered, which can be specified by the integer index, Operation, + or output tensor of the stage. order : List[Iterator] Iterators in the expected order. """ @@ -146,18 +146,20 @@ def compute_at(self, stage, target_stage, target_iter): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute at, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be compute at, which can be specified by the integer index, Operation, + or output tensor of the stage. target_stage : Union[int, Operation, Tensor] - The target stage of compute_at, can be a Stage order index, Stage operation or stage - output tensor. + The target stage of compute_at, which can be specified by the integer index, Operation, + or output tensor of the stage. target_iter : Iterator The target Iterator of compute_at. Notes ----- - After compute_at, the extent of each iterator may not be accurate any more, so the bound - information will be removed from this state. Run ComputeDAG::InferBound to recover. + After compute_at, we need careful dependency analysis to compute the accurate bound + information. However, it is relatively expensive and complicated, so we just fill "None" + as bound for the newly created iterators. + Call ComputeDAG::InferBound on the returned state to get the complete bound information. """ self.state_object = _ffi_api.StateComputeAt(self.state_object, self._resolve_stage_id(stage), @@ -170,13 +172,15 @@ def compute_root(self, stage): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute root, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be compute root, which can be specified by the integer index, Operation, + or output tensor of the stage. Notes ----- - After compute_root, the extent of each iterator may not be accurate any more, so the bound - information will be removed from this state. Run ComputeDAG::InferBound to recover. + After compute_root, we need careful dependency analysis to compute the accurate bound + information. However, it is relatively expensive and complicated, so we just fill "None" + as bound for the newly created iterators. + Call ComputeDAG::InferBound on the returned state to get the complete bound information. """ self.state_object = _ffi_api.StateComputeRoot(self.state_object, self._resolve_stage_id(stage)) @@ -187,8 +191,8 @@ def compute_inline(self, stage): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute inline, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be compute inlined, which can be specified by the integer index, Operation, + or output tensor of the stage. """ self.state_object = _ffi_api.StateComputeInline(self.state_object, self._resolve_stage_id(stage)) @@ -202,8 +206,8 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be split, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to be split. lengths: List[int] @@ -232,8 +236,8 @@ def fuse(self, stage, iters): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be fused, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be fused, which can be specified by the integer index, Operation, + or output tensor of the stage. iters : List[Iterator] The iterators to be fused. @@ -257,8 +261,8 @@ def vectorize(self, stage, iterator): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be vectorized, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be vectorized, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to be vectorized. @@ -277,8 +281,8 @@ def parallel(self, stage, iterator): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be paralleled, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be paralleled, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to be paralleled. @@ -297,8 +301,8 @@ def unroll(self, stage, iterator, max_unroll=None): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be unrolled, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be unrolled, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to be unrolled. max_unroll : Optional[int] @@ -320,12 +324,12 @@ def bind(self, stage, iterator, thread_name): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be binded, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be binded, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to be binded. thread_name : str - The thread type to be binded. Currently support: + The thread type to be binded. Candidates: - vthread - blockIdx.x - threadIdx.x diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 0dbece221aa4..1d4ac7091bfd 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -250,7 +250,7 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) { - LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " + LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, " << "kThreadX, kThreadY, kBlockZ, kThreadZ"; } AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index e484e8b947eb..33c27282bcfd 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -154,10 +154,10 @@ class StageNode : public Object { public: /*! \brief The operator of this stage */ te::Operation op; - /*! \brief The type of this stage. */ - StageKind op_type; /*! \brief The iterators in this stage. */ Array iters; + /*! \brief The type of this stage. */ + StageKind op_type; /*! \brief The compute location of this stage. */ ComputeAtKind compute_at; /*! \brief Other stage-level attributes. */ @@ -166,6 +166,8 @@ class StageNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); v->Visit("iters", &iters); + v->Visit("op_type", &op_type); + v->Visit("compute_at", &compute_at); } static constexpr const char* _type_key = "auto_scheduler.Stage"; @@ -346,15 +348,19 @@ class State : public ObjectRef { * \param stage_id The index of the stage to be reordered. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter The iterator in target stage that this step will compute at to. - * \note After compute_at, the extent of each iterator may not be accurate any more, so the - * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! * \brief Schedule primitive corresponds to te.compute_root. * \param stage_id The index of the stage to be reordered. - * \note After compute_root, the extent of each iterator may not be accurate any more, so the - * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. + * \note After compute_root, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ void compute_root(int stage_id); /*! @@ -433,15 +439,19 @@ class State : public ObjectRef { /*! * \brief Apply compute at step to current state. * \param step A ComputeAtStep. - * \note After compute_at, the extent of each iterator may not be accurate any more, so the - * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ void DoComputeAtStep(const ComputeAtStep& step); /*! * \brief Apply compute root step to current state. * \param step A ComputeRootStep. - * \note After compute_root, the extent of each iterator may not be accurate any more, so the - * bound information will be removed from this state. Run ComputeDAG::InferBound to recover. + * \note After compute_root, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ void DoComputeRootStep(const ComputeRootStep& step); /*! diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 2c1528e16735..32ea8faa84d0 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -102,53 +102,45 @@ def test_compute_at_root_inline(): relu = s0.stage_ops[10] s0.compute_inline(bn_add) + assert s0[bn_add].compute_at == 1 + s0.compute_inline(bn_mul) + assert s0[bn_mul].compute_at == 1 + s0.compute_inline(bias_add) + assert s0[bias_add].compute_at == 1 + + assert s0[conv].iters[0].range.extent == 1 + assert s0[conv].iters[1].range.extent == 64 + assert s0[conv].iters[2].range.extent == 112 + assert s0[conv].iters[3].range.extent == 112 + assert s0[conv].iters[4].range.extent == 3 + assert s0[conv].iters[5].range.extent == 7 + assert s0[conv].iters[6].range.extent == 7 s0.compute_at(conv, relu, s0[relu].iters[2]) - assert str(s0) == \ - """Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n""" + \ - """for i1 (0,3)\n""" + \ - """ for i2 (0,230)\n""" + \ - """ for i3 (0,230)\n""" + \ - """ pad_temp = ...\n""" + \ - """for i1 (0,64)\n""" + \ - """ for i2 (0,112)\n""" + \ - """ for nn (None)\n""" + \ - """ for ff (None)\n""" + \ - """ for yy (None)\n""" + \ - """ for xx (None)\n""" + \ - """ for rc (None)\n""" + \ - """ for ry (None)\n""" + \ - """ for rx (None)\n""" + \ - """ compute = ...\n""" + \ - """ for i3 (0,112)\n""" + \ - """ compute = ...\n""" + assert s0[conv].compute_at == 2 + s0 = dag.infer_bound_from_state(s0) + assert s0[conv].iters[0].range.extent == 1 + assert s0[conv].iters[1].range.extent == 1 + assert s0[conv].iters[2].range.extent == 1 + assert s0[conv].iters[3].range.extent == 112 + assert s0[conv].iters[4].range.extent == 3 + assert s0[conv].iters[5].range.extent == 7 + assert s0[conv].iters[6].range.extent == 7 - s0.compute_root(conv) s0.compute_root(bn_mul) - assert str(s0) == \ - """Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n""" + \ - """for i1 (0,3)\n""" + \ - """ for i2 (0,230)\n""" + \ - """ for i3 (0,230)\n""" + \ - """ pad_temp = ...\n""" + \ - """for nn (None)\n""" + \ - """ for ff (None)\n""" + \ - """ for yy (None)\n""" + \ - """ for xx (None)\n""" + \ - """ for rc (None)\n""" + \ - """ for ry (None)\n""" + \ - """ for rx (None)\n""" + \ - """ compute = ...\n""" + \ - """for i (None)\n""" + \ - """ for j (None)\n""" + \ - """ for k (None)\n""" + \ - """ for l (None)\n""" + \ - """ Bn_mul = ...\n""" + \ - """for i1 (0,64)\n""" + \ - """ for i2 (0,112)\n""" + \ - """ for i3 (0,112)\n""" + \ - """ compute = ...\n""" + assert s0[bn_mul].compute_at == 0 + + s0.compute_root(conv) + assert s0[conv].compute_at == 0 + s0 = dag.infer_bound_from_state(s0) + assert s0[conv].iters[0].range.extent == 1 + assert s0[conv].iters[1].range.extent == 64 + assert s0[conv].iters[2].range.extent == 112 + assert s0[conv].iters[3].range.extent == 112 + assert s0[conv].iters[4].range.extent == 3 + assert s0[conv].iters[5].range.extent == 7 + assert s0[conv].iters[6].range.extent == 7 if __name__ == "__main__": diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 1403ae088bb0..23b738a8478f 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -58,7 +58,7 @@ def test_record(): s.compute_root(D) # Parallel s.parallel(C, s[C].iters[0]) - # Thread bind + # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here) s.bind(C, s[C].iters[1], "blockIdx.x") s.bind(C, s[C].iters[2], "threadIdx.y") s.bind(C, s[C].iters[3], "vthread") From 13809c2e1cff494b6757bc1e6923cd9a0a759c5b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 18 Jul 2020 18:52:46 +0800 Subject: [PATCH 10/14] Move state implementation to step --- src/auto_scheduler/compute_dag.cc | 18 +- src/auto_scheduler/loop_state.cc | 280 ++------------------------- src/auto_scheduler/loop_state.h | 136 +------------ src/auto_scheduler/transform_step.cc | 252 +++++++++++++++++++++++- src/auto_scheduler/transform_step.h | 145 ++++++++++++-- 5 files changed, 409 insertions(+), 422 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 3db3cba7b534..e675c24d1238 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -332,22 +332,22 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } - // Call each step's PrintAsPythonAPI method + // Call each step's ApplyToPythonAPI method for (const auto& step : transform_steps) { if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } @@ -368,7 +368,7 @@ State ComputeDAG::InferBound(const State& state) const { ret_state = operator->()->init_state; pstate = ret_state.CopyOnWrite(); pstate->transform_steps = state->transform_steps; - ret_state.DoSteps(*this); + ret_state.ApplySteps(*this); } else { ret_state = state; pstate = ret_state.CopyOnWrite(); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 1d4ac7091bfd..08512e9b0cd8 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -171,7 +171,7 @@ void State::reorder(int stage_id, const Array& order) { GetIndices(stage->iters, order, &after_ids); ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); - DoReorderStep(step); + step->ApplyToState(this); } void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { @@ -179,19 +179,19 @@ void State::compute_at(int stage_id, int target_stage_id, const Iterator& target ComputeAtStep step = ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); - return DoComputeAtStep(step); + step->ApplyToState(this); } void State::compute_root(int stage_id) { ComputeRootStep step = ComputeRootStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); - return DoComputeRootStep(step); + step->ApplyToState(this); } void State::compute_inline(int stage_id) { ComputeInlineStep step = ComputeInlineStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); - return DoComputeInlineStep(step); + step->ApplyToState(this); } Array State::split(int stage_id, const Iterator& it, @@ -201,7 +201,7 @@ Array State::split(int stage_id, const Iterator& it, SplitStep(stage_id, GetIndex(stage->iters, it), it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); - return DoSplitStep(step); + return step->ApplyToState(this); } Iterator State::fuse(int stage_id, const Array& iters) { @@ -210,7 +210,7 @@ Iterator State::fuse(int stage_id, const Array& iters) { GetIndices(stage->iters, iters, &indices); FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); - return DoFuseStep(step); + return step->ApplyToState(this); } Iterator State::vectorize(int stage_id, const Iterator& it) { @@ -218,7 +218,7 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); + return step->ApplyToState(this); } Iterator State::parallel(int stage_id, const Iterator& it) { @@ -226,7 +226,7 @@ Iterator State::parallel(int stage_id, const Iterator& it) { AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel); CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); + return step->ApplyToState(this); } Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { @@ -244,7 +244,7 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); + return step->ApplyToState(this); } Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { @@ -255,271 +255,27 @@ Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread } AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); + return step->ApplyToState(this); } -/********** Step implementations for state **********/ -void State::DoReorderStep(const ReorderStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - Array iters; - for (auto x : step->after_ids) { - iters.push_back(stage->iters[x]); - } - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, - Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); -} - -void State::DoComputeAtStep(const ComputeAtStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - // Remove the bound information of each iterator since they may not be accurate after - // compute at - Array new_iters; - for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); - } - - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), - ComputeAtKind::kIter, stage->attrs)); - // Update attach map - pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); -} - -void State::DoComputeRootStep(const ComputeRootStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - // Remove the bound information of each iterator since they may not be accurate after - // compute root - Array new_iters; - for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); - } - - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), - ComputeAtKind::kRoot, stage->attrs)); - // Update attach map - pstate->attach_map.DeleteStage(step->stage_id); -} - -void State::DoComputeInlineStep(const ComputeInlineStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - // Check the validity of compute_inline - for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(operator->()->attach_map->iter_to_attached_stages.count( - std::make_pair(step->stage_id, i)), - 0) - << "Invalid compute_inline: There are some other stages that are attached to the " - << "target stage"; - } - - StateNode* pstate = CopyOnWrite(); - auto new_stage = pstate->stages[step->stage_id]; - new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; - pstate->stages.Set(step->stage_id, std::move(new_stage)); - // Update attach map - pstate->attach_map.DeleteStage(step->stage_id); -} - -// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -Array State::DoSplitStepCommon(int stage_id, int iter_id, - const Array>& lengths, - bool inner_to_outer) { - const Stage& stage = operator->()->stages[stage_id]; - const Iterator& it = stage->iters[iter_id]; - size_t old_iter_size = stage->iters.size(); - bool concrete = true; - - Optional tosplit_min, tosplit_extent; - if (it->range.defined()) { - tosplit_min = it->range->min; - tosplit_extent = it->range->extent; - } else { - tosplit_min = NullOpt; - tosplit_extent = NullOpt; - } - - Array outs; - for (size_t i = 0; i < lengths.size(); ++i) { - Optional l; - String name; - if (inner_to_outer) { - l = lengths[lengths.size() - i - 1]; - name = it->name + "." + std::to_string(lengths.size() - i); - } else { - l = lengths[i]; - name = it->name + "." + std::to_string(i); - } - Iterator res; - if (l && tosplit_min && tosplit_extent) { - res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, - IteratorAnnotation::kNone); - tosplit_min = Integer(0); - tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); - } else { - res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); - tosplit_min = NullOpt; - tosplit_extent = NullOpt; - concrete = false; - } - outs.push_back(std::move(res)); - } - - Range range; - if (tosplit_min && tosplit_extent) { - range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); - } - if (inner_to_outer) { - outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone)); - // Reverse the Iterator array - Array temp(outs.rbegin(), outs.rend()); - outs = std::move(temp); - } else { - outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, - IteratorAnnotation::kNone)); - } - - Array new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); - new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); - - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, - Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); - pstate->concrete &= concrete; - - // Two vectors are used to represent the iterator relation before and after split - // The original iterators in AttachMap will be updated with the new iterators - std::vector from_iters; - std::vector to_iters; - for (size_t i = iter_id; i < old_iter_size; ++i) { - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, i + lengths.size()); - } - pstate->attach_map.UpdateIters(from_iters, to_iters); - - return outs; -} - -Array State::DoSplitStep(const SplitStep& step) { - return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer); -} - -Iterator State::DoFuseStep(const FuseStep& step) { - int stage_id = step->stage_id; - const Stage& stage = operator->()->stages[stage_id]; - size_t old_iter_size = static_cast(stage->iters.size()); - - String new_name; - PrimExpr new_extent = 1; - IteratorKind new_iter_kind = IteratorKind::kSpecial; - - for (size_t i = 0; i < step->fused_ids.size(); ++i) { - if (i > 0) { - CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); - } - - if (i != step->fused_ids.size() - 1) { - const auto& iter_to_attached_stage = operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair(stage_id, step->fused_ids[i])) != - iter_to_attached_stage.end()) { - LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " - << "stages. State before fusion:\n" - << *this; - } - } - - const Iterator& it = stage->iters[step->fused_ids[i]]; - new_name = new_name + it->name + "@"; - - if (it->range.defined() && new_extent.defined()) { - new_extent = new_extent * it->range->extent; - } else { - new_extent = PrimExpr(); - } - - if (i == 0) { - new_iter_kind = it->iter_kind; - } else { - if (new_iter_kind != it->iter_kind) { - new_iter_kind = IteratorKind::kMixed; - } - } - } - - Range range; - if (new_extent.defined()) { - range = Range::FromMinExtent(0, new_extent); - } - Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); - Array new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + step->fused_ids.front()); - new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, - stage->iters.end()); - - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, - Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); - - // Two vectors are used to represent the iterator relation before and after fuse - // The original iterators in AttachMap will be updated with the new iterators - std::vector from_iters; - std::vector to_iters; - const size_t begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); - for (size_t i = 0; i < old_iter_size; ++i) { - if (i <= begin_id) { - continue; - } else if (i > end_id) { - // move forward - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, i - end_id + begin_id); - } else { - // move to the fused id - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, begin_id); - } - } - pstate->attach_map.UpdateIters(from_iters, to_iters); - - return new_it; -} - -Iterator State::DoAnnotationStep(const AnnotationStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - Iterator it = stage->iters[step->iter_id]; - - CHECK(it->annotation == IteratorAnnotation::kNone); - Iterator new_it = Iterator(it->name, it->range, it->iter_kind, step->annotation); - Stage new_stage = stage; - new_stage.CopyOnWrite()->iters.Set(step->iter_id, new_it); - CopyOnWrite()->stages.Set(step->stage_id, std::move(new_stage)); - return new_it; -} - -void State::DoSteps(const ComputeDAG& dag) { +void State::ApplySteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; for (const auto& step : operator->()->transform_steps) { if (auto ps = step.as()) { - DoReorderStep(GetRef(ps)); + ps->ApplyToState(this); } else if (auto ps = step.as()) { - DoComputeAtStep(GetRef(ps)); + ps->ApplyToState(this); } else if (auto ps = step.as()) { - DoComputeRootStep(GetRef(ps)); + ps->ApplyToState(this); } else if (auto ps = step.as()) { - DoComputeInlineStep(GetRef(ps)); + ps->ApplyToState(this); } else if (auto ps = step.as()) { - DoSplitStep(GetRef(ps)); + ps->ApplyToState(this); } else if (auto ps = step.as()) { - DoFuseStep(GetRef(ps)); + ps->ApplyToState(this); } else if (auto ps = step.as()) { - DoAnnotationStep(GetRef(ps)); + ps->ApplyToState(this); } else { LOG(FATAL) << "Invalid step: " << step; } diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 33c27282bcfd..dc951132720f 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -82,62 +82,6 @@ enum class ComputeAtKind : int { kIter = 2, }; -/*! \brief The type of an iterator. */ -enum class IteratorKind : int { - /*! \brief Spatial iterator. */ - kSpatial = 0, - /*! \brief Reduction iterator. */ - kReduction = 1, - /*! \brief Fused spatial and reduction iterator. */ - kMixed = 2, - /*! \brief Special iterator. (e.g. virtual root iterator) */ - kSpecial = 3 -}; - -/*! - * \brief A for loop iterator - * Similar to tvm::IterVar in `include/tvm/tir/expr.h` - */ -class IteratorNode : public Object { - public: - /*! \brief The name of this iterator. */ - String name; - /*! \brief The range of this iterator. */ - Range range; - /*! \brief The iterator type of this iterator. */ - IteratorKind iter_kind; - /*! \brief The annotation type of this iterator. */ - IteratorAnnotation annotation; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("range", &range); - v->Visit("iter_kind", &iter_kind); - v->Visit("annotation", &annotation); - } - - static constexpr const char* _type_key = "auto_scheduler.Iterator"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); -}; - -/*! - * \brief Managed reference to IteratorNode. - * \sa IteratorNode - */ -class Iterator : public ObjectRef { - public: - /*! - * \brief The constructor. - * \param name The name of this iterator. - * \param range The range of this iterator. - * \param iter_kind The iterator type of this iterator. - * \param annotation The annotation type of this iterator. - */ - Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation); - - TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); -}; - /*! \brief Stage-level attributes. */ struct StageAttributes { /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */ @@ -327,13 +271,15 @@ class State : public ObjectRef { String ToStr(bool delete_trivial_loop = true) const; /*! - * \brief General do step functions with a runtime dynamic dispatcher. This will re-apply all the - * transform steps with the initial state. + * \brief General call step functions with a runtime dynamic dispatcher. This will re-apply all + * the transform steps from the initial state. * \param dag The original ComputeDAG of this state. - * \note This is different from the class member `current_compute_dag`, for some transform step - * may change the op stage structure of the ComputeDAG. + * \note The input `dag` is different from the class member `current_compute_dag`. + * This function takes the initial ComputeDAG as input to replay all the history. While the + * `current_compute_dag` is used to track the current stage status, for some transform step may + * change the op stage structure. */ - void DoSteps(const ComputeDAG& dag); + void ApplySteps(const ComputeDAG& dag); /* Step APIs for State. */ @@ -424,74 +370,6 @@ class State : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); - - private: - /*! - * \brief Do transform steps. - * \note The following functions only change loop state but do not change transform_history. - * We separate these functions out, so you can call them for replay easily given history steps */ - - /*! - * \brief Apply reorder step to current state. - * \param step A ReorderStep. - */ - void DoReorderStep(const ReorderStep& step); - /*! - * \brief Apply compute at step to current state. - * \param step A ComputeAtStep. - * \note After compute_at, we need careful dependency analysis to compute the accurate bound - * information. However, it is relatively expensive and complicated, so we just fill "None" as - * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. - */ - void DoComputeAtStep(const ComputeAtStep& step); - /*! - * \brief Apply compute root step to current state. - * \param step A ComputeRootStep. - * \note After compute_root, we need careful dependency analysis to compute the accurate bound - * information. However, it is relatively expensive and complicated, so we just fill "None" as - * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. - */ - void DoComputeRootStep(const ComputeRootStep& step); - /*! - * \brief Apply compute inline to current state. - * \param step A ComputeInline. - */ - void DoComputeInlineStep(const ComputeInlineStep& step); - /*! - * \brief Apply split step to current state. - * \param step A SplitStep. - * \return The iterator results after split. - * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner - * most iterator of split results will become the new attach point. - */ - Array DoSplitStep(const SplitStep& step); - /*! - * \brief Apply fuse step to current state. - * \param step A FuseStep. - * \return The iterator result after fuse. - * \note If the iterators to be fused have stages attached at them(by compute_at), the fused - * result will become the new attach point. - */ - Iterator DoFuseStep(const FuseStep& step); - /*! - * \brief Apply annotation step to current state. - * \param step A AnnotationStep. - * \return The iterator result after annotate. - */ - Iterator DoAnnotationStep(const AnnotationStep& step); - - /*! - * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later). - * \param stage_id The index of the stage to be split. - * \param iter_id The index of the iterator to be split. - * \param lengths The multiple split factors. - * \param inner_to_outer The split direction. - * \return The iterator results after split. - */ - Array DoSplitStepCommon(int stage_id, int iter_id, - const Array>& lengths, bool inner_to_outer); }; } // namespace auto_scheduler diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 9bb8bffc144f..d901b7e64ac2 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -29,6 +29,7 @@ #include #include +#include #include "loop_state.h" #include "utils.h" @@ -62,6 +63,16 @@ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { data_ = std::move(node); } +void ReorderStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + Array iters; + for (auto x : after_ids) { + iters.push_back(stage->iters[x]); + } + state->CopyOnWrite()->stages.Set( + stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); +} + void ReorderStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; @@ -79,7 +90,7 @@ void ReorderStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ReorderStepNode::PrintAsPythonAPI(Array* stages, +String ReorderStepNode::ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; std::stringstream ss; @@ -106,6 +117,23 @@ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_ data_ = std::move(node); } +void ComputeAtStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + + // Remove the bound information of each iterator since they may not be accurate after + // compute at + Array new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); + } + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kIter, stage->attrs)); + // Update attach map + pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id); +} + void ComputeAtStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { te::Stage stage = (*stages)[stage_id]; @@ -116,7 +144,7 @@ void ComputeAtStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ComputeAtStepNode::PrintAsPythonAPI(Array* stages, +String ComputeAtStepNode::ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -134,6 +162,23 @@ ComputeRootStep::ComputeRootStep(int stage_id) { data_ = std::move(node); } +void ComputeRootStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + + // Remove the bound information of each iterator since they may not be accurate after + // compute root + Array new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); + } + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kRoot, stage->attrs)); + // Update attach map + pstate->attach_map.DeleteStage(stage_id); +} + void ComputeRootStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; @@ -141,7 +186,7 @@ void ComputeRootStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, +String ComputeRootStepNode::ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -157,6 +202,24 @@ ComputeInlineStep::ComputeInlineStep(int stage_id) { data_ = std::move(node); } +void ComputeInlineStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + + // Check the validity of compute_inline + for (size_t i = 0; i < stage->iters.size(); ++i) { + CHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0) + << "Invalid compute_inline: There are some other stages that are attached to the " + << "target stage"; + } + + StateNode* pstate = state->CopyOnWrite(); + auto new_stage = pstate->stages[stage_id]; + new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; + pstate->stages.Set(stage_id, std::move(new_stage)); + // Update attach map + pstate->attach_map.DeleteStage(stage_id); +} + void ComputeInlineStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; @@ -164,7 +227,7 @@ void ComputeInlineStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ComputeInlineStepNode::PrintAsPythonAPI(Array* stages, +String ComputeInlineStepNode::ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -174,6 +237,86 @@ String ComputeInlineStepNode::PrintAsPythonAPI(Array* stages, } /********** Split **********/ +// common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep +Array ApplySplitToState(State* state, int stage_id, int iter_id, + const Array>& lengths, bool inner_to_outer) { + const Stage& stage = (*state)->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + size_t old_iter_size = stage->iters.size(); + bool concrete = true; + + Optional tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = NullOpt; + tosplit_extent = NullOpt; + } + + Array outs; + for (size_t i = 0; i < lengths.size(); ++i) { + Optional l; + String name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l && tosplit_min && tosplit_extent) { + res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, + IteratorAnnotation::kNone); + tosplit_min = Integer(0); + tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); + } else { + res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); + tosplit_min = NullOpt; + tosplit_extent = NullOpt; + concrete = false; + } + outs.push_back(std::move(res)); + } + + Range range; + if (tosplit_min && tosplit_extent) { + range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); + } + if (inner_to_outer) { + outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone)); + // Reverse the Iterator array + Array temp(outs.rbegin(), outs.rend()); + outs = std::move(temp); + } else { + outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, + IteratorAnnotation::kNone)); + } + + Array new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), outs.begin(), outs.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + pstate->concrete &= concrete; + + // Two vectors are used to represent the iterator relation before and after split + // The original iterators in AttachMap will be updated with the new iterators + std::vector from_iters; + std::vector to_iters; + for (size_t i = iter_id; i < old_iter_size; ++i) { + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i + lengths.size()); + } + pstate->attach_map.UpdateIters(from_iters, to_iters); + + return outs; +} + Array ApplySplitToSchedule(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, int iter_id, const Array>& lengths, bool inner_to_outer) { @@ -262,12 +405,16 @@ SplitStep::SplitStep(int stage_id, int iter_id, Optional extent, data_ = std::move(node); } +Array SplitStepNode::ApplyToState(State* state) const { + return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer); +} + Array SplitStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -String SplitStepNode::PrintAsPythonAPI(Array* stages, +String SplitStepNode::ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } @@ -283,6 +430,85 @@ FuseStep::FuseStep(int stage_id, const Array& fused_ids) { data_ = std::move(node); } +Iterator FuseStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + size_t old_iter_size = static_cast(stage->iters.size()); + + String new_name; + PrimExpr new_extent = 1; + IteratorKind new_iter_kind = IteratorKind::kSpecial; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1); + } + + if (i != fused_ids.size() - 1) { + const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) != + iter_to_attached_stage.end()) { + LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " + << "stages. State before fusion:\n" + << (*state); + } + } + + const Iterator& it = stage->iters[fused_ids[i]]; + new_name = new_name + it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_kind = it->iter_kind; + } else { + if (new_iter_kind != it->iter_kind) { + new_iter_kind = IteratorKind::kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::FromMinExtent(0, new_extent); + } + Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); + Array new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + + // Two vectors are used to represent the iterator relation before and after fuse + // The original iterators in AttachMap will be updated with the new iterators + std::vector from_iters; + std::vector to_iters; + const size_t begin_id = fused_ids.front(), end_id = fused_ids.back(); + for (size_t i = 0; i < old_iter_size; ++i) { + if (i <= begin_id) { + continue; + } else if (i > end_id) { + // move forward + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i - end_id + begin_id); + } else { + // move to the fused id + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, begin_id); + } + } + pstate->attach_map.UpdateIters(from_iters, to_iters); + + return new_it; +} + IterVar FuseStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; @@ -305,7 +531,7 @@ IterVar FuseStepNode::ApplyToSchedule(Array* stages, return fused_axis; } -String FuseStepNode::PrintAsPythonAPI(Array* stages, +String FuseStepNode::ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; std::stringstream to_fuse; @@ -335,6 +561,18 @@ AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann data_ = std::move(node); } +Iterator AnnotationStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + Iterator it = stage->iters[iter_id]; + + CHECK(it->annotation == IteratorAnnotation::kNone); + Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters.Set(iter_id, new_it); + state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage)); + return new_it; +} + void AnnotationStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { te::Stage stage = (*stages)[stage_id]; @@ -370,7 +608,7 @@ void AnnotationStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String AnnotationStepNode::PrintAsPythonAPI(Array* stages, +String AnnotationStepNode::ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 51cac332ec45..5e7344d791b7 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -28,7 +28,7 @@ * Take fuse step for example: * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction * function `FuseStep::FuseStep(...)` in `transform_steps.cc` - * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. + * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::ApplyToPythonAPI`. * - In these two functions you need to lower this step with tvm's te schedule API * 3. Implement `State::fuse` and `State::DoFuseStep`. * - In these two functions you need to incrementally update all data structures in State with @@ -53,6 +53,18 @@ namespace auto_scheduler { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; +/*! \brief The type of an iterator. */ +enum class IteratorKind : int { + /*! \brief Spatial iterator. */ + kSpatial = 0, + /*! \brief Reduction iterator. */ + kReduction = 1, + /*! \brief Fused spatial and reduction iterator. */ + kMixed = 2, + /*! \brief Special iterator. (e.g. virtual root iterator) */ + kSpecial = 3 +}; + /*! \brief The type of an iterator's annotation. */ enum class IteratorAnnotation : int { /*! \brief This iterator has no annotation. */ @@ -83,6 +95,52 @@ enum class IteratorAnnotation : int { extern const char* IteratorAnnotationString[]; +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + /*! \brief The name of this iterator. */ + String name; + /*! \brief The range of this iterator. */ + Range range; + /*! \brief The iterator type of this iterator. */ + IteratorKind iter_kind; + /*! \brief The annotation type of this iterator. */ + IteratorAnnotation annotation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + v->Visit("iter_kind", &iter_kind); + v->Visit("annotation", &annotation); + } + + static constexpr const char* _type_key = "auto_scheduler.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; + +/*! + * \brief Managed reference to IteratorNode. + * \sa IteratorNode + */ +class Iterator : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of this iterator. + * \param range The range of this iterator. + * \param iter_kind The iterator type of this iterator. + * \param annotation The annotation type of this iterator. + */ + Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation); + + TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); +}; + +class State; + /*! * \brief The base class of transformation steps. Each step has its corresponding tvm.te * schedule primitives. @@ -115,7 +173,13 @@ class ReorderStepNode : public StepNode { Array after_ids; /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State + * \param state A mutable pointer to State. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. */ @@ -127,7 +191,7 @@ class ReorderStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); @@ -158,7 +222,17 @@ class ComputeAtStepNode : public StepNode { int target_iter_id; /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State + * \param state A mutable pointer to State. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. */ @@ -170,7 +244,7 @@ class ComputeAtStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); @@ -197,7 +271,17 @@ class ComputeAtStep : public Step { class ComputeRootStepNode : public StepNode { public: /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State + * \param state A mutable pointer to State. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. @@ -210,7 +294,7 @@ class ComputeRootStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); @@ -235,7 +319,13 @@ class ComputeRootStep : public Step { class ComputeInlineStepNode : public StepNode { public: /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State + * \param state A mutable pointer to State. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. @@ -248,7 +338,7 @@ class ComputeInlineStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); @@ -288,7 +378,16 @@ class SplitStepNode : public StepNode { bool inner_to_outer; /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State + * \param state A mutable pointer to State. + * \return The iterator results after split. + * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner + * most iterator of split results will become the new attach point. + */ + Array ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator results after split. @@ -302,7 +401,7 @@ class SplitStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); @@ -335,7 +434,16 @@ class FuseStepNode : public StepNode { Array fused_ids; /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State + * \param state A mutable pointer to State. + * \return The iterator result after fuse. + * \note If the iterators to be fused have stages attached at them(by compute_at), the fused + * result will become the new attach point. + */ + Iterator ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. @@ -348,7 +456,7 @@ class FuseStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); @@ -382,7 +490,14 @@ class AnnotationStepNode : public StepNode { IteratorAnnotation annotation; /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State + * \param state A mutable pointer to State. + * \return The iterator result after annotate. + */ + Iterator ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. @@ -395,7 +510,7 @@ class AnnotationStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); From 30828a8708e3ca8b27278ad5130f79f97ba70dab Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 19 Jul 2020 22:35:50 +0800 Subject: [PATCH 11/14] Move measure_record implementation to step --- src/auto_scheduler/measure_record.cc | 198 +++------------------------ src/auto_scheduler/transform_step.cc | 169 +++++++++++++++++++++++ src/auto_scheduler/transform_step.h | 86 ++++++++++++ src/auto_scheduler/utils.h | 21 +++ 4 files changed, 294 insertions(+), 180 deletions(-) diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 0e01dab434fb..889cd5b9e0cd 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -42,25 +42,6 @@ namespace dmlc { namespace json { -inline std::vector IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) { - std::vector out; - for (const auto& x : data) { - CHECK(x.defined()); - out.push_back(x); - } - return out; -} - -inline std::vector IntArrayToVector( - const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) { - std::vector out; - for (const auto& x : data) { - CHECK(x); - out.push_back(x.value()); - } - return out; -} - template <> struct Handler<::tvm::Array<::tvm::auto_scheduler::Stage>> { inline static void Write(dmlc::JSONWriter* writer, @@ -82,44 +63,10 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::auto_scheduler::Step>& data) { writer->BeginArray(false); - for (size_t i = 0; i < data.size(); ++i) { + for (const auto& step : data) { writer->WriteArraySeperator(); writer->BeginArray(false); - if (auto ps = data[i].as<::tvm::auto_scheduler::ReorderStepNode>()) { - writer->WriteArrayItem(std::string("RE")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(IntArrayToVector(ps->after_ids)); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::ComputeAtStepNode>()) { - writer->WriteArrayItem(std::string("CA")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->target_stage_id); - writer->WriteArrayItem(ps->target_iter_id); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::ComputeRootStepNode>()) { - writer->WriteArrayItem(std::string("CR")); - writer->WriteArrayItem(ps->stage_id); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::ComputeInlineStepNode>()) { - writer->WriteArrayItem(std::string("CI")); - writer->WriteArrayItem(ps->stage_id); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::SplitStepNode>()) { - writer->WriteArrayItem(std::string("SP")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->extent ? ::tvm::auto_scheduler::GetIntImm(ps->extent.value()) - : 0); - writer->WriteArrayItem(IntArrayToVector(ps->lengths)); - writer->WriteArrayItem(static_cast(ps->inner_to_outer)); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::FuseStepNode>()) { - writer->WriteArrayItem(std::string("FU")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(IntArrayToVector(ps->fused_ids)); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::AnnotationStepNode>()) { - writer->WriteArrayItem(std::string("AN")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(static_cast(ps->annotation)); - } else { - LOG(FATAL) << "Invalid step: " << data[i]; - } + step->WriteToRecord(writer); writer->EndArray(); } writer->EndArray(); @@ -127,102 +74,12 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::auto_scheduler::Step>* data) { - std::vector int_list; - bool s, inner_to_outer; - std::string name, scope_name, pragma_type, ti_func_name; - int stage_id, iter_id, extent, ann, target_stage_id; - reader->BeginArray(); data->clear(); while (reader->NextArrayItem()) { reader->BeginArray(); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&name); - if (name == "RE") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&int_list); - ::tvm::Array<::tvm::Integer> after_ids; - for (const auto& i : int_list) { - after_ids.push_back(i); - } - data->push_back(::tvm::auto_scheduler::ReorderStep(stage_id, after_ids)); - } else if (name == "CA") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&target_stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&iter_id); - data->push_back(::tvm::auto_scheduler::ComputeAtStep(stage_id, target_stage_id, iter_id)); - } else if (name == "CR") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - data->push_back(::tvm::auto_scheduler::ComputeRootStep(stage_id)); - } else if (name == "CI") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - data->push_back(::tvm::auto_scheduler::ComputeInlineStep(stage_id)); - } else if (name == "SP") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&extent); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&int_list); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&inner_to_outer); - ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths; - for (const auto& i : int_list) { - lengths.push_back(::tvm::Integer(i)); - } - data->push_back(::tvm::auto_scheduler::SplitStep( - stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer)); - } else if (name == "FU") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&int_list); - ::tvm::Array<::tvm::Integer> fused_ids; - for (const auto& i : int_list) { - fused_ids.push_back(i); - } - data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids)); - } else if (name == "AN") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&ann); - data->push_back(::tvm::auto_scheduler::AnnotationStep( - stage_id, iter_id, ::tvm::auto_scheduler::IteratorAnnotation(ann))); - } else { - LOG(FATAL) << "Invalid step format"; - } - s = reader->NextArrayItem(); - CHECK(!s); + data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader)); + CHECK(!reader->NextArrayItem()); } } }; @@ -237,15 +94,11 @@ struct Handler<::tvm::auto_scheduler::StateNode> { } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) { reader->BeginArray(); - bool s; - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&data->stages); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&data->transform_steps); - s = reader->NextArrayItem(); - CHECK(!s); + CHECK(!reader->NextArrayItem()); } }; @@ -260,19 +113,14 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { std::string target_str; - bool s; - reader->BeginArray(); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&target_str); data->workload_key = std::move(target_str); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&target_str); data->target = ::tvm::Target::Create(target_str); - s = reader->NextArrayItem(); - CHECK(!s); + CHECK(!reader->NextArrayItem()); } }; @@ -286,20 +134,16 @@ struct Handler<::tvm::auto_scheduler::MeasureInputNode> { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureInputNode* data) { - bool s; auto task_node = ::tvm::make_object<::tvm::auto_scheduler::SearchTaskNode>(); auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>(); state_node->concrete = true; reader->BeginArray(); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(task_node.get()); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(state_node.get()); - s = reader->NextArrayItem(); - CHECK(!s); + CHECK(!reader->NextArrayItem()); data->task = ::tvm::auto_scheduler::SearchTask(task_node); data->state = ::tvm::auto_scheduler::State(state_node); @@ -326,28 +170,22 @@ struct Handler<::tvm::auto_scheduler::MeasureResultNode> { } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureResultNode* data) { - bool s; std::vector tmp; reader->BeginArray(); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&tmp); data->costs.clear(); for (const auto& i : tmp) { data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); } - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&data->error_no); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&data->all_cost); - s = reader->NextArrayItem(); - CHECK(s); + CHECK(reader->NextArrayItem()); reader->Read(&data->timestamp); - s = reader->NextArrayItem(); - CHECK(!s); + CHECK(!reader->NextArrayItem()); } }; diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index d901b7e64ac2..1f6dcfdcc02d 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -52,6 +53,30 @@ const char* IteratorAnnotationString[] = { "tensorize" // kTensorized = 11 }; +Step StepReadFromRecord(dmlc::JSONReader* reader) { + std::string name; + CHECK(reader->NextArrayItem()); + reader->Read(&name); + if (name == ReorderStepNode::record_prefix_str) { + return ReorderStep(reader); + } else if (name == ComputeAtStepNode::record_prefix_str) { + return ComputeAtStep(reader); + } else if (name == ComputeRootStepNode::record_prefix_str) { + return ComputeRootStep(reader); + } else if (name == ComputeInlineStepNode::record_prefix_str) { + return ComputeInlineStep(reader); + } else if (name == SplitStepNode::record_prefix_str) { + return SplitStep(reader); + } else if (name == FuseStepNode::record_prefix_str) { + return FuseStep(reader); + } else if (name == AnnotationStepNode::record_prefix_str) { + return AnnotationStep(reader); + } else { + LOG(FATAL) << "Invalid step format"; + } + return Step(); +} + /********** Reorder **********/ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); @@ -63,6 +88,28 @@ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { data_ = std::move(node); } +ReorderStep::ReorderStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + CHECK(reader->NextArrayItem()); + std::vector int_list; + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> after_ids; + for (const auto& i : int_list) { + after_ids.push_back(i); + } + node->after_ids = after_ids; + data_ = std::move(node); +} + +void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(IntArrayToVector(after_ids)); +} + void ReorderStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; Array iters; @@ -117,6 +164,24 @@ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_ data_ = std::move(node); } +ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + CHECK(reader->NextArrayItem()); + reader->Read(&node->target_stage_id); + CHECK(reader->NextArrayItem()); + reader->Read(&node->target_iter_id); + data_ = std::move(node); +} + +void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(target_stage_id); + writer->WriteArrayItem(target_iter_id); +} void ComputeAtStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; @@ -162,6 +227,19 @@ ComputeRootStep::ComputeRootStep(int stage_id) { data_ = std::move(node); } +ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + data_ = std::move(node); +} + +void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); +} + void ComputeRootStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; @@ -202,6 +280,19 @@ ComputeInlineStep::ComputeInlineStep(int stage_id) { data_ = std::move(node); } +ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + data_ = std::move(node); +} + +void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); +} + void ComputeInlineStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; @@ -405,6 +496,41 @@ SplitStep::SplitStep(int stage_id, int iter_id, Optional extent, data_ = std::move(node); } +SplitStep::SplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + CHECK(reader->NextArrayItem()); + reader->Read(&node->iter_id); + int int_val; + CHECK(reader->NextArrayItem()); + reader->Read(&int_val); + if (int_val) { + node->extent = Integer(int_val); + } + CHECK(reader->NextArrayItem()); + std::vector int_list; + reader->Read(&int_list); + ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths; + for (const auto& i : int_list) { + lengths.push_back(::tvm::Integer(i)); + } + node->lengths = lengths; + CHECK(reader->NextArrayItem()); + reader->Read(&node->inner_to_outer); + data_ = std::move(node); +} + +void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0); + writer->WriteArrayItem(IntArrayToVector(lengths)); + writer->WriteArrayItem(static_cast(inner_to_outer)); +} + Array SplitStepNode::ApplyToState(State* state) const { return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer); } @@ -430,6 +556,28 @@ FuseStep::FuseStep(int stage_id, const Array& fused_ids) { data_ = std::move(node); } +FuseStep::FuseStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + std::vector int_list; + CHECK(reader->NextArrayItem()); + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> fused_ids; + for (const auto& i : int_list) { + fused_ids.push_back(i); + } + node->fused_ids = fused_ids; + data_ = std::move(node); +} + +void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(IntArrayToVector(fused_ids)); +} + Iterator FuseStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; size_t old_iter_size = static_cast(stage->iters.size()); @@ -561,6 +709,27 @@ AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann data_ = std::move(node); } +AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + CHECK(reader->NextArrayItem()); + reader->Read(&node->iter_id); + CHECK(reader->NextArrayItem()); + int int_val; + reader->Read(&int_val); + node->annotation = IteratorAnnotation(int_val); + data_ = std::move(node); +} + +void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(static_cast(annotation)); +} + Iterator AnnotationStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; Iterator it = stage->iters[iter_id]; diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 5e7344d791b7..d9465259a5f5 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -43,6 +43,7 @@ #define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ #include +#include #include #include @@ -150,6 +151,12 @@ class StepNode : public Object { /*! \brief The index of the stage. */ int stage_id; + /*! + * \brief Serialize the current step record to JSONWriter. + * \param writer The output JSONWriter. + */ + virtual void WriteToRecord(dmlc::JSONWriter* writer) const = 0; + static constexpr const char* _type_key = "auto_scheduler.Step"; TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); }; @@ -163,6 +170,12 @@ class Step : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); }; +/*! + * \brief Read a step record from JSONReader and create the corresponding step. + * \param reader The input JSONReader. + */ +Step StepReadFromRecord(dmlc::JSONReader* reader); + /*! \brief Reorder step that corresponds to te::Stage::reorder */ class ReorderStepNode : public StepNode { public: @@ -172,6 +185,8 @@ class ReorderStepNode : public StepNode { */ Array after_ids; + void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! * \brief Apply the current step to State * \param state A mutable pointer to State. @@ -193,6 +208,8 @@ class ReorderStepNode : public StepNode { */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "RE"; + static constexpr const char* _type_key = "auto_scheduler.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; @@ -210,6 +227,13 @@ class ReorderStep : public Step { */ ReorderStep(int stage_id, const Array& after_ids); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ReorderStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; @@ -221,6 +245,8 @@ class ComputeAtStepNode : public StepNode { /*! \brief The index of iterator in target stage that this step will compute at to. */ int target_iter_id; + void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! * \brief Apply the current step to State * \param state A mutable pointer to State. @@ -246,6 +272,8 @@ class ComputeAtStepNode : public StepNode { */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "CA"; + static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); }; @@ -264,12 +292,21 @@ class ComputeAtStep : public Step { */ ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ComputeAtStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); }; /*! \brief Compute root step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode : public StepNode { public: + void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! * \brief Apply the current step to State * \param state A mutable pointer to State. @@ -296,6 +333,8 @@ class ComputeRootStepNode : public StepNode { */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "CR"; + static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); }; @@ -312,12 +351,20 @@ class ComputeRootStep : public Step { */ explicit ComputeRootStep(int stage_id); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ComputeRootStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; /*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ class ComputeInlineStepNode : public StepNode { public: + void WriteToRecord(dmlc::JSONWriter* writer) const final; /*! * \brief Apply the current step to State * \param state A mutable pointer to State. @@ -340,6 +387,8 @@ class ComputeInlineStepNode : public StepNode { */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "CI"; + static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); }; @@ -356,6 +405,13 @@ class ComputeInlineStep : public Step { */ explicit ComputeInlineStep(int stage_id); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ComputeInlineStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); }; @@ -377,6 +433,7 @@ class SplitStepNode : public StepNode { */ bool inner_to_outer; + void WriteToRecord(dmlc::JSONWriter* writer) const final; /*! * \brief Apply the current step to State * \param state A mutable pointer to State. @@ -403,6 +460,8 @@ class SplitStepNode : public StepNode { */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "SP"; + static constexpr const char* _type_key = "auto_scheduler.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); }; @@ -424,6 +483,13 @@ class SplitStep : public Step { SplitStep(int stage_id, int iter_id, Optional extent, const Array>& lengths, bool inner_to_outer); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit SplitStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; @@ -433,6 +499,7 @@ class FuseStepNode : public StepNode { /*! \brief The ids of iterators to fuse. */ Array fused_ids; + void WriteToRecord(dmlc::JSONWriter* writer) const final; /*! * \brief Apply the current step to State * \param state A mutable pointer to State. @@ -458,6 +525,8 @@ class FuseStepNode : public StepNode { */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "FU"; + static constexpr const char* _type_key = "auto_scheduler.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); }; @@ -475,6 +544,13 @@ class FuseStep : public Step { */ FuseStep(int stage_id, const Array& fused_ids); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit FuseStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; @@ -489,6 +565,7 @@ class AnnotationStepNode : public StepNode { /*! \brief The annotation type of this step. */ IteratorAnnotation annotation; + void WriteToRecord(dmlc::JSONWriter* writer) const final; /*! * \brief Apply the current step to State * \param state A mutable pointer to State. @@ -512,6 +589,8 @@ class AnnotationStepNode : public StepNode { */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "AN"; + static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); }; @@ -530,6 +609,13 @@ class AnnotationStep : public Step { */ AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit AnnotationStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); }; diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index ccac3bbb75a2..de800da13b64 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -107,6 +107,27 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st } } +/*! \brief Convert a Array to std::vector. */ +inline std::vector IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) { + std::vector out; + for (const auto& x : data) { + CHECK(x.defined()); + out.push_back(x); + } + return out; +} + +/*! \brief Convert a Array> to std::vector. */ +inline std::vector IntArrayToVector( + const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) { + std::vector out; + for (const auto& x : data) { + CHECK(x); + out.push_back(x.value()); + } + return out; +} + /********** Utilities for TVM Containers / ByteArray **********/ /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { From 4bb58ceb38faa26bb0533732269a09d343f199e0 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 20 Jul 2020 10:27:53 +0800 Subject: [PATCH 12/14] Order update & API update --- src/auto_scheduler/compute_dag.cc | 40 +- src/auto_scheduler/loop_state.cc | 19 +- src/auto_scheduler/loop_state.h | 2 +- src/auto_scheduler/transform_step.cc | 845 ++++++++++-------- src/auto_scheduler/transform_step.h | 406 +++++---- .../unittest/test_auto_scheduler_measure.py | 2 +- 6 files changed, 686 insertions(+), 628 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index e675c24d1238..fe7fe79d170b 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -270,27 +270,9 @@ std::pair> ComputeDAG::ApplySteps( } // Apply the history steps to TVM schedule + // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { - // Call each step's ApplyToSchedule method - // Note: some steps have extra parameters that must be passed and they may need different - // return value, so the ApplyToSchedule is not able to be merged to single interface - if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else { - LOG(FATAL) << "Invalid Step"; - } + StepApplyToSchedule(step, stages, stage_to_axes); } return std::make_pair(schedule, operator->()->tensors); @@ -334,23 +316,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } // Call each step's ApplyToPythonAPI method for (const auto& step : transform_steps) { - if (auto ps = step.as()) { - ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->ApplyToPythonAPI(&stages, &stage_to_axes); - } else { - LOG(FATAL) << "Invalid Step"; - } + ss << StepApplyToPythonAPI(step, &stages, &stage_to_axes); } return ss.str(); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 08512e9b0cd8..8634a75a0bac 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -261,24 +261,9 @@ Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread void State::ApplySteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; + // Call each step's ApplyToState method for (const auto& step : operator->()->transform_steps) { - if (auto ps = step.as()) { - ps->ApplyToState(this); - } else if (auto ps = step.as()) { - ps->ApplyToState(this); - } else if (auto ps = step.as()) { - ps->ApplyToState(this); - } else if (auto ps = step.as()) { - ps->ApplyToState(this); - } else if (auto ps = step.as()) { - ps->ApplyToState(this); - } else if (auto ps = step.as()) { - ps->ApplyToState(this); - } else if (auto ps = step.as()) { - ps->ApplyToState(this); - } else { - LOG(FATAL) << "Invalid step: " << step; - } + StepApplyToState(step, this, dag); } } diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index dc951132720f..2666b1082f93 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -92,7 +92,7 @@ struct StageAttributes { /*! * \brief A op stage in the compute declaration. - * Similar to te::Stage in `include/schedule.h`. + * Similar to te::Stage in `include/tvm/te/schedule.h`. */ class StageNode : public Object { public: diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 1f6dcfdcc02d..c15c377aa39b 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -57,272 +57,450 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { std::string name; CHECK(reader->NextArrayItem()); reader->Read(&name); - if (name == ReorderStepNode::record_prefix_str) { + if (name == AnnotationStepNode::record_prefix_str) { + return AnnotationStep(reader); + } else if (name == FuseStepNode::record_prefix_str) { + return FuseStep(reader); + } else if (name == ReorderStepNode::record_prefix_str) { return ReorderStep(reader); + } else if (name == SplitStepNode::record_prefix_str) { + return SplitStep(reader); } else if (name == ComputeAtStepNode::record_prefix_str) { return ComputeAtStep(reader); - } else if (name == ComputeRootStepNode::record_prefix_str) { - return ComputeRootStep(reader); } else if (name == ComputeInlineStepNode::record_prefix_str) { return ComputeInlineStep(reader); - } else if (name == SplitStepNode::record_prefix_str) { - return SplitStep(reader); - } else if (name == FuseStepNode::record_prefix_str) { - return FuseStep(reader); - } else if (name == AnnotationStepNode::record_prefix_str) { - return AnnotationStep(reader); + } else if (name == ComputeRootStepNode::record_prefix_str) { + return ComputeRootStep(reader); } else { - LOG(FATAL) << "Invalid step format"; + LOG(FATAL) << "Invalid step format: " << name; } return Step(); } -/********** Reorder **********/ -ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { - auto node = make_object(); - node->stage_id = stage_id; - for (const auto& x : after_ids) { - CHECK(x->IsInstance()); - } - node->after_ids = after_ids; - data_ = std::move(node); -} - -ReorderStep::ReorderStep(dmlc::JSONReader* reader) { - auto node = make_object(); - CHECK(reader->NextArrayItem()); - reader->Read(&node->stage_id); - CHECK(reader->NextArrayItem()); - std::vector int_list; - reader->Read(&int_list); - ::tvm::Array<::tvm::Integer> after_ids; - for (const auto& i : int_list) { - after_ids.push_back(i); +void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { + if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else { + LOG(FATAL) << "Invalid step: " << step; } - node->after_ids = after_ids; - data_ = std::move(node); -} - -void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { - writer->WriteArraySeperator(); - writer->WriteString(record_prefix_str); - writer->WriteArrayItem(stage_id); - writer->WriteArrayItem(IntArrayToVector(after_ids)); } -void ReorderStepNode::ApplyToState(State* state) const { - const Stage& stage = (*state)->stages[stage_id]; - Array iters; - for (auto x : after_ids) { - iters.push_back(stage->iters[x]); +void StepApplyToSchedule(const Step& step, Array* stages, + StageToAxesMap* stage_to_axes) { + if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step: " << step; } - state->CopyOnWrite()->stages.Set( - stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); } -void ReorderStepNode::ApplyToSchedule(Array* stages, - StageToAxesMap* stage_to_axes) const { - auto stage = (*stages)[stage_id]; - const Array& axes = stage_to_axes->at(stage); - CHECK_EQ(after_ids.size(), axes.size()); - - Array new_axes; - new_axes.reserve(axes.size()); - for (auto i : after_ids) { - new_axes.push_back(axes[i]); +String StepApplyToPythonAPI(const Step& step, Array* stages, + StageToAxesMap* stage_to_axes) { + if (auto ps = step.as()) { + return ps->ApplyToPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->ApplyToPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->ApplyToPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->ApplyToPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->ApplyToPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->ApplyToPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->ApplyToPythonAPI(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step: " << step; } - stage.reorder(new_axes); - - stage_to_axes->Set(stage, std::move(new_axes)); - stages->Set(stage_id, std::move(stage)); + return ""; } -String ReorderStepNode::ApplyToPythonAPI(Array* stages, - StageToAxesMap* stage_to_axes) const { - const auto& stage = (*stages)[stage_id]; - std::stringstream ss; - - ss << "s[" << CleanName(stage->op->name) << "].reorder("; - for (size_t i = 0; i < after_ids.size(); ++i) { - ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); - if (i != after_ids.size() - 1) { - ss << ", "; - } - } - ss << ")\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} +/********** Primitives working on single stage **********/ -/********** Compute At **********/ -ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { - auto node = make_object(); +/********** Annotation **********/ +AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { + auto node = make_object(); node->stage_id = stage_id; - node->target_stage_id = target_stage_id; - node->target_iter_id = target_iter_id; + node->iter_id = iter_id; + node->annotation = ann; data_ = std::move(node); } -ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) { - auto node = make_object(); +AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) { + auto node = make_object(); CHECK(reader->NextArrayItem()); reader->Read(&node->stage_id); CHECK(reader->NextArrayItem()); - reader->Read(&node->target_stage_id); + reader->Read(&node->iter_id); CHECK(reader->NextArrayItem()); - reader->Read(&node->target_iter_id); + int int_val; + reader->Read(&int_val); + node->annotation = IteratorAnnotation(int_val); data_ = std::move(node); } -void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { +void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArraySeperator(); writer->WriteString(record_prefix_str); writer->WriteArrayItem(stage_id); - writer->WriteArrayItem(target_stage_id); - writer->WriteArrayItem(target_iter_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(static_cast(annotation)); } -void ComputeAtStepNode::ApplyToState(State* state) const { - const Stage& stage = (*state)->stages[stage_id]; - // Remove the bound information of each iterator since they may not be accurate after - // compute at - Array new_iters; - for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); - } +Iterator AnnotationStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + Iterator it = stage->iters[iter_id]; - StateNode* pstate = state->CopyOnWrite(); - pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), - ComputeAtKind::kIter, stage->attrs)); - // Update attach map - pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id); + CHECK(it->annotation == IteratorAnnotation::kNone); + Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters.Set(iter_id, new_it); + state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage)); + return new_it; } -void ComputeAtStepNode::ApplyToSchedule(Array* stages, - StageToAxesMap* stage_to_axes) const { +void AnnotationStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { te::Stage stage = (*stages)[stage_id]; - const auto& target_stage = (*stages)[target_stage_id]; - const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id]; - stage.compute_at(target_stage, target_axis); + const Array& axes = (*stage_to_axes)[stage]; + + switch (annotation) { + case IteratorAnnotation::kUnroll: + stage.unroll(axes[iter_id]); + break; + case IteratorAnnotation::kVectorize: + stage.vectorize(axes[iter_id]); + break; + case IteratorAnnotation::kParallel: + stage.parallel(axes[iter_id]); + break; + case IteratorAnnotation::kVThread: + case IteratorAnnotation::kBlockX: + case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kBlockZ: + case IteratorAnnotation::kThreadX: + case IteratorAnnotation::kThreadY: + case IteratorAnnotation::kThreadZ: + stage.bind(axes[iter_id], + te::thread_axis(Range(), IteratorAnnotationString[static_cast(annotation)])); + break; + case IteratorAnnotation::kNone: + break; + default: + LOG(FATAL) << "Invalid Annotation " << static_cast(annotation); + break; + } stages->Set(stage_id, std::move(stage)); } -String ComputeAtStepNode::ApplyToPythonAPI(Array* stages, - StageToAxesMap* stage_to_axes) const { +String AnnotationStepNode::ApplyToPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - const auto& target_stage = (*stages)[target_stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name) - << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n"; + const auto& iter = (*stage_to_axes)[stage][iter_id]; + + ss << "s[" << CleanName(stage->op->name) << "]."; + switch (annotation) { + case IteratorAnnotation::kUnroll: + ss << "unroll("; + break; + case IteratorAnnotation::kVectorize: + ss << "vectorize("; + break; + case IteratorAnnotation::kParallel: + ss << "parallel("; + break; + case IteratorAnnotation::kVThread: + case IteratorAnnotation::kBlockX: + case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kBlockZ: + case IteratorAnnotation::kThreadX: + case IteratorAnnotation::kThreadY: + case IteratorAnnotation::kThreadZ: + ss << "bind("; + break; + case IteratorAnnotation::kNone: + break; + default: + LOG(FATAL) << "Invalid annotation " << static_cast(annotation); + break; + } + ss << CleanName(iter->var->name_hint); + switch (annotation) { + case IteratorAnnotation::kVThread: + case IteratorAnnotation::kBlockX: + case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kBlockZ: + case IteratorAnnotation::kThreadX: + case IteratorAnnotation::kThreadY: + case IteratorAnnotation::kThreadZ: + ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast(annotation)] + << "\")"; + break; + default: + break; + } + ss << ")\n"; + ApplyToSchedule(stages, stage_to_axes); return ss.str(); } -/********** Compute Root **********/ -ComputeRootStep::ComputeRootStep(int stage_id) { - auto node = make_object(); +/********** Fuse **********/ +FuseStep::FuseStep(int stage_id, const Array& fused_ids) { + auto node = make_object(); node->stage_id = stage_id; + for (const auto& x : fused_ids) { + CHECK(x->IsInstance()); + } + node->fused_ids = fused_ids; data_ = std::move(node); } -ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) { - auto node = make_object(); +FuseStep::FuseStep(dmlc::JSONReader* reader) { + auto node = make_object(); CHECK(reader->NextArrayItem()); reader->Read(&node->stage_id); + std::vector int_list; + CHECK(reader->NextArrayItem()); + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> fused_ids; + for (const auto& i : int_list) { + fused_ids.push_back(i); + } + node->fused_ids = fused_ids; data_ = std::move(node); } -void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { +void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArraySeperator(); writer->WriteString(record_prefix_str); writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(IntArrayToVector(fused_ids)); } -void ComputeRootStepNode::ApplyToState(State* state) const { +Iterator FuseStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; + size_t old_iter_size = static_cast(stage->iters.size()); - // Remove the bound information of each iterator since they may not be accurate after - // compute root - Array new_iters; - for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); - } - - StateNode* pstate = state->CopyOnWrite(); - pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), - ComputeAtKind::kRoot, stage->attrs)); - // Update attach map - pstate->attach_map.DeleteStage(stage_id); + String new_name; + PrimExpr new_extent = 1; + IteratorKind new_iter_kind = IteratorKind::kSpecial; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1); + } + + if (i != fused_ids.size() - 1) { + const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) != + iter_to_attached_stage.end()) { + LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " + << "stages. State before fusion:\n" + << (*state); + } + } + + const Iterator& it = stage->iters[fused_ids[i]]; + new_name = new_name + it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_kind = it->iter_kind; + } else { + if (new_iter_kind != it->iter_kind) { + new_iter_kind = IteratorKind::kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::FromMinExtent(0, new_extent); + } + Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); + Array new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + + // Two vectors are used to represent the iterator relation before and after fuse + // The original iterators in AttachMap will be updated with the new iterators + std::vector from_iters; + std::vector to_iters; + const size_t begin_id = fused_ids.front(), end_id = fused_ids.back(); + for (size_t i = 0; i < old_iter_size; ++i) { + if (i <= begin_id) { + continue; + } else if (i > end_id) { + // move forward + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i - end_id + begin_id); + } else { + // move to the fused id + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, begin_id); + } + } + pstate->attach_map.UpdateIters(from_iters, to_iters); + + return new_it; } -void ComputeRootStepNode::ApplyToSchedule(Array* stages, - StageToAxesMap* stage_to_axes) const { +IterVar FuseStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; - stage.compute_root(); + const Array& axes = stage_to_axes->at(stage); + + Array to_fuse; + for (const auto& i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + + Array new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); + + stage_to_axes->Set(stage, std::move(new_axes)); stages->Set(stage_id, std::move(stage)); + return fused_axis; } -String ComputeRootStepNode::ApplyToPythonAPI(Array* stages, - StageToAxesMap* stage_to_axes) const { - std::stringstream ss; +String FuseStepNode::ApplyToPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; - ApplyToSchedule(stages, stage_to_axes); + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse(" + << to_fuse.str() << ")\n"; + return ss.str(); } -/********** Compute Inline **********/ -ComputeInlineStep::ComputeInlineStep(int stage_id) { - auto node = make_object(); +/********** Reorder **********/ +ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { + auto node = make_object(); node->stage_id = stage_id; + for (const auto& x : after_ids) { + CHECK(x->IsInstance()); + } + node->after_ids = after_ids; data_ = std::move(node); } -ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) { - auto node = make_object(); +ReorderStep::ReorderStep(dmlc::JSONReader* reader) { + auto node = make_object(); CHECK(reader->NextArrayItem()); reader->Read(&node->stage_id); + CHECK(reader->NextArrayItem()); + std::vector int_list; + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> after_ids; + for (const auto& i : int_list) { + after_ids.push_back(i); + } + node->after_ids = after_ids; data_ = std::move(node); } -void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { +void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArraySeperator(); writer->WriteString(record_prefix_str); writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(IntArrayToVector(after_ids)); } -void ComputeInlineStepNode::ApplyToState(State* state) const { +void ReorderStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; - - // Check the validity of compute_inline - for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0) - << "Invalid compute_inline: There are some other stages that are attached to the " - << "target stage"; + Array iters; + for (auto x : after_ids) { + iters.push_back(stage->iters[x]); } - - StateNode* pstate = state->CopyOnWrite(); - auto new_stage = pstate->stages[stage_id]; - new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; - pstate->stages.Set(stage_id, std::move(new_stage)); - // Update attach map - pstate->attach_map.DeleteStage(stage_id); + state->CopyOnWrite()->stages.Set( + stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); } -void ComputeInlineStepNode::ApplyToSchedule(Array* stages, - StageToAxesMap* stage_to_axes) const { +void ReorderStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; - stage.compute_inline(); + const Array& axes = stage_to_axes->at(stage); + CHECK_EQ(after_ids.size(), axes.size()); + + Array new_axes; + new_axes.reserve(axes.size()); + for (auto i : after_ids) { + new_axes.push_back(axes[i]); + } + stage.reorder(new_axes); + + stage_to_axes->Set(stage, std::move(new_axes)); stages->Set(stage_id, std::move(stage)); } -String ComputeInlineStepNode::ApplyToPythonAPI(Array* stages, - StageToAxesMap* stage_to_axes) const { - std::stringstream ss; +String ReorderStepNode::ApplyToPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; + std::stringstream ss; + + ss << "s[" << CleanName(stage->op->name) << "].reorder("; + for (size_t i = 0; i < after_ids.size(); ++i) { + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + if (i != after_ids.size() - 1) { + ss << ", "; + } + } + ss << ")\n"; + ApplyToSchedule(stages, stage_to_axes); return ss.str(); } @@ -545,287 +723,176 @@ String SplitStepNode::ApplyToPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -/********** Fuse **********/ -FuseStep::FuseStep(int stage_id, const Array& fused_ids) { - auto node = make_object(); +/********** Primitives working on multiple stages **********/ + +/********** Compute At **********/ +ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { + auto node = make_object(); node->stage_id = stage_id; - for (const auto& x : fused_ids) { - CHECK(x->IsInstance()); - } - node->fused_ids = fused_ids; + node->target_stage_id = target_stage_id; + node->target_iter_id = target_iter_id; data_ = std::move(node); } -FuseStep::FuseStep(dmlc::JSONReader* reader) { - auto node = make_object(); +ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) { + auto node = make_object(); CHECK(reader->NextArrayItem()); reader->Read(&node->stage_id); - std::vector int_list; CHECK(reader->NextArrayItem()); - reader->Read(&int_list); - ::tvm::Array<::tvm::Integer> fused_ids; - for (const auto& i : int_list) { - fused_ids.push_back(i); - } - node->fused_ids = fused_ids; + reader->Read(&node->target_stage_id); + CHECK(reader->NextArrayItem()); + reader->Read(&node->target_iter_id); data_ = std::move(node); } -void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { +void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArraySeperator(); writer->WriteString(record_prefix_str); writer->WriteArrayItem(stage_id); - writer->WriteArrayItem(IntArrayToVector(fused_ids)); + writer->WriteArrayItem(target_stage_id); + writer->WriteArrayItem(target_iter_id); } - -Iterator FuseStepNode::ApplyToState(State* state) const { +void ComputeAtStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; - size_t old_iter_size = static_cast(stage->iters.size()); - String new_name; - PrimExpr new_extent = 1; - IteratorKind new_iter_kind = IteratorKind::kSpecial; + // Remove the bound information of each iterator since they may not be accurate after + // compute at + Array new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); + } - for (size_t i = 0; i < fused_ids.size(); ++i) { - if (i > 0) { - CHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1); - } + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kIter, stage->attrs)); + // Update attach map + pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id); +} - if (i != fused_ids.size() - 1) { - const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) != - iter_to_attached_stage.end()) { - LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " - << "stages. State before fusion:\n" - << (*state); - } - } - - const Iterator& it = stage->iters[fused_ids[i]]; - new_name = new_name + it->name + "@"; - - if (it->range.defined() && new_extent.defined()) { - new_extent = new_extent * it->range->extent; - } else { - new_extent = PrimExpr(); - } +void ComputeAtStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id]; + stage.compute_at(target_stage, target_axis); - if (i == 0) { - new_iter_kind = it->iter_kind; - } else { - if (new_iter_kind != it->iter_kind) { - new_iter_kind = IteratorKind::kMixed; - } - } - } + stages->Set(stage_id, std::move(stage)); +} - Range range; - if (new_extent.defined()) { - range = Range::FromMinExtent(0, new_extent); - } - Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); - Array new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + fused_ids.front()); - new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1, - stage->iters.end()); +String ComputeAtStepNode::ApplyToPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name) + << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} - StateNode* pstate = state->CopyOnWrite(); - pstate->stages.Set(stage_id, - Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); +/********** Compute Inline **********/ +ComputeInlineStep::ComputeInlineStep(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + data_ = std::move(node); +} - // Two vectors are used to represent the iterator relation before and after fuse - // The original iterators in AttachMap will be updated with the new iterators - std::vector from_iters; - std::vector to_iters; - const size_t begin_id = fused_ids.front(), end_id = fused_ids.back(); - for (size_t i = 0; i < old_iter_size; ++i) { - if (i <= begin_id) { - continue; - } else if (i > end_id) { - // move forward - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, i - end_id + begin_id); - } else { - // move to the fused id - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, begin_id); - } - } - pstate->attach_map.UpdateIters(from_iters, to_iters); +ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) { + auto node = make_object(); + CHECK(reader->NextArrayItem()); + reader->Read(&node->stage_id); + data_ = std::move(node); +} - return new_it; +void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); } -IterVar FuseStepNode::ApplyToSchedule(Array* stages, - StageToAxesMap* stage_to_axes) const { - auto stage = (*stages)[stage_id]; - const Array& axes = stage_to_axes->at(stage); +void ComputeInlineStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; - Array to_fuse; - for (const auto& i : fused_ids) { - to_fuse.push_back(axes[i]); + // Check the validity of compute_inline + for (size_t i = 0; i < stage->iters.size(); ++i) { + CHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0) + << "Invalid compute_inline: There are some other stages that are attached to the " + << "target stage"; } - IterVar fused_axis; - stage.fuse(to_fuse, &fused_axis); - Array new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()); - new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); + StateNode* pstate = state->CopyOnWrite(); + auto new_stage = pstate->stages[stage_id]; + new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; + pstate->stages.Set(stage_id, std::move(new_stage)); + // Update attach map + pstate->attach_map.DeleteStage(stage_id); +} - stage_to_axes->Set(stage, std::move(new_axes)); +void ComputeInlineStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + auto stage = (*stages)[stage_id]; + stage.compute_inline(); stages->Set(stage_id, std::move(stage)); - return fused_axis; } -String FuseStepNode::ApplyToPythonAPI(Array* stages, - StageToAxesMap* stage_to_axes) const { - const auto& stage = (*stages)[stage_id]; - std::stringstream to_fuse; - - for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint); - if (i != fused_ids.size() - 1) { - to_fuse << ", "; - } - } - +String ComputeInlineStepNode::ApplyToPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; - const auto& fused = ApplyToSchedule(stages, stage_to_axes); - - ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse(" - << to_fuse.str() << ")\n"; - + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; + ApplyToSchedule(stages, stage_to_axes); return ss.str(); } -/********** Annotation **********/ -AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { - auto node = make_object(); +/********** Compute Root **********/ +ComputeRootStep::ComputeRootStep(int stage_id) { + auto node = make_object(); node->stage_id = stage_id; - node->iter_id = iter_id; - node->annotation = ann; data_ = std::move(node); } -AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) { - auto node = make_object(); +ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) { + auto node = make_object(); CHECK(reader->NextArrayItem()); reader->Read(&node->stage_id); - CHECK(reader->NextArrayItem()); - reader->Read(&node->iter_id); - CHECK(reader->NextArrayItem()); - int int_val; - reader->Read(&int_val); - node->annotation = IteratorAnnotation(int_val); data_ = std::move(node); } -void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { +void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArraySeperator(); writer->WriteString(record_prefix_str); writer->WriteArrayItem(stage_id); - writer->WriteArrayItem(iter_id); - writer->WriteArrayItem(static_cast(annotation)); } -Iterator AnnotationStepNode::ApplyToState(State* state) const { +void ComputeRootStepNode::ApplyToState(State* state) const { const Stage& stage = (*state)->stages[stage_id]; - Iterator it = stage->iters[iter_id]; - CHECK(it->annotation == IteratorAnnotation::kNone); - Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation); - Stage new_stage = stage; - new_stage.CopyOnWrite()->iters.Set(iter_id, new_it); - state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage)); - return new_it; -} - -void AnnotationStepNode::ApplyToSchedule(Array* stages, - StageToAxesMap* stage_to_axes) const { - te::Stage stage = (*stages)[stage_id]; - const Array& axes = (*stage_to_axes)[stage]; - - switch (annotation) { - case IteratorAnnotation::kUnroll: - stage.unroll(axes[iter_id]); - break; - case IteratorAnnotation::kVectorize: - stage.vectorize(axes[iter_id]); - break; - case IteratorAnnotation::kParallel: - stage.parallel(axes[iter_id]); - break; - case IteratorAnnotation::kVThread: - case IteratorAnnotation::kBlockX: - case IteratorAnnotation::kBlockY: - case IteratorAnnotation::kBlockZ: - case IteratorAnnotation::kThreadX: - case IteratorAnnotation::kThreadY: - case IteratorAnnotation::kThreadZ: - stage.bind(axes[iter_id], - te::thread_axis(Range(), IteratorAnnotationString[static_cast(annotation)])); - break; - case IteratorAnnotation::kNone: - break; - default: - LOG(FATAL) << "Invalid Annotation " << static_cast(annotation); - break; + // Remove the bound information of each iterator since they may not be accurate after + // compute root + Array new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); } + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kRoot, stage->attrs)); + // Update attach map + pstate->attach_map.DeleteStage(stage_id); +} + +void ComputeRootStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + auto stage = (*stages)[stage_id]; + stage.compute_root(); stages->Set(stage_id, std::move(stage)); } -String AnnotationStepNode::ApplyToPythonAPI(Array* stages, - StageToAxesMap* stage_to_axes) const { +String ComputeRootStepNode::ApplyToPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - const auto& iter = (*stage_to_axes)[stage][iter_id]; - - ss << "s[" << CleanName(stage->op->name) << "]."; - switch (annotation) { - case IteratorAnnotation::kUnroll: - ss << "unroll("; - break; - case IteratorAnnotation::kVectorize: - ss << "vectorize("; - break; - case IteratorAnnotation::kParallel: - ss << "parallel("; - break; - case IteratorAnnotation::kVThread: - case IteratorAnnotation::kBlockX: - case IteratorAnnotation::kBlockY: - case IteratorAnnotation::kBlockZ: - case IteratorAnnotation::kThreadX: - case IteratorAnnotation::kThreadY: - case IteratorAnnotation::kThreadZ: - ss << "bind("; - break; - case IteratorAnnotation::kNone: - break; - default: - LOG(FATAL) << "Invalid annotation " << static_cast(annotation); - break; - } - ss << CleanName(iter->var->name_hint); - switch (annotation) { - case IteratorAnnotation::kVThread: - case IteratorAnnotation::kBlockX: - case IteratorAnnotation::kBlockY: - case IteratorAnnotation::kBlockZ: - case IteratorAnnotation::kThreadX: - case IteratorAnnotation::kThreadY: - case IteratorAnnotation::kThreadZ: - ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast(annotation)] - << "\")"; - break; - default: - break; - } - ss << ")\n"; - + ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); } diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index d9465259a5f5..62d64c168a6d 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -20,23 +20,27 @@ /*! * \file auto_scheduler/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. The implementation of each step consists of 2 parts: - * - transform_step.cc: How each step interacts with TE and TE's schedule primitives - * - loop_state.cc: How each step updates LoopState + * step. * * \note To add a new transform step: * Take fuse step for example: - * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction - * function `FuseStep::FuseStep(...)` in `transform_steps.cc` - * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::ApplyToPythonAPI`. + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first + * construction function `FuseStep::FuseStep()` in `transform_steps.cc`. + * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::ApplyToPythonAPI()`. * - In these two functions you need to lower this step with tvm's te schedule API - * 3. Implement `State::fuse` and `State::DoFuseStep`. + * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`. * - In these two functions you need to incrementally update all data structures in State with - * CopyOnWrite style - * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works. - * 5. Add log record serialization support in `struct Handler>` - * in `record.cc`. - * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test. + * CopyOnWrite style. + * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and + * `StepApplyToPythonAPI`, make sure it works. + * 5. Log record serialization support: + * - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and + * output the record to it. + * - Add another construction function that takes a mutable JSONReader as input, this will get a + * step record from the reader and create the step. + * - Add the step implementation to `StepReadFromRecord`. + * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test, the test should + * at lease consists of two parts: the functional test and the record serialization test. */ #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ @@ -140,8 +144,6 @@ class Iterator : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); }; -class State; - /*! * \brief The base class of transformation steps. Each step has its corresponding tvm.te * schedule primitives. @@ -170,249 +172,231 @@ class Step : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); }; +// Forward declaration +class State; +class ComputeDAG; + /*! * \brief Read a step record from JSONReader and create the corresponding step. * \param reader The input JSONReader. */ Step StepReadFromRecord(dmlc::JSONReader* reader); -/*! \brief Reorder step that corresponds to te::Stage::reorder */ -class ReorderStepNode : public StepNode { - public: - /*! - * \brief The iterator ids after reorder. - * This array should specify the order of all iterators. - */ - Array after_ids; - - void WriteToRecord(dmlc::JSONWriter* writer) const final; - - /*! - * \brief Apply the current step to State - * \param state A mutable pointer to State. - */ - void ApplyToState(State* state) const; - - /*! - * \brief Apply the current step to tvm.schedule - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. - */ - void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - - /*! - * \brief Print step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. - * \return Python schedule code. - */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - - static constexpr const char* record_prefix_str = "RE"; - - static constexpr const char* _type_key = "auto_scheduler.ReorderStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); -}; +/*! + * \brief Apply the step to State. + * \param step The step to be applied to State. + * \param state A mutable pointer to State. + * \param dag The original ComputeDAG of this state. + * \return The iterator result after annotate. + */ +void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); /*! - * \brief Managed reference to ReorderStepNode. - * \sa ReorderStepNode + * \brief Apply the step to tvm.schedule. + * \param step The step to be applied to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. */ -class ReorderStep : public Step { - public: - /*! - * \brief The constructor. - * \param stage_id The index of the stage to be reordered. - * \param after_ids The expected indexes of the iterators after reorder. - */ - ReorderStep(int stage_id, const Array& after_ids); +void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); - /*! - * \brief The constructor used to read a step record from JSONReader and create the - * corresponding step. - * \param reader The input JSONReader. - */ - explicit ReorderStep(dmlc::JSONReader* reader); +/*! + * \brief Print the step as equivalent python schedule API. + * \param step The step to be applied to python API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ +String StepApplyToPythonAPI(const Step& step, Array* stages, + StageToAxesMap* stage_to_axes); - TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); -}; +/********** Primitives working on single stage **********/ -/*! \brief Compute at step that corresponds to te::Stage::compute_at */ -class ComputeAtStepNode : public StepNode { +/*! + * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. + * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) + */ +class AnnotationStepNode : public StepNode { public: - /*! \brief The index of stage that this step will compute at to. */ - int target_stage_id; - /*! \brief The index of iterator in target stage that this step will compute at to. */ - int target_iter_id; + /*! \brief The index of the iterator to add annotation. */ + int iter_id; + /*! \brief The annotation type of this step. */ + IteratorAnnotation annotation; void WriteToRecord(dmlc::JSONWriter* writer) const final; /*! - * \brief Apply the current step to State + * \brief Apply the current step to State. * \param state A mutable pointer to State. - * \note After compute_at, we need careful dependency analysis to compute the accurate bound - * information. However, it is relatively expensive and complicated, so we just fill "None" as - * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * \return The iterator result after annotate. */ - void ApplyToState(State* state) const; + Iterator ApplyToState(State* state) const; /*! - * \brief Apply the current step to tvm.schedule + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* record_prefix_str = "CA"; + static constexpr const char* record_prefix_str = "AN"; - static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); + static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); }; /*! - * \brief Managed reference to ComputeAtStepNode. - * \sa ComputeAtStepNode + * \brief Managed reference to AnnotationStepNode. + * \sa AnnotationStepNode */ -class ComputeAtStep : public Step { +class AnnotationStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute at. - * \param target_stage_id The index of stage that this step will compute at to. - * \param target_iter_id The index of iterator in target stage that this step will compute at to. + * \param stage_id The index of the stage to add annotation. + * \param iter_id The index of the iterator to add annotation. + * \param ann The annotation type of this step. */ - ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. * \param reader The input JSONReader. */ - explicit ComputeAtStep(dmlc::JSONReader* reader); + explicit AnnotationStep(dmlc::JSONReader* reader); - TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); + TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); }; -/*! \brief Compute root step that corresponds to te::Stage::compute_root */ -class ComputeRootStepNode : public StepNode { +/*! \brief Fuse step that corresponds to te::Stage::fuse */ +class FuseStepNode : public StepNode { public: + /*! \brief The ids of iterators to fuse. */ + Array fused_ids; + void WriteToRecord(dmlc::JSONWriter* writer) const final; /*! - * \brief Apply the current step to State + * \brief Apply the current step to State. * \param state A mutable pointer to State. - * \note After compute_at, we need careful dependency analysis to compute the accurate bound - * information. However, it is relatively expensive and complicated, so we just fill "None" as - * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * \return The iterator result after fuse. + * \note If the iterators to be fused have stages attached at them(by compute_at), the fused + * result will become the new attach point. */ - void ApplyToState(State* state) const; + Iterator ApplyToState(State* state) const; /*! - * \brief Apply the current step to tvm.schedule + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. */ - void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* record_prefix_str = "CR"; + static constexpr const char* record_prefix_str = "FU"; - static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); + static constexpr const char* _type_key = "auto_scheduler.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); }; /*! - * \brief Managed reference to ComputeRootStepNode. - * \sa ComputeRootStepNode + * \brief Managed reference to FuseStepNode. + * \sa FuseStepNode */ -class ComputeRootStep : public Step { +class FuseStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute root + * \param stage_id The index of the stage to be fused. + * \param fused_ids The index of the iterators to be fused. */ - explicit ComputeRootStep(int stage_id); + FuseStep(int stage_id, const Array& fused_ids); /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. * \param reader The input JSONReader. */ - explicit ComputeRootStep(dmlc::JSONReader* reader); + explicit FuseStep(dmlc::JSONReader* reader); - TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); + TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; -/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ -class ComputeInlineStepNode : public StepNode { +/*! \brief Reorder step that corresponds to te::Stage::reorder */ +class ReorderStepNode : public StepNode { public: + /*! + * \brief The iterator ids after reorder. + * This array should specify the order of all iterators. + */ + Array after_ids; + void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! - * \brief Apply the current step to State + * \brief Apply the current step to State. * \param state A mutable pointer to State. */ void ApplyToState(State* state) const; /*! - * \brief Apply the current step to tvm.schedule + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. - * \return The iterator result after fuse. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* record_prefix_str = "CI"; + static constexpr const char* record_prefix_str = "RE"; - static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); + static constexpr const char* _type_key = "auto_scheduler.ReorderStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; /*! - * \brief Managed reference to ComputeInlineStepNode. - * \sa ComputeInlineStepNode + * \brief Managed reference to ReorderStepNode. + * \sa ReorderStepNode */ -class ComputeInlineStep : public Step { +class ReorderStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute inline. + * \param stage_id The index of the stage to be reordered. + * \param after_ids The expected indexes of the iterators after reorder. */ - explicit ComputeInlineStep(int stage_id); + ReorderStep(int stage_id, const Array& after_ids); /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. * \param reader The input JSONReader. */ - explicit ComputeInlineStep(dmlc::JSONReader* reader); + explicit ReorderStep(dmlc::JSONReader* reader); - TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); + TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; /*! @@ -434,8 +418,9 @@ class SplitStepNode : public StepNode { bool inner_to_outer; void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! - * \brief Apply the current step to State + * \brief Apply the current step to State. * \param state A mutable pointer to State. * \return The iterator results after split. * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner @@ -444,7 +429,7 @@ class SplitStepNode : public StepNode { Array ApplyToState(State* state) const; /*! - * \brief Apply the current step to tvm.schedule + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator results after split. @@ -453,7 +438,7 @@ class SplitStepNode : public StepNode { StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. @@ -493,88 +478,145 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::fuse */ -class FuseStepNode : public StepNode { +/********** Primitives working on multiple stages **********/ + +/*! \brief Compute at step that corresponds to te::Stage::compute_at */ +class ComputeAtStepNode : public StepNode { public: - /*! \brief The ids of iterators to fuse. */ - Array fused_ids; + /*! \brief The index of stage that this step will compute at to. */ + int target_stage_id; + /*! \brief The index of iterator in target stage that this step will compute at to. */ + int target_iter_id; void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! - * \brief Apply the current step to State + * \brief Apply the current step to State. * \param state A mutable pointer to State. - * \return The iterator result after fuse. - * \note If the iterators to be fused have stages attached at them(by compute_at), the fused - * result will become the new attach point. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - Iterator ApplyToState(State* state) const; + void ApplyToState(State* state) const; /*! - * \brief Apply the current step to tvm.schedule + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. - * \return The iterator result after fuse. */ - tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* record_prefix_str = "FU"; + static constexpr const char* record_prefix_str = "CA"; - static constexpr const char* _type_key = "auto_scheduler.FuseStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); + static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); }; /*! - * \brief Managed reference to FuseStepNode. - * \sa FuseStepNode + * \brief Managed reference to ComputeAtStepNode. + * \sa ComputeAtStepNode */ -class FuseStep : public Step { +class ComputeAtStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be fused. - * \param fused_ids The index of the iterators to be fused. + * \param stage_id The index of the stage to be compute at. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter_id The index of iterator in target stage that this step will compute at to. */ - FuseStep(int stage_id, const Array& fused_ids); + ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. * \param reader The input JSONReader. */ - explicit FuseStep(dmlc::JSONReader* reader); + explicit ComputeAtStep(dmlc::JSONReader* reader); - TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); + TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); +}; + +/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ +class ComputeInlineStepNode : public StepNode { + public: + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator result after fuse. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "CI"; + + static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); }; /*! - * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. - * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) + * \brief Managed reference to ComputeInlineStepNode. + * \sa ComputeInlineStepNode */ -class AnnotationStepNode : public StepNode { +class ComputeInlineStep : public Step { public: - /*! \brief The index of the iterator to add annotation. */ - int iter_id; - /*! \brief The annotation type of this step. */ - IteratorAnnotation annotation; + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be compute inline. + */ + explicit ComputeInlineStep(int stage_id); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ComputeInlineStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); +}; +/*! \brief Compute root step that corresponds to te::Stage::compute_root */ +class ComputeRootStepNode : public StepNode { + public: void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! - * \brief Apply the current step to State + * \brief Apply the current step to State. * \param state A mutable pointer to State. - * \return The iterator result after annotate. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - Iterator ApplyToState(State* state) const; + void ApplyToState(State* state) const; /*! - * \brief Apply the current step to tvm.schedule + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. @@ -582,41 +624,39 @@ class AnnotationStepNode : public StepNode { void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* record_prefix_str = "AN"; + static constexpr const char* record_prefix_str = "CR"; - static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); + static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); }; /*! - * \brief Managed reference to AnnotationStepNode. - * \sa AnnotationStepNode + * \brief Managed reference to ComputeRootStepNode. + * \sa ComputeRootStepNode */ -class AnnotationStep : public Step { +class ComputeRootStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to add annotation. - * \param iter_id The index of the iterator to add annotation. - * \param ann The annotation type of this step. + * \param stage_id The index of the stage to be compute root */ - AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); + explicit ComputeRootStep(int stage_id); /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. * \param reader The input JSONReader. */ - explicit AnnotationStep(dmlc::JSONReader* reader); + explicit ComputeRootStep(dmlc::JSONReader* reader); - TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); + TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; } // namespace auto_scheduler diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 23b738a8478f..5fb5349e83c8 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -60,7 +60,7 @@ def test_record(): s.parallel(C, s[C].iters[0]) # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here) s.bind(C, s[C].iters[1], "blockIdx.x") - s.bind(C, s[C].iters[2], "threadIdx.y") + s.bind(C, s[C].iters[2], "threadIdx.z") s.bind(C, s[C].iters[3], "vthread") # Unroll s.unroll(C, s[C].iters[4]) From 0cb3dbd0fe2b7ed007bec3dfcafa2f824219c161 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 20 Jul 2020 14:19:29 +0800 Subject: [PATCH 13/14] Update the order of state api --- python/tvm/auto_scheduler/loop_state.py | 260 ++++++++++++------------ src/auto_scheduler/loop_state.cc | 184 ++++++++--------- src/auto_scheduler/loop_state.h | 119 +++++------ 3 files changed, 283 insertions(+), 280 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 238851ad6d42..ab041cf4a43d 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -126,108 +126,100 @@ def stage_ops(self): """ return [stage.op for stage in self.stages] - def reorder(self, stage, order): - """ Schedule primitive corresponds to te.reorder. - - Parameters - ---------- - stage : Union[int, Operation, Tensor] - The Stage to be reordered, which can be specified by the integer index, Operation, - or output tensor of the stage. - order : List[Iterator] - Iterators in the expected order. - """ - self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage), - order) - - def compute_at(self, stage, target_stage, target_iter): - """ Schedule primitive corresponds to te.compute_at. + def bind(self, stage, iterator, thread_name): + """ Schedule primitive corresponds to te.bind. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute at, which can be specified by the integer index, Operation, - or output tensor of the stage. - target_stage : Union[int, Operation, Tensor] - The target stage of compute_at, which can be specified by the integer index, Operation, + The Stage to be binded, which can be specified by the integer index, Operation, or output tensor of the stage. - target_iter : Iterator - The target Iterator of compute_at. + iterator : Iterator + The iterator to be binded. + thread_name : str + The thread type to be binded. Candidates: + - vthread + - blockIdx.x + - threadIdx.x + - blockIdx.y + - threadIdx.y + - blockIdx.z + - threadIdx.z - Notes - ----- - After compute_at, we need careful dependency analysis to compute the accurate bound - information. However, it is relatively expensive and complicated, so we just fill "None" - as bound for the newly created iterators. - Call ComputeDAG::InferBound on the returned state to get the complete bound information. + Returns + ------- + res_it : Iterator + The binded Iterator. """ - self.state_object = _ffi_api.StateComputeAt(self.state_object, - self._resolve_stage_id(stage), - self._resolve_stage_id(target_stage), - target_iter) + if not thread_name in State.ANNOTATION_TRANS_TABLE.keys(): + raise ValueError("Invalid thread_name: ", thread_name) - def compute_root(self, stage): - """ Schedule primitive corresponds to te.compute_root. + self.state_object, res = _ffi_api.StateBind(self.state_object, + self._resolve_stage_id(stage), iterator, + State.ANNOTATION_TRANS_TABLE[thread_name]) + return res + + def parallel(self, stage, iterator): + """ Schedule primitive corresponds to te.parallel. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute root, which can be specified by the integer index, Operation, + The Stage to be paralleled, which can be specified by the integer index, Operation, or output tensor of the stage. + iterator : Iterator + The iterator to be paralleled. - Notes - ----- - After compute_root, we need careful dependency analysis to compute the accurate bound - information. However, it is relatively expensive and complicated, so we just fill "None" - as bound for the newly created iterators. - Call ComputeDAG::InferBound on the returned state to get the complete bound information. + Returns + ------- + res_it : Iterator + The paralleled Iterator. """ - self.state_object = _ffi_api.StateComputeRoot(self.state_object, - self._resolve_stage_id(stage)) + self.state_object, res = _ffi_api.StateParallel(self.state_object, + self._resolve_stage_id(stage), iterator) + return res - def compute_inline(self, stage): - """ Schedule primitive corresponds to te.compute_inline. + def unroll(self, stage, iterator, max_unroll=None): + """ Schedule primitive corresponds to te.unroll. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute inlined, which can be specified by the integer index, Operation, + The Stage to be unrolled, which can be specified by the integer index, Operation, or output tensor of the stage. - """ - self.state_object = _ffi_api.StateComputeInline(self.state_object, - self._resolve_stage_id(stage)) + iterator : Iterator + The iterator to be unrolled. + max_unroll : Optional[int] + The max unroll limit. Iterator with extent larger than this limit will be skipped. - def split(self, stage, iterator, lengths, inner_to_outer=True): - """ Schedule primitive corresponds to te.split. + Returns + ------- + res_it : Iterator + The unrolled Iterator. + """ + self.state_object, res = _ffi_api.StateUnroll(self.state_object, + self._resolve_stage_id(stage), iterator, + max_unroll if max_unroll else -1) + return res - This API supports multiple split factors. (e.g. with 2 split factors, the original iterator - will be split to 3 parts, use `inner_to_outer` to control the split order) + def vectorize(self, stage, iterator): + """ Schedule primitive corresponds to te.vectorize. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be split, which can be specified by the integer index, Operation, + The Stage to be vectorized, which can be specified by the integer index, Operation, or output tensor of the stage. iterator : Iterator - The iterator to be split. - lengths: List[int] - The multiple split factors. Can be None to be filled by search policy. - inner_to_outer: boolean = True - Whether the factor go from inner to outer, or from outer to inner. + The iterator to be vectorized. Returns ------- - res_its : List[Iterator] - The splitted new Iterators. - - Notes - ----- - If we do split on an iterator which has stages attached at it(by compute_at), the inner - most iterator of split results will become the new attach point. + res_it : Iterator + The vectorized Iterator. """ - self.state_object, res = _ffi_api.StateSplit(self.state_object, - self._resolve_stage_id(stage), - iterator, lengths, inner_to_outer) + self.state_object, res = _ffi_api.StateVectorize(self.state_object, + self._resolve_stage_id(stage), iterator) return res def fuse(self, stage, iters): @@ -255,101 +247,109 @@ def fuse(self, stage, iters): self._resolve_stage_id(stage), iters) return res - def vectorize(self, stage, iterator): - """ Schedule primitive corresponds to te.vectorize. + def reorder(self, stage, order): + """ Schedule primitive corresponds to te.reorder. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be vectorized, which can be specified by the integer index, Operation, + The Stage to be reordered, which can be specified by the integer index, Operation, or output tensor of the stage. - iterator : Iterator - The iterator to be vectorized. - - Returns - ------- - res_it : Iterator - The vectorized Iterator. + order : List[Iterator] + Iterators in the expected order. """ - self.state_object, res = _ffi_api.StateVectorize(self.state_object, - self._resolve_stage_id(stage), iterator) - return res + self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage), + order) - def parallel(self, stage, iterator): - """ Schedule primitive corresponds to te.parallel. + def split(self, stage, iterator, lengths, inner_to_outer=True): + """ Schedule primitive corresponds to te.split. + + This API supports multiple split factors. (e.g. with 2 split factors, the original iterator + will be split to 3 parts, use `inner_to_outer` to control the split order) Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be paralleled, which can be specified by the integer index, Operation, + The Stage to be split, which can be specified by the integer index, Operation, or output tensor of the stage. iterator : Iterator - The iterator to be paralleled. + The iterator to be split. + lengths: List[int] + The multiple split factors. Can be None to be filled by search policy. + inner_to_outer: boolean = True + Whether the factor go from inner to outer, or from outer to inner. Returns ------- - res_it : Iterator - The paralleled Iterator. + res_its : List[Iterator] + The splitted new Iterators. + + Notes + ----- + If we do split on an iterator which has stages attached at it(by compute_at), the inner + most iterator of split results will become the new attach point. """ - self.state_object, res = _ffi_api.StateParallel(self.state_object, - self._resolve_stage_id(stage), iterator) + self.state_object, res = _ffi_api.StateSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, lengths, inner_to_outer) return res - def unroll(self, stage, iterator, max_unroll=None): - """ Schedule primitive corresponds to te.unroll. + def compute_at(self, stage, target_stage, target_iter): + """ Schedule primitive corresponds to te.compute_at. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be unrolled, which can be specified by the integer index, Operation, + The Stage to be compute at, which can be specified by the integer index, Operation, or output tensor of the stage. - iterator : Iterator - The iterator to be unrolled. - max_unroll : Optional[int] - The max unroll limit. Iterator with extent larger than this limit will be skipped. + target_stage : Union[int, Operation, Tensor] + The target stage of compute_at, which can be specified by the integer index, Operation, + or output tensor of the stage. + target_iter : Iterator + The target Iterator of compute_at. - Returns - ------- - res_it : Iterator - The unrolled Iterator. + Notes + ----- + After compute_at, we need careful dependency analysis to compute the accurate bound + information. However, it is relatively expensive and complicated, so we just fill "None" + as bound for the newly created iterators. + Call ComputeDAG::InferBound on the returned state to get the complete bound information. """ - self.state_object, res = _ffi_api.StateUnroll(self.state_object, - self._resolve_stage_id(stage), iterator, - max_unroll if max_unroll else -1) - return res + self.state_object = _ffi_api.StateComputeAt(self.state_object, + self._resolve_stage_id(stage), + self._resolve_stage_id(target_stage), + target_iter) - def bind(self, stage, iterator, thread_name): - """ Schedule primitive corresponds to te.bind. + def compute_inline(self, stage): + """ Schedule primitive corresponds to te.compute_inline. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be binded, which can be specified by the integer index, Operation, + The Stage to be compute inlined, which can be specified by the integer index, Operation, or output tensor of the stage. - iterator : Iterator - The iterator to be binded. - thread_name : str - The thread type to be binded. Candidates: - - vthread - - blockIdx.x - - threadIdx.x - - blockIdx.y - - threadIdx.y - - blockIdx.z - - threadIdx.z - - Returns - ------- - res_it : Iterator - The binded Iterator. """ - if not thread_name in State.ANNOTATION_TRANS_TABLE.keys(): - raise ValueError("Invalid thread_name: ", thread_name) + self.state_object = _ffi_api.StateComputeInline(self.state_object, + self._resolve_stage_id(stage)) - self.state_object, res = _ffi_api.StateBind(self.state_object, - self._resolve_stage_id(stage), iterator, - State.ANNOTATION_TRANS_TABLE[thread_name]) - return res + def compute_root(self, stage): + """ Schedule primitive corresponds to te.compute_root. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be compute root, which can be specified by the integer index, Operation, + or output tensor of the stage. + + Notes + ----- + After compute_root, we need careful dependency analysis to compute the accurate bound + information. However, it is relatively expensive and complicated, so we just fill "None" + as bound for the newly created iterators. + Call ComputeDAG::InferBound on the returned state to get the complete bound information. + """ + self.state_object = _ffi_api.StateComputeRoot(self.state_object, + self._resolve_stage_id(stage)) def copy(self): """ Do deep copy of this State. """ diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 8634a75a0bac..bfe547864ed1 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -163,43 +163,47 @@ State::State(const Array& ops) { } /********** Schedule primitives apis for state **********/ -void State::reorder(int stage_id, const Array& order) { +Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; - CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " - << "should be specified"; - Array after_ids; - GetIndices(stage->iters, order, &after_ids); - ReorderStep step = ReorderStep(stage_id, after_ids); + if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) { + LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, " + << "kThreadX, kThreadY, kBlockZ, kThreadZ"; + } + AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); - step->ApplyToState(this); + return step->ApplyToState(this); } -void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { - const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = - ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); +Iterator State::parallel(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel); CopyOnWrite()->transform_steps.push_back(step); - step->ApplyToState(this); + return step->ApplyToState(this); } -void State::compute_root(int stage_id) { - ComputeRootStep step = ComputeRootStep(stage_id); - CopyOnWrite()->transform_steps.push_back(step); - step->ApplyToState(this); -} +Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { + const Stage& stage = operator->()->stages[stage_id]; -void State::compute_inline(int stage_id) { - ComputeInlineStep step = ComputeInlineStep(stage_id); + // Don't unroll if the extent is larger than max_unroll + if (max_unroll != -1 && it->range.defined()) { + if (auto imm = it->range->extent.as()) { + if (imm->value > max_unroll) { + return it; + } + } + } + + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); CopyOnWrite()->transform_steps.push_back(step); - step->ApplyToState(this); + return step->ApplyToState(this); } -Array State::split(int stage_id, const Iterator& it, - const Array>& lengths, bool inner_to_outer) { +Iterator State::vectorize(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - SplitStep step = - SplitStep(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } @@ -213,49 +217,45 @@ Iterator State::fuse(int stage_id, const Array& iters) { return step->ApplyToState(this); } -Iterator State::vectorize(int stage_id, const Iterator& it) { +void State::reorder(int stage_id, const Array& order) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + << "should be specified"; + Array after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); - return step->ApplyToState(this); + step->ApplyToState(this); } -Iterator State::parallel(int stage_id, const Iterator& it) { +Array State::split(int stage_id, const Iterator& it, + const Array>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel); + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } -Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { - const Stage& stage = operator->()->stages[stage_id]; - - // Don't unroll if the extent is larger than max_unroll - if (max_unroll != -1 && it->range.defined()) { - if (auto imm = it->range->extent.as()) { - if (imm->value > max_unroll) { - return it; - } - } - } +void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { + const Stage& target_stage = operator->()->stages[target_stage_id]; + ComputeAtStep step = + ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); + CopyOnWrite()->transform_steps.push_back(step); + step->ApplyToState(this); +} - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); +void State::compute_inline(int stage_id) { + ComputeInlineStep step = ComputeInlineStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); - return step->ApplyToState(this); + step->ApplyToState(this); } -Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { - const Stage& stage = operator->()->stages[stage_id]; - if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) { - LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, " - << "kThreadX, kThreadY, kBlockZ, kThreadZ"; - } - AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); +void State::compute_root(int stage_id) { + ComputeRootStep step = ComputeRootStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); - return step->ApplyToState(this); + step->ApplyToState(this); } void State::ApplySteps(const ComputeDAG& dag) { @@ -368,35 +368,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /********** State interface API for ffi **********/ -TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") - .set_body_typed([](State state, int stage_id, const Array& order) { - state.reorder(stage_id, order); - return state; - }); - -TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") - .set_body_typed([](State state, int stage_id, int target_stage_id, - const Iterator& target_iter) { - state.compute_at(stage_id, target_stage_id, target_iter); - return state; +TVM_REGISTER_GLOBAL("auto_scheduler.StateBind") + .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) { + const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot") - .set_body_typed([](State state, int stage_id) { - state.compute_root(stage_id); - return state; +TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel") + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline") - .set_body_typed([](State state, int stage_id) { - state.compute_inline(stage_id); - return state; +TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll") + .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array>& lengths, bool inner_to_outer) { - const auto& res = state.split(stage_id, it, lengths, inner_to_outer); +TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); return Array{state, res}; }); @@ -406,28 +398,36 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") - .set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.vectorize(stage_id, it); - return Array{state, res}; +TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") + .set_body_typed([](State state, int stage_id, const Array& order) { + state.reorder(stage_id, order); + return state; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel") - .set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.parallel(stage_id, it); +TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array>& lengths, bool inner_to_outer) { + const auto& res = state.split(stage_id, it, lengths, inner_to_outer); return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll") - .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) { - const auto& res = state.unroll(stage_id, it, max_unroll); - return Array{state, res}; +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") + .set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateBind") - .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) { - const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type)); - return Array{state, res}; +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline") + .set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot") + .set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 2666b1082f93..4d6477b92b0f 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -281,39 +281,55 @@ class State : public ObjectRef { */ void ApplySteps(const ComputeDAG& dag); - /* Step APIs for State. */ + /********** Step APIs working on single stage **********/ /*! - * \brief Schedule primitive corresponds to te.reorder. - * \param stage_id The index of the stage to be reordered. - * \param order The expected iterator order. + * \brief Schedule primitive corresponds to te.bind. + * \param stage_id The index of the stage to be binded. + * \param it The iterator to be binded. + * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as + * this input. + * \return The iterator result after binded. */ - void reorder(int stage_id, const Array& order); + Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); /*! - * \brief Schedule primitive corresponds to te.compute_at. - * \param stage_id The index of the stage to be reordered. - * \param target_stage_id The index of stage that this step will compute at to. - * \param target_iter The iterator in target stage that this step will compute at to. - * \note After compute_at, we need careful dependency analysis to compute the accurate bound - * information. However, it is relatively expensive and complicated, so we just fill "None" as - * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * \brief Schedule primitive corresponds to te.parallel. + * \param stage_id The index of the stage to be paralleled. + * \param it The iterator to be paralleled. + * \return The iterator result after parallel. */ - void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + Iterator parallel(int stage_id, const Iterator& it); /*! - * \brief Schedule primitive corresponds to te.compute_root. - * \param stage_id The index of the stage to be reordered. - * \note After compute_root, we need careful dependency analysis to compute the accurate bound - * information. However, it is relatively expensive and complicated, so we just fill "None" as - * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * \brief Schedule primitive corresponds to te.unroll. + * \param stage_id The index of the stage to be unrolled. + * \param it The iterator to be unrolled. + * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be + * skipped. + * \return The iterator result after unrolled. */ - void compute_root(int stage_id); + Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); /*! - * \brief Schedule primitive corresponds to te.compute_inline. + * \brief Schedule primitive corresponds to te.vectorize. + * \param stage_id The index of the stage to be vectorized. + * \param it The iterator to be vectorized. + * \return The iterator result after vectorize. + */ + Iterator vectorize(int stage_id, const Iterator& it); + /*! + * \brief Schedule primitive corresponds to te.fuse. + * \param stage_id The index of the stage to be fused. + * \param iters The iterators to be fused. + * \return The iterator result after fuse. + * \note If the iterators to be fused have stages attached at them(by compute_at), the fused + * result will become the new attach point. + */ + Iterator fuse(int stage_id, const Array& iters); + /*! + * \brief Schedule primitive corresponds to te.reorder. * \param stage_id The index of the stage to be reordered. + * \param order The expected iterator order. */ - void compute_inline(int stage_id); + void reorder(int stage_id, const Array& order); /*! * \brief Schedule primitive corresponds to te.split. * \param stage_id The index of the stage to be split. @@ -326,47 +342,34 @@ class State : public ObjectRef { */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); + + /********** Step APIs working on multiple stages **********/ + /*! - * \brief Schedule primitive corresponds to te.fuse. - * \param stage_id The index of the stage to be fused. - * \param iters The iterators to be fused. - * \return The iterator result after fuse. - * \note If the iterators to be fused have stages attached at them(by compute_at), the fused - * result will become the new attach point. - */ - Iterator fuse(int stage_id, const Array& iters); - /*! - * \brief Schedule primitive corresponds to te.vectorize. - * \param stage_id The index of the stage to be vectorized. - * \param it The iterator to be vectorized. - * \return The iterator result after vectorize. - */ - Iterator vectorize(int stage_id, const Iterator& it); - /*! - * \brief Schedule primitive corresponds to te.parallel. - * \param stage_id The index of the stage to be paralleled. - * \param it The iterator to be paralleled. - * \return The iterator result after parallel. + * \brief Schedule primitive corresponds to te.compute_at. + * \param stage_id The index of the stage to be reordered. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter The iterator in target stage that this step will compute at to. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - Iterator parallel(int stage_id, const Iterator& it); + void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! - * \brief Schedule primitive corresponds to te.unroll. - * \param stage_id The index of the stage to be unrolled. - * \param it The iterator to be unrolled. - * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be - * skipped. - * \return The iterator result after unrolled. + * \brief Schedule primitive corresponds to te.compute_inline. + * \param stage_id The index of the stage to be reordered. */ - Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); + void compute_inline(int stage_id); /*! - * \brief Schedule primitive corresponds to te.bind. - * \param stage_id The index of the stage to be binded. - * \param it The iterator to be binded. - * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as - * this input. - * \return The iterator result after binded. + * \brief Schedule primitive corresponds to te.compute_root. + * \param stage_id The index of the stage to be reordered. + * \note After compute_root, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); + void compute_root(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); From d217fd4399890a391b76b90cb7be9d3842b0f697 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 21 Jul 2020 09:59:49 +0800 Subject: [PATCH 14/14] Update --- src/auto_scheduler/compute_dag.cc | 4 +- src/auto_scheduler/measure_record.cc | 67 ++++++++++++-------- src/auto_scheduler/transform_step.cc | 92 ++++++++++++++++++---------- src/auto_scheduler/transform_step.h | 20 +++--- 4 files changed, 114 insertions(+), 69 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index fe7fe79d170b..d81dff66d402 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -314,9 +314,9 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } - // Call each step's ApplyToPythonAPI method + // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << StepApplyToPythonAPI(step, &stages, &stage_to_axes); + ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes); } return ss.str(); diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 889cd5b9e0cd..39f9ad86c958 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -74,12 +74,14 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::auto_scheduler::Step>* data) { + bool s; reader->BeginArray(); data->clear(); while (reader->NextArrayItem()) { reader->BeginArray(); data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader)); - CHECK(!reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(!s); } } }; @@ -93,12 +95,16 @@ struct Handler<::tvm::auto_scheduler::StateNode> { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) { + bool s; reader->BeginArray(); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->stages); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->transform_steps); - CHECK(!reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(!s); } }; @@ -112,15 +118,19 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { - std::string target_str; + bool s; + std::string str_value; reader->BeginArray(); - CHECK(reader->NextArrayItem()); - reader->Read(&target_str); - data->workload_key = std::move(target_str); - CHECK(reader->NextArrayItem()); - reader->Read(&target_str); - data->target = ::tvm::Target::Create(target_str); - CHECK(!reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&str_value); + data->workload_key = std::move(str_value); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&str_value); + data->target = ::tvm::Target::Create(str_value); + s = reader->NextArrayItem(); + CHECK(!s); } }; @@ -138,12 +148,16 @@ struct Handler<::tvm::auto_scheduler::MeasureInputNode> { auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>(); state_node->concrete = true; + bool s; reader->BeginArray(); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(task_node.get()); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(state_node.get()); - CHECK(!reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(!s); data->task = ::tvm::auto_scheduler::SearchTask(task_node); data->state = ::tvm::auto_scheduler::State(state_node); @@ -170,22 +184,27 @@ struct Handler<::tvm::auto_scheduler::MeasureResultNode> { } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureResultNode* data) { - std::vector tmp; - + std::vector double_list; + bool s; reader->BeginArray(); - CHECK(reader->NextArrayItem()); - reader->Read(&tmp); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&double_list); data->costs.clear(); - for (const auto& i : tmp) { + for (const auto& i : double_list) { data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); } - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->error_no); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->all_cost); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->timestamp); - CHECK(!reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(!s); } }; diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index c15c377aa39b..6c672a5215f2 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -55,7 +55,9 @@ const char* IteratorAnnotationString[] = { Step StepReadFromRecord(dmlc::JSONReader* reader) { std::string name; - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&name); if (name == AnnotationStepNode::record_prefix_str) { return AnnotationStep(reader); @@ -118,22 +120,22 @@ void StepApplyToSchedule(const Step& step, Array* stages, } } -String StepApplyToPythonAPI(const Step& step, Array* stages, +String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes) { if (auto ps = step.as()) { - return ps->ApplyToPythonAPI(stages, stage_to_axes); + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { - return ps->ApplyToPythonAPI(stages, stage_to_axes); + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { - return ps->ApplyToPythonAPI(stages, stage_to_axes); + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { - return ps->ApplyToPythonAPI(stages, stage_to_axes); + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { - return ps->ApplyToPythonAPI(stages, stage_to_axes); + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { - return ps->ApplyToPythonAPI(stages, stage_to_axes); + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { - return ps->ApplyToPythonAPI(stages, stage_to_axes); + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -153,11 +155,15 @@ AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) { auto node = make_object(); - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->stage_id); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->iter_id); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); int int_val; reader->Read(&int_val); node->annotation = IteratorAnnotation(int_val); @@ -219,7 +225,7 @@ void AnnotationStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String AnnotationStepNode::ApplyToPythonAPI(Array* stages, +String AnnotationStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -285,10 +291,13 @@ FuseStep::FuseStep(int stage_id, const Array& fused_ids) { FuseStep::FuseStep(dmlc::JSONReader* reader) { auto node = make_object(); - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->stage_id); std::vector int_list; - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&int_list); ::tvm::Array<::tvm::Integer> fused_ids; for (const auto& i : int_list) { @@ -406,7 +415,7 @@ IterVar FuseStepNode::ApplyToSchedule(Array* stages, return fused_axis; } -String FuseStepNode::ApplyToPythonAPI(Array* stages, +String FuseStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; std::stringstream to_fuse; @@ -440,9 +449,12 @@ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { ReorderStep::ReorderStep(dmlc::JSONReader* reader) { auto node = make_object(); - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->stage_id); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); std::vector int_list; reader->Read(&int_list); ::tvm::Array<::tvm::Integer> after_ids; @@ -487,7 +499,7 @@ void ReorderStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ReorderStepNode::ApplyToPythonAPI(Array* stages, +String ReorderStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; std::stringstream ss; @@ -676,17 +688,22 @@ SplitStep::SplitStep(int stage_id, int iter_id, Optional extent, SplitStep::SplitStep(dmlc::JSONReader* reader) { auto node = make_object(); - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->stage_id); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->iter_id); int int_val; - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&int_val); if (int_val) { node->extent = Integer(int_val); } - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); std::vector int_list; reader->Read(&int_list); ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths; @@ -694,7 +711,8 @@ SplitStep::SplitStep(dmlc::JSONReader* reader) { lengths.push_back(::tvm::Integer(i)); } node->lengths = lengths; - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->inner_to_outer); data_ = std::move(node); } @@ -718,7 +736,7 @@ Array SplitStepNode::ApplyToSchedule(Array* stages, return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -String SplitStepNode::ApplyToPythonAPI(Array* stages, +String SplitStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } @@ -736,11 +754,15 @@ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_ ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) { auto node = make_object(); - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->stage_id); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->target_stage_id); - CHECK(reader->NextArrayItem()); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->target_iter_id); data_ = std::move(node); } @@ -779,7 +801,7 @@ void ComputeAtStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ComputeAtStepNode::ApplyToPythonAPI(Array* stages, +String ComputeAtStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -799,7 +821,9 @@ ComputeInlineStep::ComputeInlineStep(int stage_id) { ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) { auto node = make_object(); - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->stage_id); data_ = std::move(node); } @@ -835,7 +859,7 @@ void ComputeInlineStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ComputeInlineStepNode::ApplyToPythonAPI(Array* stages, +String ComputeInlineStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -853,7 +877,9 @@ ComputeRootStep::ComputeRootStep(int stage_id) { ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) { auto node = make_object(); - CHECK(reader->NextArrayItem()); + bool s; + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&node->stage_id); data_ = std::move(node); } @@ -888,7 +914,7 @@ void ComputeRootStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ComputeRootStepNode::ApplyToPythonAPI(Array* stages, +String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 62d64c168a6d..ce3ca50ffae6 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -26,13 +26,13 @@ * Take fuse step for example: * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first * construction function `FuseStep::FuseStep()` in `transform_steps.cc`. - * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::ApplyToPythonAPI()`. + * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::PrintAsPythonAPI()`. * - In these two functions you need to lower this step with tvm's te schedule API * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`. * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style. * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and - * `StepApplyToPythonAPI`, make sure it works. + * `StepPrintAsPythonAPI`, make sure it works. * 5. Log record serialization support: * - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and * output the record to it. @@ -206,7 +206,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ -String StepApplyToPythonAPI(const Step& step, Array* stages, +String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); /********** Primitives working on single stage **********/ @@ -244,7 +244,7 @@ class AnnotationStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* record_prefix_str = "AN"; @@ -307,7 +307,7 @@ class FuseStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* record_prefix_str = "FU"; @@ -368,7 +368,7 @@ class ReorderStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* record_prefix_str = "RE"; @@ -443,7 +443,7 @@ class SplitStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* record_prefix_str = "SP"; @@ -513,7 +513,7 @@ class ComputeAtStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* record_prefix_str = "CA"; @@ -570,7 +570,7 @@ class ComputeInlineStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* record_prefix_str = "CI"; @@ -629,7 +629,7 @@ class ComputeRootStepNode : public StepNode { * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ - String ApplyToPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* record_prefix_str = "CR";