diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 6e67fef0f283..da0d196f4912 100755 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -194,6 +194,24 @@ class ComputeDAGNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); }; +/*! + * \brief Options for applying layout rewrite. + * This is an optimization to rewrite the layout of input tensors according to the schedule we get. + */ +enum class LayoutRewriteOption : int { + /*! \brief Do not process layout rewrite. */ + NoRewrite = 0, + /*! \brief Insert layout transformation stages for input placeholders in the compute DAG */ + InsertTransformStage = 1, + /*! + * \brief Do not insert layout transformation stages and assume the input placeholders + * are pre-transformed. + * \note The lowered function with this option does not accept the origial input shapes, + * so this option must be used along with a layout conversion pass in Relay. + */ + RewriteForPreTransformed = 2, +}; + /*! * \brief Managed reference to ComputeDAGNode. * \sa ComputeDAGNode @@ -214,8 +232,10 @@ class ComputeDAG : public ObjectRef { * \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders` * according to the loop nest derived with `transform_steps`. * \param transform_steps Transform steps of a state. + * \param layout_rewrite Different options in layout rewrite. + * \return The updated ComputeDAG after layout rewrite. */ - void RewriteLayout(const Array& transform_steps); + ComputeDAG RewriteLayout(Array* transform_steps, LayoutRewriteOption layout_rewrite) const; /*! * \brief Apply the history transform steps to get a TVM schedule. @@ -225,14 +245,14 @@ class ComputeDAG : public ObjectRef { * \param stage_to_axes The map that stores all axes for one stage. * Pass a valid pointer if this information needs to be used outside this function. * \param layout_rewrite Rewrite the layout of placeholders specified by - * attr `layout_free_placeholders` + * attr `layout_free_placeholders`. * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower` * or `tvm.build`. */ - std::pair> ApplySteps(const Array& transform_steps, - Array* stages = nullptr, - StageToAxesMap* stage_to_axes = nullptr, - bool layout_rewrite = false) const; + std::pair> ApplySteps( + const Array& transform_steps, Array* stages = nullptr, + StageToAxesMap* stage_to_axes = nullptr, + LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const; /*! * \brief Print transform steps as equivalent python schedule API. diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 7be3554c7c5d..4cc1551e76fc 100755 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -182,7 +182,23 @@ class StepNode : public Object { */ class Step : public ObjectRef { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); + /*! + * \brief CopyOnWrite function for Step. + * This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different + * steps. + * \return A base StepNode pointer, need to cast to its real StepNode type before doing any + * modifications. + * \code + * + * SplitStep ref; + * StepNode* mutable_ref = ref.CopyOnWrite(); + * dynamic_cast(mutable_ref)->... = ...; + * + * \endcode + */ + StepNode* CopyOnWrite(); + + TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); }; // Forward declaration @@ -267,7 +283,7 @@ class AnnotationStepNode : public StepNode { static constexpr const char* record_prefix_str = "AN"; static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode); }; /*! @@ -330,7 +346,7 @@ class FuseStepNode : public StepNode { static constexpr const char* record_prefix_str = "FU"; static constexpr const char* _type_key = "auto_scheduler.FuseStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode); }; /*! @@ -390,7 +406,7 @@ class PragmaStepNode : public StepNode { static constexpr const char* record_prefix_str = "PR"; static constexpr const char* _type_key = "auto_scheduler.PragmaStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode); }; /*! @@ -452,7 +468,7 @@ class ReorderStepNode : public StepNode { static constexpr const char* record_prefix_str = "RE"; static constexpr const char* _type_key = "auto_scheduler.ReorderStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode); }; /*! @@ -527,7 +543,7 @@ class SplitStepNode : public StepNode { static constexpr const char* record_prefix_str = "SP"; static constexpr const char* _type_key = "auto_scheduler.SplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode); }; /*! @@ -607,7 +623,7 @@ class FollowSplitStepNode : public StepNode { static constexpr const char* record_prefix_str = "FSP"; static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode); }; /*! @@ -688,7 +704,7 @@ class FollowFusedSplitStepNode : public StepNode { static constexpr const char* record_prefix_str = "FFSP"; static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode); }; /*! @@ -754,7 +770,7 @@ class StorageAlignStepNode : public StepNode { static constexpr const char* record_prefix_str = "SA"; static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode); }; /*! @@ -822,7 +838,7 @@ class ComputeAtStepNode : public StepNode { static constexpr const char* record_prefix_str = "CA"; static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode); }; /*! @@ -879,7 +895,7 @@ class ComputeInlineStepNode : public StepNode { static constexpr const char* record_prefix_str = "CI"; static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode); }; /*! @@ -938,7 +954,7 @@ class ComputeRootStepNode : public StepNode { static constexpr const char* record_prefix_str = "CR"; static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode); }; /*! @@ -1010,7 +1026,7 @@ class CacheReadStepNode : public StepNode { static constexpr const char* record_prefix_str = "CHR"; static constexpr const char* _type_key = "auto_scheduler.CacheReadStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode); }; /*! @@ -1081,7 +1097,7 @@ class CacheWriteStepNode : public StepNode { static constexpr const char* record_prefix_str = "CHW"; static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode); }; /*! @@ -1148,7 +1164,7 @@ class RfactorStepNode : public StepNode { static constexpr const char* record_prefix_str = "RF"; static constexpr const char* _type_key = "auto_scheduler.RfactorStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode); }; /*! diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 2fc0d7d0bf8c..d50ff395b679 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -51,6 +51,11 @@ class ComputeDAG(Object): Input/output tensors or workload key for a compute declaration. """ + # Layout Rewrite Options + NoRewrite = 0 + InsertTransformStage = 1 + RewriteForPreTransformed = 2 + def __init__(self, compute_or_sche): if isinstance(compute_or_sche, str): compute = workload_key_to_tensors(compute_or_sche) @@ -81,7 +86,7 @@ def get_init_state(self): """ return State(self.init_state, self) - def apply_steps_from_state(self, state, layout_rewrite=False): + def apply_steps_from_state(self, state, layout_rewrite=NoRewrite): """ Apply the history transform steps from a State to get a TVM schedule. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index c6cf094ee202..090e6daf9859 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -796,8 +796,8 @@ std::string GetOrigLayout(std::set* placeholder_axis_names, const t return orig_layout; } -std::string GetNewLayout(Array* new_shape, const State& state, const int stage_id, - const Stage& stage, const te::Operation& op, const te::Tensor& placeholder, +std::string GetNewLayout(const State& state, const int stage_id, const Stage& stage, + const te::Operation& op, const te::Tensor& placeholder, const std::set& placeholder_axis_names) { std::ostringstream os; Array stage_iters; @@ -852,7 +852,6 @@ std::string GetNewLayout(Array* new_shape, const State& state, const i if (placeholder_axis_names.count(ori_iter_name)) { os << iter->range->extent << ori_iter_name; new_names.push_back(ori_iter_name); - new_shape->push_back(iter->range->extent); } } std::string new_layout = os.str(); @@ -862,16 +861,22 @@ std::string GetNewLayout(Array* new_shape, const State& state, const i return new_layout; } -void ComputeDAG::RewriteLayout(const Array& transform_steps) { - ComputeDAGNode* p_dag = this->CopyOnWrite(); +ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, + LayoutRewriteOption layout_rewrite) const { + CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite) + << "Call ComputeDAG::RewriteLayout with NoRewrite."; + ComputeDAG new_dag = *this; + ComputeDAGNode* p_dag = new_dag.CopyOnWrite(); + auto node = make_object(); - node->transform_steps = transform_steps; + node->transform_steps = *transform_steps; node->concrete = true; const State& state = InferBound(State(node)); + OperationSet handled_ops; - int stage_id = -1; - for (const auto& stage : state->stages) { - stage_id += 1; + for (size_t stage_id = 0; stage_id < state->stages.size(); stage_id++) { + const auto& stage = state->stages[stage_id]; + const te::Operation& op = stage->op; if (!op->IsInstance()) { continue; @@ -881,15 +886,13 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { continue; } const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; - Array placeholders = Downcast>(attr_value); - for (const auto& placeholder : placeholders) { + for (const auto& placeholder : Downcast>(attr_value)) { const auto& placeholder_op = placeholder->op; // Check whether this placeholder has already been handled if (handled_ops.count(placeholder_op)) { continue; } - // Skip the op that is not direct consumer of this placeholder. // This is usually caused by cache read/write. bool direct_consumer = false; @@ -902,28 +905,89 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { if (!direct_consumer) { continue; } + handled_ops.insert(placeholder_op); + // Process original layout std::set placeholder_axis_names; - GetOrigLayout(&placeholder_axis_names, op, placeholder); + std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder); + Array origin_shape; + std::vector origin_axes; + ParseKernelLayout(origin_layout, &origin_shape, &origin_axes); - Array new_shape; + // Process new layout std::string new_layout = - GetNewLayout(&new_shape, state, stage_id, stage, op, placeholder, placeholder_axis_names); - - handled_ops.insert(placeholder_op); - - Array old_ops = p_dag->ops; - ArrayNode* pops = p_dag->ops.CopyOnWrite(); - - // Create new placeholder - te::Operation new_placeholder_op; - new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape, + GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names); + Array new_shape; + std::vector new_axes; + ParseKernelLayout(new_layout, &new_shape, &new_axes); + + // Process op updates + te::Operation new_op_to_update; + if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) { + // Create new placeholder + new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, placeholder_op.as()->dtype); + } else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { + // Process index strides + std::unordered_map axes_stride; + for (const auto& i : origin_axes) { + axes_stride[i] = Integer(1); + } + Array new_stride(new_shape.size(), PrimExpr()); + PrimExpr temp = Integer(1); + for (int i = new_shape.size() - 1; i >= 0; i--) { + new_stride.Set(i, axes_stride[new_axes[i]]); + axes_stride[new_axes[i]] *= new_shape[i]; + } - te::Operation new_compute_op, old_compute_op; + // Add extra layout transpose stage + const auto& layout_transform_tensor = te::compute( + new_shape, + [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, + &new_axes](const tvm::runtime::Array& indices) -> tvm::PrimExpr { + Array access_indices; + for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) { + PrimExpr temp = Integer(0); + for (size_t i = 0; i < new_shape.size(); i++) { + if (origin_axes[indice_index].compare(new_axes[i]) == 0) { + temp += indices[i] * new_stride[i]; + } + } + access_indices.push_back(temp); + } + return placeholder_op.output(0)(access_indices); + }, + "auto_schedule_layout_transpose"); + new_op_to_update = layout_transform_tensor->op; + + // Update the transform steps + for (size_t i = 0; i < transform_steps->size(); i++) { + Step step = (*transform_steps)[i]; + if (step->stage_id >= static_cast(stage_id)) { + step.CopyOnWrite()->stage_id++; + } + if (step->IsInstance()) { + auto compute_at_step = tvm::Downcast(step); + if (compute_at_step->target_stage_id >= static_cast(stage_id)) { + dynamic_cast(compute_at_step.CopyOnWrite())->target_stage_id++; + } + transform_steps->Set(i, std::move(compute_at_step)); + } else { + transform_steps->Set(i, std::move(step)); + } + } + Array to_fuse; + for (size_t i = 0; i < new_shape.size() - 1; i++) { + to_fuse.push_back(i); + } + transform_steps->push_back(FuseStep(stage_id, to_fuse)); + transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); + } + + te::Operation new_compute_op, original_compute_op; Array new_body; IndexRewriter index_rewriter(placeholder_op, new_layout); - for (auto& op : old_ops) { + for (const auto& op : p_dag->ops) { if (auto* pop = op.as()) { bool need_update = false; for (auto& t : op->InputTensors()) { @@ -933,35 +997,45 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { } } if (need_update) { - for (auto& body : pop->body) { + for (const auto& body : pop->body) { new_body.push_back(index_rewriter.Rewrite(body)); } - old_compute_op = op; - ICHECK(!new_compute_op.defined()); + original_compute_op = op; + CHECK(!new_compute_op.defined()); new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body); } } } - // construct the map from old_op to new_op + // construct the map from original_op to new_op std::unordered_map updated_ops; - for (size_t i = 0; i < old_ops.size(); ++i) { - auto old_op = old_ops[i]; - if (old_op == placeholder_op) { - pops->SetItem(i, new_placeholder_op); - updated_ops[placeholder_op] = new_placeholder_op; - } else if (old_op == old_compute_op) { - pops->SetItem(i, new_compute_op); - updated_ops[old_compute_op] = new_compute_op; + + Array original_ops = p_dag->ops; + p_dag->ops.clear(); + for (size_t i = 0; i < original_ops.size(); ++i) { + const auto& original_op = original_ops[i]; + if (original_op == placeholder_op) { + if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { + p_dag->ops.push_back(placeholder_op); + } + p_dag->ops.push_back(new_op_to_update); + updated_ops[placeholder_op] = new_op_to_update; + } else if (original_op == original_compute_op) { + p_dag->ops.push_back(new_compute_op); + updated_ops[original_compute_op] = new_compute_op; } else { - pops->SetItem(i, old_op); + p_dag->ops.push_back(original_op); } } + ArrayNode* pops = p_dag->ops.CopyOnWrite(); // Because ops is sorted in topo-order, only do one pass linear scan here. for (size_t i = 0; i < pops->size(); ++i) { - auto old_op = Downcast(pops->at(i)); - if (auto* pop = old_op.as()) { + const auto& original_op = Downcast(pops->at(i)); + if (auto* pop = original_op.as()) { + if (original_op == new_op_to_update) { + continue; + } auto inputs = pop->InputTensors(); std::unordered_map rmap; for (auto input : inputs) { @@ -977,8 +1051,8 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { } } if (!rmap.empty()) { - te::Operation new_op = pop->ReplaceInputs(old_op, rmap); - updated_ops[old_op] = new_op; + te::Operation new_op = pop->ReplaceInputs(original_op, rmap); + updated_ops[original_op] = new_op; pops->SetItem(i, new_op); } } @@ -986,9 +1060,12 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { Array old_tensors = p_dag->tensors; ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite(); - for (size_t i = 0; i < old_tensors.size(); ++i) { const auto& old_tensor = old_tensors[i]; + if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed && + old_tensor->op->IsInstance()) { + continue; + } auto it = updated_ops.find(old_tensor->op); te::Operation new_op; while (it != updated_ops.end()) { @@ -1018,15 +1095,17 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { } p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops); p_dag->init_state = State(p_dag->ops); + + return new_dag; } std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes, - bool layout_rewrite) const { - if (layout_rewrite && !transform_steps.empty()) { - ComputeDAG new_dag = *this; - new_dag.RewriteLayout(transform_steps); - return new_dag.ApplySteps(transform_steps, stages, stage_to_axes, false); + LayoutRewriteOption layout_rewrite) const { + if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { + Array steps = transform_steps; + const auto& dag = RewriteLayout(&steps, layout_rewrite); + return dag.ApplySteps(steps); } // Temporal object to be used if the input pointer is nullptr @@ -1305,11 +1384,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG") }); TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState") - .set_body_typed([](const ComputeDAG& dag, const State& state, const bool layout_rewrite) { + .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) { te::Schedule sch; Array return_tensors; std::tie(sch, return_tensors) = - dag.ApplySteps(state->transform_steps, nullptr, nullptr, layout_rewrite); + dag.ApplySteps(state->transform_steps, nullptr, nullptr, + static_cast(layout_rewrite)); return Array{sch, return_tensors}; }); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 23d6eb64da6c..517f7ff91f55 100755 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -445,6 +445,12 @@ String State::ToStr(bool delete_trivial_loop) const { return os.str(); } +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto& stage = tvm::Downcast(ref); + p->stream << stage->GetTypeKey() << "(" << stage.get() << ": " << stage->op->name << ")"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { PrintState(&p->stream, tvm::Downcast(ref), true); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 852f1e1f17d8..5560907dcffa 100755 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -122,6 +122,58 @@ const char* IteratorAnnotationString[] = { "tensorize" // kTensorized = 11 }; +StepNode* Step::CopyOnWrite() { + CHECK(data_ != nullptr); + if (!data_.unique()) { + if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else if (const auto& ps = as()) { + auto n = make_object(*ps); + ObjectPtr(std::move(n)).swap(data_); + } else { + LOG(FATAL) << "Invalid step: " << (*this); + } + } + return static_cast(data_.get()); +} + Step StepReadFromRecord(dmlc::JSONReader* reader) { std::string name; bool s; diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index 3ce7a438eef4..4a11d0fb0ca0 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -28,18 +28,26 @@ def test_apply_steps_with_layout_rewrite(): dag, s = get_tiled_matmul() - _, bufs = dag.apply_steps_from_state(s, layout_rewrite=False) + _, bufs = dag.apply_steps_from_state(s) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 - _, bufs = dag.apply_steps_from_state(s, layout_rewrite=True) + _, bufs = dag.apply_steps_from_state( + s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed + ) assert bufs[1].shape[0] == 4 assert bufs[1].shape[1] == 8 assert bufs[1].shape[2] == 4 assert bufs[1].shape[3] == 4 assert bufs[1].shape[4] == 512 + _, bufs = dag.apply_steps_from_state( + s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage + ) + assert bufs[1].shape[0] == 512 + assert bufs[1].shape[1] == 512 -def test_layout_rewrite_correctness(): +@tvm.testing.requires_llvm +def test_correctness_layout_rewrite_rewrite_for_preTransformed(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) @@ -50,16 +58,19 @@ def test_layout_rewrite_correctness(): search_policy = auto_scheduler.SketchPolicy(task) + measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, - runner="local", + runner=measure_ctx.runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) - s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True) - s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False) + s, bufs = dag.apply_steps_from_state( + inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed + ) + s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] np_args_ref = [np.array(x) for x in np_args] @@ -100,10 +111,60 @@ def test_layout_rewrite_correctness(): func_ref(*args_ref) ctx.sync() - np.testing.assert_allclose(np_args[0], np_args_ref[0]) - np.testing.assert_allclose(np_args[2], np_args_ref[2]) + tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), rtol=1e-4) + tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), rtol=1e-4) + del measure_ctx + + +@tvm.testing.requires_llvm +def test_correctness_layout_rewrite_insert_transform_stage(): + N = 128 + target = tvm.target.Target("llvm") + task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) + dag = task.compute_dag + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + + search_policy = auto_scheduler.SketchPolicy(task) + + measure_ctx = auto_scheduler.LocalRPCMeasureContext() + tuning_options = auto_scheduler.TuningOptions( + num_measure_trials=2, + runner=measure_ctx.runner, + verbose=1, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + ) + auto_scheduler.auto_schedule(task, search_policy, tuning_options) + inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) + s, bufs = dag.apply_steps_from_state( + inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage + ) + + s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) + np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] + + func = tvm.build(s, bufs, target=target) + func_ref = tvm.build(s_ref, bufs_ref, target=target) + + ctx = tvm.context(str(target)) + ctx_ref = tvm.cpu() + + args = [tvm.nd.array(x, ctx=ctx) for x in np_args] + args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] + ctx.sync() + + func(*args) + func_ref(*args_ref) + ctx.sync() + + tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), rtol=1e-4) + tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), rtol=1e-4) + tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), rtol=1e-4) + del measure_ctx if __name__ == "__main__": test_apply_steps_with_layout_rewrite() - test_layout_rewrite_correctness() + test_correctness_layout_rewrite_rewrite_for_preTransformed() + test_correctness_layout_rewrite_insert_transform_stage() diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py index 7851d922013d..2debc14fc356 100644 --- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py +++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py @@ -21,11 +21,14 @@ import multiprocessing import numpy as np +import tvm +import tvm.testing from tvm import auto_scheduler from test_auto_scheduler_common import matmul_auto_scheduler_test +@tvm.testing.requires_llvm def test_task_scheduler_round_robin(): tasks = [] for n in [2, 4, 8]: @@ -39,8 +42,10 @@ def objective_func(costs): num_trials_per_task = 2 # Tune all tasks + measure_ctx = auto_scheduler.LocalRPCMeasureContext() tune_option = auto_scheduler.TuningOptions( num_measure_trials=num_trials_per_task * len(tasks), + runner=measure_ctx.runner, num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) @@ -67,13 +72,16 @@ def objective_func(costs): num_measures_per_round=1, ) task_scheduler.tune(tune_option, search_policy="sketch.random") + del measure_ctx +@tvm.testing.requires_llvm def task_scheduler_round_robin_spawn(): assert multiprocessing.get_start_method(False) == "spawn" test_task_scheduler_round_robin() +@tvm.testing.requires_llvm def test_task_scheduler_round_robin_spawn(): ctx = multiprocessing.get_context("spawn") p = ctx.Process(target=task_scheduler_round_robin_spawn) @@ -81,6 +89,7 @@ def test_task_scheduler_round_robin_spawn(): p.join() +@tvm.testing.requires_llvm def test_task_scheduler_gradient(): tasks = [] for n in [2, 4]: @@ -95,8 +104,10 @@ def objective_func(costs): n_trials = 5 # Tune all tasks + measure_ctx = auto_scheduler.LocalRPCMeasureContext() tune_option = auto_scheduler.TuningOptions( num_measure_trials=n_trials, + runner=measure_ctx.runner, num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) @@ -118,6 +129,7 @@ def objective_func(costs): assert counters[tasks[0].workload_key] == n_trials - 1 assert counters[tasks[1].workload_key] == 1 + del measure_ctx if __name__ == "__main__":