diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 11fec642c718..da399ab976d6 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -432,9 +432,13 @@ class ScheduleNode : public runtime::Object { * \param block_rv The block to be moved * \param loop_rv The loop where the block to be moved under * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \param index The block index of the loop body subtree blocks: + * - `index = -1` means inserted into the last possible insertion point; + * - `index = -2` means inserted into the first possible insertion point; + * - Otherwise, `index` is a nonnegative number that indicates the insertion point */ - virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) = 0; + virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + int index = -1) = 0; /*! * \brief Move a consumer block under the specific loop, and regenerate the * loops induced by the block so that the buffer region consumed by the consumer block could @@ -449,9 +453,13 @@ class ScheduleNode : public runtime::Object { * \param block_rv The block to be moved * \param loop_rv The loop where the block to be moved under * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \param index The block index of the loop body subtree blocks: + * - `index = -1` means inserted into the last possible insertion point; + * - `index = -2` means inserted into the first possible insertion point; + * - Otherwise, `index` is a nonnegative number that indicates the insertion point */ virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) = 0; + bool preserve_unit_loops, int index = -1) = 0; /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index e18bee35a5e1..04cc1bc26ad1 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1274,6 +1274,7 @@ def compute_at( block: Union[BlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, + index: int = -1, ) -> None: """Compute-At. Move a producer block under the specific loop, and regenerate the loops induced by the block so that the buffer region produced by the producer block could @@ -1303,6 +1304,12 @@ def compute_at( preserve_unit_loops: bool Whether to keep the trivial loops whose extents are 1 + index: int + The block index of the loop body subtree blocks: + - `index = -1` means inserted into the last possible insertion point; + - `index = -2` means inserted into the first possible insertion point; + - Otherwise, `index` is a nonnegative number that indicates the insertion point + Examples -------- @@ -1360,6 +1367,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: block, loop, preserve_unit_loops, + index, ) @type_checked @@ -1368,6 +1376,7 @@ def reverse_compute_at( block: Union[BlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, + index: int = -1, ) -> None: """Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the loops induced by the block so that the buffer region consumed by the consumer block could @@ -1394,6 +1403,12 @@ def reverse_compute_at( preserve_unit_loops: bool Whether to keep the trivial loops whose extents are 1 + index: int + The block index of the loop body subtree blocks: + - `index = -1` means inserted into the last possible insertion point; + - `index = -2` means inserted into the first possible insertion point; + - Otherwise, `index` is a nonnegative number that indicates the insertion point + Examples -------- @@ -1451,6 +1466,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: block, loop, preserve_unit_loops, + index, ) @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index c16638f748b4..5f773a02d6ff 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -574,7 +574,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) { + bool preserve_unit_loops, int index) { static StmtSRef inline_mark = StmtSRef::InlineMark(); static StmtSRef root_mark = StmtSRef::RootMark(); StmtSRef loop_sref = this->GetSRef(loop_rv); @@ -586,14 +586,14 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_); } else { TVM_TIR_SCHEDULE_BEGIN(); - tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops); + tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_); } this->state_->DebugVerify(); } void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) { + bool preserve_unit_loops, int index) { static StmtSRef inline_mark = StmtSRef::InlineMark(); static StmtSRef root_mark = StmtSRef::RootMark(); StmtSRef loop_sref = this->GetSRef(loop_rv); @@ -605,7 +605,7 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_); } else { TVM_TIR_SCHEDULE_BEGIN(); - tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops); + tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_); } this->state_->DebugVerify(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index cdd0a5b7b0a2..92b9de408873 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -119,9 +119,10 @@ class ConcreteScheduleNode : public ScheduleNode { BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) override; /******** Schedule: Compute location ********/ - void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; - void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) override; + void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + int index = -1) override; + void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + int index = -1) override; void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; /******** Schedule: Reduction ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 14203a0d167e..05d9e4cf944a 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -299,10 +299,13 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf * \param self The schedule state * \param block_sref The block to be moved * \param loop_sref The loop where the block to be moved to - * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \param index The block index of the loop body subtree blocks: + * - `index = -1` means inserted into the last possible insertion point; + * - `index = -2` means inserted into the first possible insertion point; + * - Otherwise, `index` is a nonnegative number that indicates the insertion point */ TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, - bool preserve_unit_loops); + bool preserve_unit_loops, int index = -1); /*! * \brief Move a consumer block under the specific loop, and regenerate the * loops induced by the block so that the buffer region consumed by the consumer block could @@ -318,9 +321,13 @@ TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stm * \param block_sref The block to be moved * \param loop_sref The loop where the block to be moved to * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \param index The block index of the loop body subtree blocks: + * - `index = -1` means inserted into the last possible insertion point; + * - `index = -2` means inserted into the first possible insertion point; + * - Otherwise, `index` is a nonnegative number that indicates the insertion point */ TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, - const StmtSRef& loop_sref, bool preserve_unit_loops); + const StmtSRef& loop_sref, bool preserve_unit_loops, int index = -1); /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 98a6b2400ee3..8baedfd70dd0 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -129,15 +129,19 @@ class NotInSameScopeError : public ScheduleError { * \param producer_srefs The producer blocks * \param consumer_srefs The consumer blocks * \param block2realize A cache that maps a block to its realize - * \return The last position the new block can be inserted onto, and the + * \param index The block index of the loop body subtree blocks: + * - `index = -1` means inserted into the last possible insertion point; + * - `index = -2` means inserted into the first possible insertion point; + * - Otherwise, `index` is a nonnegative number that indicates the insertion point + * \return The possible position the new block can be inserted into, and the * producer-consumer-relationship is still satisfied. * \throws ScheduleError if there is no such insertion point found */ template -int FindInsertionPoint( - const ScheduleState& self, const Array& subtrees, const Array& producer_srefs, - const Array& consumer_srefs, - std::unordered_map* block2realize) { +int FindInsertionPoint(const ScheduleState& self, const Array& subtrees, + const Array& producer_srefs, const Array& consumer_srefs, + std::unordered_map* block2realize, + int index) { ProducerConsumerSplit split = ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize); // Step 1. Check if all the producers are visited in the subtrees, if required to @@ -159,8 +163,22 @@ int FindInsertionPoint( // Step 3. Check if there is at least one index of the position can be inserted into // The valid indices are: (last_producer_position, first_consumer_position] ICHECK(split.last_producer_position < split.first_consumer_position); - // Step 4. Return the last valid insertion point - return split.first_consumer_position; + // Step 4. Return the possible insertion point according to index + int insert_position; + if (index == -1) { + insert_position = split.first_consumer_position; + } else if (index == -2) { + insert_position = split.last_producer_position + 1; + } else if (index >= 0 && index >= split.last_producer_position + 1 && + index <= split.first_consumer_position) { + insert_position = index; + } else { + LOG(FATAL) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 << ", " + << split.first_consumer_position << "]), " + << "current index=" << index; + throw; + } + return insert_position; } /*! @@ -556,7 +574,8 @@ void CalculateProvidedRequiredRegions( template void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, - arith::Analyzer* analyzer, bool check_only = false) { + arith::Analyzer* analyzer, bool check_only = false, + int index = -1) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // Step 1. Bunch of checks @@ -588,7 +607,8 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s /*self=*/self, /*subtrees=*/AsArray(loop->body), /*producer_srefs=*/producer_srefs, - /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize); + /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize, + /*index=*/index); // Step 4. Calculate the region provided by a single execution instance of `block`, // as well as the region required by dependent blocks under `loop`. // Here is the definition of `provide` and `require`: @@ -626,17 +646,17 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s } void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, - bool preserve_unit_loops) { + bool preserve_unit_loops, int index) { arith::Analyzer analyzer; - ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, &analyzer, + false, index); } void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, - bool preserve_unit_loops) { + bool preserve_unit_loops, int index) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + &analyzer, false, index); } bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, @@ -671,20 +691,21 @@ struct ComputeAtTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 2; - static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, - Bool preserve_unit_loops) { - return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool()); + Bool preserve_unit_loops, IntImm index) { + return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value); } static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, - Bool preserve_unit_loops) { + Bool preserve_unit_loops, IntImm index) { PythonAPICall py("compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); + py.Input("index", index); return py.Str(); } @@ -698,20 +719,22 @@ struct ReverseComputeAtTraits : public UnpackedInstTraitsReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool()); + Bool preserve_unit_loops, IntImm index) { + return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), + index->value); } static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, - Bool preserve_unit_loops) { + Bool preserve_unit_loops, IntImm index) { PythonAPICall py("reverse_compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); + py.Input("index", index); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 07d4da54d7fb..04ddc0507dc4 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -322,24 +322,25 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) { - ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops); + bool preserve_unit_loops, int index) { + ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, index); static const InstructionKind& kind = InstructionKind::Get("ComputeAt"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops)}, - /*outputs=*/{})); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, loop_rv}, + /*attrs=*/{Integer(preserve_unit_loops), Integer(index)}, + /*outputs=*/{})); } void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) { - ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops); + bool preserve_unit_loops, int index) { + ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index); static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops)}, + /*attrs=*/{Integer(preserve_unit_loops), Integer(index)}, /*outputs=*/{})); } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 865a21687950..d98e4ba4bb95 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -79,9 +79,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) final; /******** Schedule: Compute location ********/ - void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; - void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, - bool preserve_unit_loops) final; + void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + int index = -1) final; + void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + int index = -1) final; void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; /******** Schedule: Reduction ********/ diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 5f76e77592e3..592d32d6245d 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -80,7 +80,7 @@ def test_gpu_softmax_mn(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)", @@ -93,7 +93,7 @@ def test_gpu_softmax_mn(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)", @@ -107,7 +107,7 @@ def test_gpu_softmax_mn(): "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)", 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', "l8, l9, l10 = sch.get_loops(block=b1)", "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)", @@ -117,7 +117,7 @@ def test_gpu_softmax_mn(): "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)", 'sch.bind(loop=l18, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)", + "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l19, l20, l21 = sch.get_loops(block=b0)", "l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)", @@ -157,7 +157,7 @@ def test_gpu_softmax_mn_after_inline(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)", @@ -171,14 +171,14 @@ def test_gpu_softmax_mn_after_inline(): "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)", 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', "l8, l9, l10 = sch.get_loops(block=b1)", "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)", 'sch.bind(loop=l12, thread_axis="threadIdx.x")', "b13, b14 = sch.get_consumers(block=b0)", "l15, l16, l17, l18 = sch.get_loops(block=b13)", - "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l19, l20, l21 = sch.get_loops(block=b0)", "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)", @@ -206,7 +206,7 @@ def test_gpu_batch_norm_bmn(): "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)", 'sch.bind(loop=l5, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)", + "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True, index=-1)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l6, l7, l8, l9 = sch.get_loops(block=b0)", "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)", diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index 87159fcb3110..fe1220c50925 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -62,7 +62,7 @@ def test_cpu_matmul(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -76,7 +76,7 @@ def test_cpu_matmul(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -123,7 +123,7 @@ def test_cpu_matmul_relu(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -137,7 +137,7 @@ def test_cpu_matmul_relu(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -193,15 +193,15 @@ def test_cuda_matmul(): 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)", 'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)", + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)", + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", @@ -247,15 +247,15 @@ def test_cuda_matmul_relu(): "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)", 'sch.bind(loop=l32, thread_axis="threadIdx.x")', 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)", 'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)", + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)", + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", @@ -402,7 +402,7 @@ def test_multi_level_tiling_conv2d_nchwc_vnni(): l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True) sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") -sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split( +sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True, index=-1)""".split( "\n" ), """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") @@ -437,7 +437,7 @@ def test_multi_level_tiling_conv2d_nchwc_vnni(): l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True) sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") -sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split( +sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True, index=-1)""".split( "\n" ), """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") @@ -546,15 +546,15 @@ def test_multi_level_tiling_dense_dp4a(): l38 = sch.fuse(l17, l27, preserve_unit_iters=True) sch.bind(loop=l38, thread_axis="threadIdx.x") b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local") -sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True) +sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True, index=-1) b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True) +sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True, index=-1) l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40) l47 = sch.fuse(l45, l46, preserve_unit_iters=True) v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48) b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True) +sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True, index=-1) l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49) l56 = sch.fuse(l54, l55, preserve_unit_iters=True) v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) @@ -632,9 +632,9 @@ def test_cuda_tensor_core_matmul_relu(): l52 = sch.fuse(l31, l41, preserve_unit_iters=True) sch.bind(loop=l52, thread_axis="threadIdx.y") b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared") -sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True) +sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1) b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True) +sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1) v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) sch.reverse_compute_inline(block=b2) @@ -646,19 +646,19 @@ def test_cuda_tensor_core_matmul_relu(): b72 = sch.blockize(loop=l64) sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1) l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73) l80 = sch.fuse(l78, l79, preserve_unit_iters=True) v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1) l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82) l89 = sch.fuse(l87, l88, preserve_unit_iters=True) v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90) b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True) +sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1) l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91) l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True) l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True) @@ -667,7 +667,7 @@ def test_cuda_tensor_core_matmul_relu(): b112 = sch.blockize(loop=l102) sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True) +sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1) l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113) l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True) l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True) @@ -772,9 +772,9 @@ def test_cuda_tensor_core_software_pipeline_matmul_relu(): l52 = sch.fuse(l31, l41, preserve_unit_iters=True) sch.bind(loop=l52, thread_axis="threadIdx.y") b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared") -sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True) +sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1) b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True) +sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1) v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) sch.reverse_compute_inline(block=b2) @@ -786,19 +786,19 @@ def test_cuda_tensor_core_software_pipeline_matmul_relu(): b72 = sch.blockize(loop=l64) sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1) l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73) l80 = sch.fuse(l78, l79, preserve_unit_iters=True) v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1) l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82) l89 = sch.fuse(l87, l88, preserve_unit_iters=True) v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90) b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True) +sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1) l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91) l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True) l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True) @@ -807,7 +807,7 @@ def test_cuda_tensor_core_software_pipeline_matmul_relu(): b112 = sch.blockize(loop=l102) sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True) +sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1) l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113) l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True) l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True) @@ -895,7 +895,7 @@ def test_cuda_tensor_core_matmul_relu_global(): l51 = sch.fuse(l30, l40, preserve_unit_iters=True) sch.bind(loop=l51, thread_axis="threadIdx.y") b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True) +sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1) sch.reverse_compute_inline(block=b1) l53, l54, l55, l56, l57 = sch.get_loops(block=b52) l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True) @@ -905,19 +905,19 @@ def test_cuda_tensor_core_matmul_relu_global(): b69 = sch.blockize(loop=l61) sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global") b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True) +sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1) l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70) l77 = sch.fuse(l75, l76, preserve_unit_iters=True) v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True) +sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1) l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79) l86 = sch.fuse(l84, l85, preserve_unit_iters=True) v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1) l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88) l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True) @@ -926,7 +926,7 @@ def test_cuda_tensor_core_matmul_relu_global(): b109 = sch.blockize(loop=l99) sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1) l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110) l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True) l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True) @@ -995,7 +995,7 @@ def test_cuda_tensor_core_matmul_relu_global(): l51 = sch.fuse(l30, l40, preserve_unit_iters=True) sch.bind(loop=l51, thread_axis="threadIdx.y") b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True) +sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1) sch.reverse_compute_inline(block=b1) l53, l54, l55, l56, l57 = sch.get_loops(block=b52) l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True) @@ -1005,19 +1005,19 @@ def test_cuda_tensor_core_matmul_relu_global(): b69 = sch.blockize(loop=l61) sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global") b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True) +sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1) l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70) l77 = sch.fuse(l75, l76, preserve_unit_iters=True) v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True) +sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1) l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79) l86 = sch.fuse(l84, l85, preserve_unit_iters=True) v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1) l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88) l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True) @@ -1026,7 +1026,7 @@ def test_cuda_tensor_core_matmul_relu_global(): b109 = sch.blockize(loop=l99) sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True) +sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1) l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110) l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True) l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True) @@ -1133,9 +1133,9 @@ def test_cuda_tensor_core_conv2d(): l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True) sch.bind(loop=l64, thread_axis="threadIdx.y") b65 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="shared") -sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True) +sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True, index=-1) b66 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True) +sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True, index=-1) v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) sch.reverse_compute_inline(block=b1) @@ -1147,19 +1147,19 @@ def test_cuda_tensor_core_conv2d(): b84 = sch.blockize(loop=l76) sch.annotate(block_or_loop=b84, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True) +sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True, index=-1) l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85) l92 = sch.fuse(l90, l91, preserve_unit_iters=True) v93 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93) b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True) +sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True, index=-1) l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94) l101 = sch.fuse(l99, l100, preserve_unit_iters=True) v102 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102) b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True) +sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True, index=-1) l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b103) l111, l112 = sch.split(loop=l110, factors=[None, 16], preserve_unit_iters=True) l113, l114 = sch.split(loop=l109, factors=[None, 16], preserve_unit_iters=True) @@ -1168,7 +1168,7 @@ def test_cuda_tensor_core_conv2d(): b124 = sch.blockize(loop=l114) sch.annotate(block_or_loop=b124, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") b125 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True) +sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True, index=-1) l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b125) l133, l134 = sch.split(loop=l132, factors=[None, 16], preserve_unit_iters=True) l135, l136 = sch.split(loop=l131, factors=[None, 16], preserve_unit_iters=True) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py index b2df408e9d01..c951a5adf386 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -71,7 +71,7 @@ def test_random_compute_location(): [ 'b0 = sch.get_block(name="move", func_name="main")', "l1 = sch.sample_compute_location(block=b0)", - "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)", + "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True, index=-1)", ] ] mod = Add diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 0c20a4783ca0..72cba1a8fdc4 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1353,5 +1353,157 @@ def _create_prim_func(): verify_trace_roundtrip(sch=sch, mod=mod) +def test_compute_at_to_index(): + @T.prim_func + def multi_producers_conv( + data: T.Buffer[(1, 3, 224, 224), "int8"], + w: T.Buffer[(16, 3, 7, 7), "int8"], + conv: T.Buffer[(1, 16, 112, 112), "int32"], + ) -> None: + pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8") + wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8") + for i0, i1, i2, i3 in T.grid(1, 3, 230, 230): + with T.block("pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3]) + T.writes(pad[i0_1, i1_1, i2_1, i3_1]) + pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227, + data[i0_1, i1_1, i2_1 - 3, i3_1 - 3], + T.int8(0), + dtype="int8", + ) + for i0 in T.serial(1): + for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7): + with T.block("wbuf"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(w[v0, v1, v2, v3]) + T.writes(wbuf[v0, v1, v2, v3]) + wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3] + for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7): + with T.block("conv"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap( + "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6] + ) + T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx]) + T.writes(conv[nn, ff, yy, xx]) + with T.init(): + conv[nn, ff, yy, xx] = 0 + conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast( + pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32" + ) * T.cast(wbuf[ff, rc, ry, rx], "int32") + + @T.prim_func + def multi_producers_after_compute_at( + data: T.Buffer[(1, 3, 224, 224), "int8"], + w: T.Buffer[(16, 3, 7, 7), "int8"], + conv: T.Buffer[(1, 16, 112, 112), "int32"], + ) -> None: + pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8") + wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8") + for i0 in T.serial(1): + for ax0, ax1, ax2 in T.grid(3, 229, 229): + with T.block("pad"): + i0_1 = T.axis.spatial(1, 0) + i1_1 = T.axis.spatial(3, ax0) + i2_1 = T.axis.spatial(230, ax1) + i3_1 = T.axis.spatial(230, ax2) + T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3]) + T.writes(pad[i0_1, i1_1, i2_1, i3_1]) + pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227, + data[i0_1, i1_1, i2_1 - 3, i3_1 - 3], + T.int8(0), + dtype="int8", + ) + for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7): + with T.block("wbuf"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(w[v0, v1, v2, v3]) + T.writes(wbuf[v0, v1, v2, v3]) + wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3] + for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7): + with T.block("conv"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap( + "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6] + ) + T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx]) + T.writes(conv[nn, ff, yy, xx]) + with T.init(): + conv[nn, ff, yy, xx] = 0 + conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast( + pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32" + ) * T.cast(wbuf[ff, rc, ry, rx], "int32") + + sch = tir.Schedule(multi_producers_conv, debug_mask="all") + block_c = sch.get_block("pad") + axis = sch.get_loops("conv")[0] + sch.compute_at(block_c, axis, index=-2) + tvm.ir.assert_structural_equal(multi_producers_after_compute_at, sch.mod["main"]) + + +def test_reverse_compute_at_to_index(): + @T.prim_func + def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(128, 128), "float32"]) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + C = T.alloc_buffer([128, 128], dtype="float32") + for i_0, j_0, i_1 in T.grid(8, 8, 16): + for j_1 in T.serial(16): + with T.block("B"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for ax0 in T.serial(16): + with T.block("C"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + ax0) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(D[vi, vj]) + D[vi, vj] = B[vi, vj] + T.float32(1) + + @T.prim_func + def main_reverse_compute_at( + A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(128, 128), "float32"] + ) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + C = T.alloc_buffer([128, 128], dtype="float32") + for i_0, j_0, i_1 in T.grid(8, 8, 16): + for j_1 in T.serial(16): + with T.block("B"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for ax0 in T.serial(16): + with T.block("D"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + ax0) + T.reads(B[vi, vj]) + T.writes(D[vi, vj]) + D[vi, vj] = B[vi, vj] + T.float32(1) + for ax0 in T.serial(16): + with T.block("C"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + ax0) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + sch = tir.Schedule(main, debug_mask="all") + block_c = sch.get_block("D") + axis = sch.get_loops("B")[2] + sch.reverse_compute_at(block_c, axis, index=1) + tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main()