From 93559cf35b25a22ebc3ee9ccc1eca1f672a27878 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 Aug 2022 11:50:03 -0500 Subject: [PATCH 01/26] [UnitTests] Initial unit tests for padded transformation behavior --- python/tvm/tir/schedule/schedule.py | 15 + .../test_tir_schedule_transform_layout.py | 332 ++++++++++++++++++ 2 files changed, 347 insertions(+) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d1293371a0e0..2890b7bbb558 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2443,6 +2443,7 @@ def transform_layout( block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer], index_map: Union[IndexMap, Callable], + pad_value: Optional[Union[int, float, IndexMap, Callable]] = None, ) -> None: """Apply a transformation represented by IndexMap to buffer @@ -2479,6 +2480,20 @@ def transform_layout( primitive will be called in addition to the TransformLayout primitive. + pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] + + The value to be used for any padding introduced by the + transformation. + + If None, the transformation may not introduce padding. + + If an int, float or PrimExpr, the transformation is the + specific value to be present in the padding. + + If an IndexMap or Callable, the transformation is the + value to be present in the padding in terms of the + transformed index. + Examples -------- Before transform_layout, in TensorIR, the IR is: diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 0332df7fd312..5c56771c13cd 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -329,5 +329,337 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name): ) +class BasePaddingCompare(tvm.testing.CompareBeforeAfter): + pad_value = tvm.testing.parameter(None) + + transformed_buffer = tvm.testing.parameter("A") + + @pytest.fixture + def transform(self, pad_value, transformed_buffer): + def transform(mod): + sch = tir.Schedule(mod) + sch.transform_layout( + "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value + ) + # sch.transform_block_layout("block", lambda i: [i // 4, i % 4]) + return sch.mod + + return transform + + +class TestNoPadding(BasePaddingCompare): + """Transformations without padding do not depend on pad_value.""" + + pad_value = tvm.testing.parameter(None, 42) + + def before(): + A = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("block"): + A[i] = 0 + + def expected(): + A = T.alloc_buffer([4, 4], "int32") + for i in T.serial(16): + with T.block("block"): + A[i // 4, i % 4] = 0 + + +class TestNoPaddingMultipleUsage(BasePaddingCompare): + """Transformations without padding do not depend on pad_value. + + Like TestNoPadding, but the buffer A shows up in multiple + locations. To remain internally consistent, all instances of the + buffer should be rewritten. + """ + + pad_value = tvm.testing.parameter(None, 42) + + def before(): + A = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("block"): + A[i] = 0 + + B = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("other"): + B[i] = A[i] + + def expected(): + A = T.alloc_buffer([4, 4], "int32") + for i in T.serial(16): + with T.block("block"): + A[i // 4, i % 4] = 0 + + B = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("other"): + B[i] = A[i // 4, i % 4] + + +class TestNoPaddingVirtualIndex(BasePaddingCompare): + """Like TestNoPadding, but accessed through block indices.""" + + pad_value = tvm.testing.parameter(None, 42) + + def before(): + A = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi] = 0 + + def expected(): + A = T.alloc_buffer([4, 4], "int32") + for i in T.serial(16): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi // 4, vi % 4] = 0 + + +@pytest.mark.xfail(reason="Not implemented yet") +class TestErrorIfPaddingForbidden(BasePaddingCompare): + """Unless padding is explicitly enabled, should raise error""" + + def before(): + A = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + A[i] = 0 + + expected = tvm.tir.schedule.schedule.ScheduleError + + +@pytest.mark.xfail(reason="Not implemented yet") +class TestErrorOnWrongPaddingType(BasePaddingCompare): + """The padding must have the same dtype as the buffer""" + + pad_value = tvm.testing.parameter(0.5) + + def before(): + A = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + A[i] = 0 + + expected = tvm.tir.schedule.schedule.ScheduleError + + +@pytest.mark.xfail(reason="Superceded by TestPaddedTransformIfThenElse") +class TestPaddedTransformPostProc(BasePaddingCompare): + """Set the transformation padding in a post-processing block. + + This test is incompatible with TestPaddedTransformIfThenElse, and + is here for initial development purposes. + """ + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + B[i] = A[i] + + def expected(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i in T.serial(14): + with T.block("block"): + B[i // 4, i % 4] = A[i] + + for i, j in T.grid(4, 4): + with T.block("buffer_B_padding"): + T.where(i == 3 and 2 <= j) + B[i, j] = 0 + + +class TestPaddedTransformIfThenElse(BasePaddingCompare): + """Use if_then_else to represent padding, if possible. + + For a block that is a producer of the pre-transformation buffer, + which visits all indices according to a row-major traversal, and + which has no effect other than producing the transformed buffer, + transform the loop iterators to be a row-major traversal of the + post-transformation buffer, with padding represented by + `T.if_then_else`. + + This test is incompatible with TestPaddedTransformPostProc. This + is the long-term intended method to be supported, with + TestPaddedTransformPostProc present for development purposes. + """ + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + B[i] = A[i] + + def expected(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j in T.grid(4, 4): + with T.block("block"): + B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, A[i * 4 + j], dtype="int32") + + +class TestPaddedTransformWithoutLoop(BasePaddingCompare): + """Handle padded writes without a loop + + The statement being replaced may be something other than a + for-loop, such as if a loop has already been unrolled. + """ + + pad_value = tvm.testing.parameter(0) + + def before(A: T.Buffer[14, "int32"]): + with T.block("root"): + T.reads() + T.writes() + with T.block("block"): + A[0] = 0 + + def expected(A: T.Buffer[(4, 4), "int32"]): + with T.block("block"): + A[0, 0] = 0 + + for i, j in T.grid(4, 4): + with T.block("buffer_A_padding"): + T.where(i == 3 and 2 <= j) + A[i, j] = 0 + + +class TestPaddedTransformIfThenElseReduction(BasePaddingCompare): + """Like TestPaddedTransformIfThenElse, but with a reduction axis""" + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + B[i] = 0 + for k in T.serial(32): + with T.block("block"): + B[i] = B[i] + A[i, k] + + def expected(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j in T.grid(4, 4): + B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 0, dtype="int32") + for k in T.serial(32): + with T.block("block"): + B[i, j] = T.if_then_else( + i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32" + ) + + +class TestPaddedTransformIfThenElseReductionBlock(BasePaddingCompare): + """Like TestPaddedTransformIfThenElse, but with a reduction axis""" + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer(14, "int32") + for i, k in T.grid(14, 32): + with T.block("block"): + with T.init(): + B[i] = 0 + B[i] = B[i] + A[i, k] + + def expected(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j, k in T.grid(4, 4, 32): + with T.block("block"): + with T.init(): + B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 0, dtype="int32") + B[i, j] = T.if_then_else( + i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32" + ) + + +class TestPaddedTransformIfThenElseReductionBlockVirtualAxes(BasePaddingCompare): + """Like TestPaddedTransformIfThenElse, but with a reduction axis""" + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer(14, "int32") + for i, k in T.grid(14, 32): + with T.block("block"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0 + B[vi] = B[vi] + A[vi, vk] + + def expected(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j, k in T.grid(4, 4, 32): + with T.block("block"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 0, dtype="int32") + B[vi, vj] = T.if_then_else( + vi == 3 and 2 <= vj, 0, B[vi, vj] + A[vi * 4 + vj, vk], dtype="int32" + ) + + +class TestPaddedTransformPostProcIfRequiredDueToSideEffects(BasePaddingCompare): + """Set the transformation padding in a post-processing block. + + Like TestPaddedTransformIfThenElse, but the block that produces B + also has the effect of setting `C`. + """ + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + C = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + B[i] = A[i] + C[i] = 0 + + def expected(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer([4, 4], "int32") + C = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + B[i // 4, i % 4] = A[i] + C[i] = 0 + + for i, j in T.grid(4, 4): + with T.block("block_pad_B"): + T.where(i == 3 and 2 <= j) + B[i, j] = 0 + + +class TestPaddedTransformOfInputCreatesAssumption(BasePaddingCompare): + """Transformation of an input buffer places T.assume locally""" + + pad_value = tvm.testing.parameter(42) + + def before(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]): + for i in T.serial(14): + with T.block("block"): + B[i] = A[i] + + def expected(A: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): + for i, j in T.grid(4, 4): + with T.block("buffer_A_assumption"): + T.assume(not (i == 3 and 2 <= j) or A[i, j] == 42) + + for i in T.serial(14): + with T.block("block"): + B[i] = A[i // 4, i % 4] + + if __name__ == "__main__": tvm.testing.main() From 60ea52753780a440844e7e10bce9cc870b0871d8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 13:30:15 -0500 Subject: [PATCH 02/26] [Utils][Fix] Correction for non-empty Callable type annotations --- python/tvm/tir/schedule/_type_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py index 0b48dfc2b0e6..0c66f7ef6cdf 100644 --- a/python/tvm/tir/schedule/_type_checker.py +++ b/python/tvm/tir/schedule/_type_checker.py @@ -164,7 +164,7 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]: return "atomic", [type_] -def callable_str(subtypes): +def callable_str(*subtypes): if subtypes: *arg_types, return_type = subtypes arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types) From 2971b5be3821bfcb5bbea77d22d501f4d786cf09 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 Aug 2022 14:21:28 -0500 Subject: [PATCH 03/26] [TIR] Pass the pad_value argument from Python to C++ --- include/tvm/tir/schedule/schedule.h | 4 +++- python/tvm/tir/schedule/schedule.py | 2 +- src/meta_schedule/postproc/rewrite_layout.cc | 3 ++- .../multi_level_tiling_tensor_core.cc | 2 +- src/tir/schedule/concrete_schedule.cc | 6 ++++-- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/instruction_traits.h | 4 +++- src/tir/schedule/primitive.h | 4 +++- .../primitive/layout_transformation.cc | 18 +++++++++++++----- src/tir/schedule/schedule.cc | 6 ++++-- src/tir/schedule/traced_schedule.cc | 15 +++++++++------ src/tir/schedule/traced_schedule.h | 2 +- 12 files changed, 45 insertions(+), 23 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index da399ab976d6..d497faca3a8f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -601,9 +601,11 @@ class ScheduleNode : public runtime::Object { * \param buffer_index The index of the buffer in block's read or write region. * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \param index_map The transformation to apply. + * \param pad_value The value to write into padding introduced by the transformation. */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map) = 0; + BufferIndexType buffer_index_type, const IndexMap& index_map, + const Optional& pad_value) = 0; /*! * \brief Apply a transformation represented by IndexMap to block diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 2890b7bbb558..35b4a97dda04 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2553,7 +2553,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member - self, block, buffer_index, buffer_index_type_enum, index_map + self, block, buffer_index, buffer_index_type_enum, index_map, pad_value ) if axis_separators: _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index f4cbdfe737fb..40488736bfe3 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -148,7 +148,8 @@ bool RewriteLayout(const Schedule& sch) { // Apply schedule BlockRV block_rv = sch->GetBlock(block->name_hint, func_name); BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global"); - sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value()); + sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(), + NullOpt); sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); } } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 7ddda9b2635b..691f4c80f53f 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -492,7 +492,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( state->sch->state(), GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); - state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map); + state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt); }; for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index afc675799706..4c8271a45f9f 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -761,9 +761,11 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann /******** Schedule: Layout transformation ********/ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) { + const IndexMap& index_map, + const Optional& pad_value) { TVM_TIR_SCHEDULE_BEGIN(); - tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map); + tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map, + pad_value); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index e79d1d528809..e92d2aa35ac5 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -143,7 +143,7 @@ class ConcreteScheduleNode : public ScheduleNode { void Unannotate(const BlockRV& block_rv, const String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) override; + const IndexMap& index_map, const Optional& pad_value) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 56c69224fe17..122c5ff0d9fe 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -430,7 +430,9 @@ TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs( /********** PythonAPICall **********/ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) { - if (const auto* str = obj.as()) { + if (!obj.defined()) { + os << "None"; + } else if (const auto* str = obj.as()) { os << str->data; } else if (const auto* int_imm = obj.as()) { os << int_imm->value; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 05d9e4cf944a..280c57808f7d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -474,9 +474,11 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& * \param buffer_index The index of the buffer in block's read or write region. * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \param index_map The transformation to apply. + * \param pad_value The value to write into padding introduced by the transformation. */ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map); + BufferIndexType buffer_index_type, const IndexMap& index_map, + const Optional& pad_value); /*! * \brief Apply a transformation represented by IndexMap to block diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 8e2643db0103..8a20cc8e97a8 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -133,7 +133,8 @@ class BufferIsSubregionError : public ScheduleError { }; void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map) { + BufferIndexType buffer_index_type, const IndexMap& index_map, + const Optional& pad_value) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); @@ -536,17 +537,20 @@ struct TransformLayoutTraits : public UnpackedInstTraits private: static constexpr size_t kNumInputs = 1; - static constexpr size_t kNumAttrs = 3; + static constexpr size_t kNumAttrs = 4; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, - Integer buffer_index_type, IndexMap index_map) { + Integer buffer_index_type, IndexMap index_map, + Optional pad_value) { return sch->TransformLayout(block_rv, buffer_index.IntValue(), - static_cast(buffer_index_type->value), index_map); + static_cast(buffer_index_type->value), index_map, + pad_value); } static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - Integer buffer_index_type, IndexMap index_map) { + Integer buffer_index_type, IndexMap index_map, + Optional pad_value) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); @@ -556,6 +560,8 @@ struct TransformLayoutTraits : public UnpackedInstTraits py.Input("buffer", os.str()); py.Input("index_map", index_map->ToPythonString()); + py.Input("pad_value", pad_value); + return py.Str(); } @@ -566,6 +572,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits attrs_record.push_back(attrs[0]); attrs_record.push_back(attrs[1]); attrs_record.push_back(String(::tvm::SaveJSON(attrs[2]))); + attrs_record.push_back(attrs[3]); return std::move(attrs_record); } @@ -575,6 +582,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits attrs.push_back(attrs_record[0]); attrs.push_back(attrs_record[1]); attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); + attrs.push_back(attrs_record[3]); return attrs; } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 091db344aadb..d3bf99d783dd 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -248,9 +248,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") /******** (FFI) Layout transformation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - int buffer_index_type, const IndexMap& index_map) { + int buffer_index_type, const IndexMap& index_map, + const Optional& pad_value) { return self->TransformLayout(block_rv, buffer_index, - static_cast(buffer_index_type), index_map); + static_cast(buffer_index_type), index_map, + pad_value); }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") .set_body_method(&ScheduleNode::TransformBlockLayout); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 04ddc0507dc4..340b614dd7f5 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -487,14 +487,17 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) { - ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map); + const IndexMap& index_map, + const Optional& pad_value) { + ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map, + pad_value); static const InstructionKind& kind = InstructionKind::Get("TransformLayout"); trace_->Append( - /*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), index_map}, - /*outputs=*/{})); + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), index_map, pad_value}, + /*outputs=*/{})); } void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index d98e4ba4bb95..8ba1120df667 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -103,7 +103,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Unannotate(const BlockRV& block_rv, const String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) override; + const IndexMap& index_map, const Optional& pad_value) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, From 885fd78ed8fb99217099600dee2cdab06f2d366e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 Aug 2022 12:38:39 -0500 Subject: [PATCH 04/26] [TIR] Added check to validate lack of transformation padding --- .../primitive/layout_transformation.cc | 48 +++++++++++++++++++ .../test_tir_schedule_transform_layout.py | 1 - 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 8a20cc8e97a8..59eed9429199 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -132,6 +132,40 @@ class BufferIsSubregionError : public ScheduleError { Buffer buffer_; }; +class TransformationIntroducesPaddingError : public ScheduleError { + public: + TransformationIntroducesPaddingError(IRModule mod, Buffer buffer, IndexMap index_map, + PrimExpr padding_predicate) + : mod_(std::move(mod)), + buffer_(std::move(buffer)), + index_map_(std::move(index_map)), + padding_predicate_(std::move(padding_predicate)) {} + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: Transformation would introduce padding at " << padding_predicate_ << "."; + return ss.str(); + } + + String DetailRenderTemplate() const final { + auto new_shape = index_map_->MapShape(buffer_->shape); + std::ostringstream os; + os << "The transformation " << index_map_ << " applied on buffer " << buffer_->name + << " of shape " << buffer_->shape << " would result in shape " << new_shape + << ". However, this would introduce padding wherever " << padding_predicate_ << " is true."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + Buffer buffer_; + IndexMap index_map_; + PrimExpr padding_predicate_; +}; + void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, const Optional& pad_value) { @@ -153,6 +187,20 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ new_buffer_node->shape = index_map->MapShape(old_buffer->shape); Buffer new_buffer{new_buffer_node}; + // Step 1.1: Validate that padding hasn't been introduced. + auto [inverse, padding_predicate] = [&]() { + Array region; + for (const auto& dim : old_buffer->shape) { + region.push_back(Range::FromMinExtent(0, dim)); + } + return index_map.NonSurjectiveInverse(region); + }(); + + bool has_padding = !is_zero(padding_predicate); + if (has_padding && !pad_value.defined()) { + throw TransformationIntroducesPaddingError(self->mod, old_buffer, index_map, padding_predicate); + } + // Step 2: Rewrite access indices and regions of the buffer auto [new_stmt, block_sref_reuse] = TransformLayoutRewriter::Rewrite( GetRef(scope_block), old_buffer, new_buffer, index_map); diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 5c56771c13cd..c8c543862f35 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -418,7 +418,6 @@ def expected(): A[vi // 4, vi % 4] = 0 -@pytest.mark.xfail(reason="Not implemented yet") class TestErrorIfPaddingForbidden(BasePaddingCompare): """Unless padding is explicitly enabled, should raise error""" From 185eead8e298958f3a02d568ad846594d1fca5b8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 26 Aug 2022 13:39:00 -0500 Subject: [PATCH 05/26] Raise error if pad value doesn't match buffer's data type. --- .../primitive/layout_transformation.cc | 31 +++++++++++++++++++ .../test_tir_schedule_transform_layout.py | 1 - 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 59eed9429199..76fb691aa8e5 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -132,6 +132,34 @@ class BufferIsSubregionError : public ScheduleError { Buffer buffer_; }; +class TransformationPaddingTypeError : public ScheduleError { + public: + TransformationPaddingTypeError(IRModule mod, Buffer buffer, PrimExpr pad_value) + : mod_(mod), buffer_(buffer), pad_value_(pad_value) {} + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_->dtype; + return ss.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream ss; + ss << "ScheduleError: Buffer " << buffer_->name << " has elements of type " << buffer_->dtype + << ", but the transformation fills padding with " << pad_value_ << ", which is of type " + << pad_value_->dtype; + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + Buffer buffer_; + PrimExpr pad_value_; +}; + class TransformationIntroducesPaddingError : public ScheduleError { public: TransformationIntroducesPaddingError(IRModule mod, Buffer buffer, IndexMap index_map, @@ -176,6 +204,9 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); } + if (pad_value && pad_value.value()->dtype != old_buffer->dtype) { + throw TransformationPaddingTypeError(self->mod, old_buffer, pad_value.value()); + } StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index c8c543862f35..a47c697a7a57 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -430,7 +430,6 @@ def before(): expected = tvm.tir.schedule.schedule.ScheduleError -@pytest.mark.xfail(reason="Not implemented yet") class TestErrorOnWrongPaddingType(BasePaddingCompare): """The padding must have the same dtype as the buffer""" From 874bfc2789bc1f60a76a7cdb8d922eb5e7e14ed8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 12:06:34 -0500 Subject: [PATCH 06/26] Simplify expresions in IndexMap::NonsurjectiveInverse --- src/tir/ir/index_map.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 0e3c3b2774c8..3a72b198599f 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -90,7 +90,7 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; for (const auto& index : (*this)->initial_indices) { - inverse_exprs.push_back(inverse_exprs_map.at(index)); + inverse_exprs.push_back(analyzer.Simplify(inverse_exprs_map.at(index))); } PrimExpr padding_predicate = padded_iter_map->padding_predicate; From ddea093f99f89d4f4beadecd5fbfea2b438fc420 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 12:13:40 -0500 Subject: [PATCH 07/26] Preparatory refactor, update BlockNode::alloc_buffers while visiting --- .../primitive/layout_transformation.cc | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 76fb691aa8e5..7411bf080f17 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -97,6 +97,13 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { auto* n = block.CopyOnWrite(); RewriteAccessRegion(&n->reads, infered_access_regions[0]); RewriteAccessRegion(&n->writes, infered_access_regions[1]); + n->alloc_buffers.MutateByApply([this](const Buffer& buffer) { + if (buffer.same_as(old_buffer_)) { + return new_buffer_; + } else { + return buffer; + } + }); block_sref_reuse_.Set(GetRef(op), block); return std::move(block); } @@ -197,6 +204,7 @@ class TransformationIntroducesPaddingError : public ScheduleError { void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, const Optional& pad_value) { + // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); @@ -213,12 +221,6 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - // Step 1: Infer the shape of the new buffer - ObjectPtr new_buffer_node = make_object(*(old_buffer.get())); - new_buffer_node->shape = index_map->MapShape(old_buffer->shape); - Buffer new_buffer{new_buffer_node}; - - // Step 1.1: Validate that padding hasn't been introduced. auto [inverse, padding_predicate] = [&]() { Array region; for (const auto& dim : old_buffer->shape) { @@ -232,22 +234,19 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ throw TransformationIntroducesPaddingError(self->mod, old_buffer, index_map, padding_predicate); } - // Step 2: Rewrite access indices and regions of the buffer - auto [new_stmt, block_sref_reuse] = TransformLayoutRewriter::Rewrite( - GetRef(scope_block), old_buffer, new_buffer, index_map); + // Step 2: Infer the shape of the new buffer + Buffer new_buffer = old_buffer; + new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape); + + // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block + // alloc_buffers. + auto [new_stmt, block_sref_reuse] = + TransformLayoutRewriter::Rewrite(GetRef(scope_block), old_buffer, new_buffer, + index_map, inverse, padding_predicate, pad_value); Block new_scope_block = Downcast(new_stmt); - // Step 3: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. - if (defining_site_sref.defined()) { - auto* n = new_scope_block.CopyOnWrite(); - n->alloc_buffers.MutateByApply([&old_buffer, &new_buffer](const Buffer& buffer) { - if (buffer.same_as(old_buffer)) { - return new_buffer; - } - return buffer; - }); - block_sref_reuse.Set(GetRef(scope_block), new_scope_block); - } else { + // Step 4: Rewrite buffer_map of the PrimFunc if necessary. + if (!defining_site_sref.defined()) { GlobalVar g_var; GetRootPrimFunc(self->mod, scope_block, &g_var); IRModuleNode* new_mod = self->mod.CopyOnWrite(); From 2055bbf8044c8aa85e35cffe1b1d7758cc1ae02a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 12:21:51 -0500 Subject: [PATCH 08/26] Introduced LayoutTransformPlanner for planning how to pad --- .../primitive/layout_transformation.cc | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 7411bf080f17..e96e0486e4e2 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -16,12 +16,39 @@ * specific language governing permissions and limitations * under the License. */ + +#include +#include + #include "../../../arith/ir_mutator_with_analyzer.h" #include "../utils.h" namespace tvm { namespace tir { +class LayoutTransformPlanner : private StmtExprVisitor { + public: + struct NoPaddingRequired {}; + + using TransformPlan = std::variant; + static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, + IndexMap inverse, PrimExpr padding_predicate, + Optional pad_value) { + LayoutTransformPlanner visitor(old_buffer); + visitor(block); + return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value); + } + + private: + LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} + TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, + PrimExpr padding_predicate, Optional pad_value) const { + return NoPaddingRequired(); + } + + Buffer old_buffer_; +}; + class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { public: /*! @@ -33,23 +60,29 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { * \return The new AST rooting at the original parent scope and the map from the old block to the * new block */ - static std::pair> Rewrite(const Stmt& scope_stmt, - const Buffer& old_buffer, - const Buffer& new_buffer, - const IndexMap& index_map) { + static std::pair> Rewrite( + const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, + const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& padding_predicate, + const Optional& pad_value) { + auto plan = LayoutTransformPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse, + padding_predicate, pad_value); + arith::Analyzer analyzer; - TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, &analyzer); - Stmt result = rewriter(scope_stmt); + TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); + Block result = Downcast(rewriter(scope_stmt)); return {result, rewriter.block_sref_reuse_}; } private: TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer, - const IndexMap& index_map, arith::Analyzer* analyzer) + const IndexMap& index_map, + const LayoutTransformPlanner::TransformPlan& plan, + arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer), old_buffer_(old_buffer), new_buffer_(new_buffer), index_map_(index_map), + plan_(plan), buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {} void RewriteBufferAccess(Buffer* buffer, Array* indices) { @@ -111,6 +144,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const Buffer& old_buffer_; const Buffer& new_buffer_; const IndexMap& index_map_; + const LayoutTransformPlanner::TransformPlan& plan_; Map buffer_data_to_buffer_; Map block_sref_reuse_; }; From f3538cdb0b7c04d7006cd2cf8f36dfa397acb814 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 12:27:46 -0500 Subject: [PATCH 09/26] Implemented insertion of T.assume for input buffers --- .../primitive/layout_transformation.cc | 65 ++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index e96e0486e4e2..b783f7d48bd9 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -28,9 +28,14 @@ namespace tir { class LayoutTransformPlanner : private StmtExprVisitor { public: + // Statement to be inserted prior to the analyzed block + struct ProloguePlan { + Stmt prologue; + }; + struct NoPaddingRequired {}; - using TransformPlan = std::variant; + using TransformPlan = std::variant; static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) { @@ -41,11 +46,65 @@ class LayoutTransformPlanner : private StmtExprVisitor { private: LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} + void VisitStmt_(const BufferStoreNode* op) override { + if (!op->buffer.same_as(old_buffer_)) { + return; + } + + WriteInfo write_info; + write_info.store = GetRef(op); + + write_info_.push_back(write_info); + + // Don't need to continue recursing, as the entire goal was to + // find the BufferStore. + } TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) const { + if (auto prologue_plan = + FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); + prologue_plan.has_value()) { + return prologue_plan.value(); + } else { return NoPaddingRequired(); + } + } + + std::optional FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, + IndexMap inverse, PrimExpr padding_predicate, + Optional pad_value) const { + if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { + return std::nullopt; + } + + Array indices; + for (const auto& var : inverse->initial_indices) { + indices.push_back(var); + } + + PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value.value()); + Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr})); + + std::stringstream block_name; + block_name << "buffer_" << new_buffer->name << "_assumptions"; + auto read_region = BufferRegion::FromPoint(new_buffer, indices); + stmt = BlockRealize({}, Bool(true), Block({}, {read_region}, {}, block_name.str(), stmt)); + + for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { + size_t i = (inverse->initial_indices.size() - 1) - rev_i; + Var loop_var = inverse->initial_indices[i]; + PrimExpr extent = new_buffer->shape[i]; + stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); + } + return ProloguePlan{stmt}; } + struct WriteInfo { + // The BufferStore object + BufferStore store; + }; + + std::vector write_info_; Buffer old_buffer_; }; @@ -70,6 +129,10 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { arith::Analyzer analyzer; TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); Block result = Downcast(rewriter(scope_stmt)); + if (auto plan_ptr = std::get_if(&plan)) { + auto write_ptr = result.CopyOnWrite(); + write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); + } return {result, rewriter.block_sref_reuse_}; } From a6dbd308bda20615f5b89631b4adc191ffb79618 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 12:34:57 -0500 Subject: [PATCH 10/26] Implement epilogue plan for explicitly setting pad value --- .../primitive/layout_transformation.cc | 164 +++++++++++++++++- 1 file changed, 163 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index b783f7d48bd9..58b7dd6061a7 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -33,9 +33,18 @@ class LayoutTransformPlanner : private StmtExprVisitor { Stmt prologue; }; + + // The block to be inserted, along with the location at which it + // should be inserted. The location will be either a For or a + // Block, and will be after all writes the transformed buffer. + struct EpiloguePlan { + Stmt insert_after; + Stmt new_block; + }; + struct NoPaddingRequired {}; - using TransformPlan = std::variant; + using TransformPlan = std::variant; static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) { @@ -46,25 +55,79 @@ class LayoutTransformPlanner : private StmtExprVisitor { private: LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} + + void VisitStmt_(const ForNode* op) override { + BindLoopVar context(this, GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const BlockRealizeNode* op) override { + BindBlockRealize context(this, GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const BufferStoreNode* op) override { if (!op->buffer.same_as(old_buffer_)) { return; } + std::optional> loop_dependency_range = std::nullopt; + for (const auto& index : op->indices) { + if (auto index_depth = LoopDependencyRange(index); index_depth.has_value()) { + if (loop_dependency_range) { + loop_dependency_range = { + std::min(loop_dependency_range.value().first, index_depth.value().first), + std::max(loop_dependency_range.value().second, index_depth.value().second)}; + } else { + loop_dependency_range = index_depth; + } + } + } + WriteInfo write_info; write_info.store = GetRef(op); + if (loop_dependency_range) { + size_t i = loop_dependency_range.value().first; + size_t j = loop_dependency_range.value().second; + ICHECK_LT(i, active_loops_.size()); + ICHECK_LT(j, active_loops_.size()); + + write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1}; + } + write_info.innermost_block_realize = innermost_block_realize_; + write_info_.push_back(write_info); // Don't need to continue recursing, as the entire goal was to // find the BufferStore. } + + std::optional> LoopDependencyRange(const PrimExpr& expr) const { + std::optional> prev = std::nullopt; + for (const auto& var : UndefinedVars(expr)) { + auto it = loop_depth_lookup_.find(var.get()); + if (it != loop_depth_lookup_.end()) { + if (prev.has_value()) { + prev = {std::min(prev.value().first, it->second.first), + std::max(prev.value().second, it->second.second)}; + } else { + prev = it->second; + } + } + } + + return prev; + } TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) const { if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); prologue_plan.has_value()) { return prologue_plan.value(); + } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse, + padding_predicate, pad_value); + epilogue_plan.has_value()) { + return epilogue_plan.value(); } else { return NoPaddingRequired(); } @@ -99,12 +162,101 @@ class LayoutTransformPlanner : private StmtExprVisitor { return ProloguePlan{stmt}; } + std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, + IndexMap inverse, PrimExpr padding_predicate, + Optional pad_value) const { + if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { + return std::nullopt; + } + + Array indices; + for (const auto& var : inverse->initial_indices) { + indices.push_back(var); + } + + Stmt stmt = BufferStore(new_buffer, pad_value.value(), indices); + + std::stringstream block_name; + block_name << "buffer_" << new_buffer->name << "_padding"; + auto write_region = BufferRegion::FromPoint(new_buffer, indices); + stmt = + BlockRealize({}, padding_predicate, Block({}, {}, {write_region}, block_name.str(), stmt)); + + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { + size_t i = (inverse->initial_indices.size() - 1) - rev_i; + Var loop_var = inverse->initial_indices[i]; + PrimExpr extent = new_buffer->shape[i]; + stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); + } + + const auto& info = write_info_.back(); + Stmt insert_after = [&]() -> Stmt { + if (info.dependent_loopnest.size()) { + return info.dependent_loopnest.front(); + } else if (info.innermost_block_realize) { + return info.innermost_block_realize.value(); + } else { + LOG(FATAL) << "Write occured outside of any block/loop"; + return Stmt(); + } + }(); + return EpiloguePlan{insert_after, stmt}; + } + + struct BindLoopVar { + BindLoopVar(LayoutTransformPlanner* self, For for_node) + : self_(self), var_(for_node->loop_var) { + size_t loop_depth = self_->active_loops_.size(); + self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth}; + self_->active_loops_.push_back(std::move(for_node)); + } + ~BindLoopVar() { + self_->active_loops_.pop_back(); + self_->loop_depth_lookup_.erase(var_.get()); + } + BindLoopVar(const BindLoopVar&) = delete; + BindLoopVar& operator=(const BindLoopVar&) = delete; + BindLoopVar(BindLoopVar&&) = delete; + BindLoopVar& operator=(BindLoopVar&&) = delete; + + LayoutTransformPlanner* self_{nullptr}; + Var var_; + }; + + struct BindBlockRealize { + BindBlockRealize(LayoutTransformPlanner* self, BlockRealize block_realize) : self_(self) { + cache_ = std::move(block_realize); + std::swap(self_->innermost_block_realize_, cache_); + } + ~BindBlockRealize() { std::swap(self_->innermost_block_realize_, cache_); } + BindBlockRealize(const BindBlockRealize&) = delete; + BindBlockRealize& operator=(const BindBlockRealize&) = delete; + BindBlockRealize(BindBlockRealize&&) = delete; + BindBlockRealize& operator=(BindBlockRealize&&) = delete; + + LayoutTransformPlanner* self_{nullptr}; + Optional cache_; + }; + struct WriteInfo { // The BufferStore object BufferStore store; + + // The block realize that contains the store, if any. + Optional innermost_block_realize; + + // The nested loops whose values contribute to the indices used in + // the store. Not all loop variables in the loopnest need to + // contribute, but the first and last must. + std::vector dependent_loopnest; }; std::vector write_info_; + std::vector active_loops_; + std::unordered_map> loop_depth_lookup_; + Optional innermost_block_realize_{NullOpt}; + Buffer old_buffer_; }; @@ -157,6 +309,16 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { using Parent::VisitExpr_; using Parent::VisitStmt_; + Stmt VisitStmt(const Stmt& stmt) final { + Stmt output = Parent::VisitStmt(stmt); + if (auto plan_ptr = std::get_if(&plan_)) { + if (plan_ptr->insert_after.same_as(stmt)) { + return SeqStmt({output, plan_ptr->new_block}); + } + } + return output; + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad buffer_load = Downcast(Parent::VisitExpr_(op)); if (buffer_load->buffer.same_as(old_buffer_)) { From 619c5b7562bef0202e5e6e613aec36bbd257f316 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 13:15:53 -0500 Subject: [PATCH 11/26] Check LetStmt bindings when determining loop dependencies --- .../primitive/layout_transformation.cc | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 58b7dd6061a7..09d433c90298 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -60,6 +60,12 @@ class LayoutTransformPlanner : private StmtExprVisitor { BindLoopVar context(this, GetRef(op)); StmtExprVisitor::VisitStmt_(op); } + + void VisitStmt_(const LetStmtNode* op) override { + BindLetVar context(this, op->var, op->value); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const BlockRealizeNode* op) override { BindBlockRealize context(this, GetRef(op)); StmtExprVisitor::VisitStmt_(op); @@ -224,6 +230,34 @@ class LayoutTransformPlanner : private StmtExprVisitor { Var var_; }; + struct BindLetVar { + BindLetVar() {} + BindLetVar(LayoutTransformPlanner* self, Var var, PrimExpr value) : self_(self), var_(var) { + if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) { + self_->loop_depth_lookup_[var_.get()] = loop_depth.value(); + } + } + ~BindLetVar() { + if (self_) { + self_->loop_depth_lookup_.erase(var_.get()); + } + } + BindLetVar(const BindLetVar&) = delete; + BindLetVar& operator=(const BindLetVar&) = delete; + BindLetVar(BindLetVar&& other) : BindLetVar() { swap(other); } + BindLetVar& operator=(BindLetVar&& other) { + swap(other); + return *this; + } + void swap(BindLetVar& other) { + std::swap(self_, other.self_); + std::swap(var_, other.var_); + } + + LayoutTransformPlanner* self_{nullptr}; + Var var_; + }; + struct BindBlockRealize { BindBlockRealize(LayoutTransformPlanner* self, BlockRealize block_realize) : self_(self) { cache_ = std::move(block_realize); From aa9bbf7b4ec7db8ceb4c67b3b7ee360e622caa86 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 13:18:09 -0500 Subject: [PATCH 12/26] Implement replacement plan for using tir::if_then_else For producer blocks that iterate over the pre-transformation shape, rewrite to iterate over the post-transformation shape, with `tir::if_then_else` to handle writing to indices corresponding to padding/non-padding. --- .../primitive/layout_transformation.cc | 288 +++++++++++++++++- 1 file changed, 287 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 09d433c90298..80cea729aed4 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -33,6 +33,11 @@ class LayoutTransformPlanner : private StmtExprVisitor { Stmt prologue; }; + // Loops within the analyzed block that should be replaced + struct ReplacementPlan { + Map replacements; + Map block_sref_reuse; + }; // The block to be inserted, along with the location at which it // should be inserted. The location will be either a For or a @@ -44,7 +49,9 @@ class LayoutTransformPlanner : private StmtExprVisitor { struct NoPaddingRequired {}; - using TransformPlan = std::variant; + using TransformPlan = + std::variant; + static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) { @@ -101,6 +108,30 @@ class LayoutTransformPlanner : private StmtExprVisitor { } write_info.innermost_block_realize = innermost_block_realize_; + write_info.contains_row_major_traversal = [&]() -> bool { + const auto& loopnest = write_info.dependent_loopnest; + if (loopnest.empty()) { + return false; + } + + if (loopnest.size() != old_buffer_->shape.size() || loopnest.size() != op->indices.size()) { + return false; + } + + for (size_t i = 0; i < loopnest.size(); i++) { + const For& loop = loopnest[i]; + const PrimExpr& buffer_dim = old_buffer_->shape[i]; + PrimExpr index = Substitute(op->indices[i], active_let_bindings_); + bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) && + ExprDeepEqual()(loop->extent, buffer_dim) && + loop->kind == ForKind::kSerial; + if (!is_loop_over_axis) { + return false; + } + } + + return true; + }(); write_info_.push_back(write_info); @@ -124,12 +155,48 @@ class LayoutTransformPlanner : private StmtExprVisitor { return prev; } + + class BufferStoreReplacer : public StmtExprMutator { + public: + BufferStoreReplacer(std::function(const BufferStoreNode*)> replace_store, + std::function(const BlockRealizeNode*, const BlockRealize&)> + replace_block_realize) + : replace_store_(replace_store), replace_block_realize_(replace_block_realize) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + if (auto replacement = replace_store_(op)) { + auto store = Downcast(replacement.value()); + return StmtExprMutator::VisitStmt_(store.get()); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + auto realize = Downcast(StmtExprMutator::VisitStmt_(op)); + if (auto replacement = replace_block_realize_(op, realize)) { + return replacement.value(); + } else { + return std::move(realize); + } + } + + private: + std::function(const BufferStoreNode*)> replace_store_; + std::function(const BlockRealizeNode*, const BlockRealize&)> + replace_block_realize_; + }; + TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) const { if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); prologue_plan.has_value()) { return prologue_plan.value(); + } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse, + padding_predicate, pad_value); + replacement_plan.has_value()) { + return replacement_plan.value(); } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); epilogue_plan.has_value()) { @@ -168,6 +235,193 @@ class LayoutTransformPlanner : private StmtExprVisitor { return ProloguePlan{stmt}; } + std::optional FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map, + IndexMap inverse, + PrimExpr padding_predicate, + Optional pad_value) const { + if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { + return std::nullopt; + } + + auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional { + if (!info.contains_row_major_traversal || !pad_value.defined() || + is_zero(padding_predicate)) { + return NullOpt; + } + + Array old_indices = info.store->indices; + PrimExpr if_then_else_condition = padding_predicate; + Array new_indices; + for (const auto& var : inverse->initial_indices) { + new_indices.push_back(var); + } + + auto replace_block_realize = + [&]() -> std::function(const BlockRealizeNode*, const BlockRealize&)> { + auto no_change = [](const BlockRealizeNode*, const BlockRealize&) -> Optional { + return NullOpt; + }; + if (!info.innermost_block_realize) { + return no_change; + } + if (old_indices.empty()) { + return no_change; + } + + BlockRealize block_realize = info.innermost_block_realize.value(); + const auto& block = block_realize->block; + + // Find the block iterators that are used to access the buffer. Must be in the same order + // as they appear in the indices. + if (block->iter_vars.size() < old_indices.size()) { + return no_change; + } + const auto& iter_vars = block->iter_vars; + size_t block_index_start = 0; + for (; block_index_start < iter_vars.size() - old_indices.size(); block_index_start++) { + if (old_indices[0].same_as(iter_vars[block_index_start]->var)) { + break; + } + } + if (block_index_start >= iter_vars.size() - old_indices.size()) { + return no_change; + } + + for (size_t i = 0; i < old_indices.size(); i++) { + if (!old_indices[i].same_as(iter_vars[block_index_start + i]->var) || + iter_vars[block_index_start + i]->iter_type != kDataPar) { + return no_change; + } + } + + // If we got to this point, all indices used to access the + // buffer are virtual indices defined in the innermost block. + // Therefore, generate new virtual indices for iterating over + // the post-transform buffer. + Array new_iter_values; // For BlockRealize + Array new_iter_vars; // For Block + Array new_access_indices; // For BufferStore + Map loop_var_to_virtual_var; // For updating if_then_else_condition + + for (size_t i = 0; i < block_index_start; i++) { + new_iter_vars.push_back(iter_vars[i]); + new_iter_values.push_back(block_realize->iter_values[i]); + } + + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + Var var = inverse->initial_indices[i]; + PrimExpr dim = new_buffer->shape[i]; + std::stringstream ss; + ss << "v_" << var->name_hint; + Var virtual_var(ss.str(), var.dtype()); + new_iter_values.push_back(var); + new_iter_vars.push_back(IterVar(Range::FromMinExtent(0, dim), virtual_var, kDataPar)); + new_access_indices.push_back(virtual_var); + loop_var_to_virtual_var.Set(var, virtual_var); + } + + for (size_t i = block_index_start + old_indices.size(); i < iter_vars.size(); i++) { + new_iter_vars.push_back(iter_vars[i]); + new_iter_values.push_back(block_realize->iter_values[i]); + } + + Map old_virtual_var_to_new_virtual_var; + ICHECK_EQ(inverse->final_indices.size(), old_indices.size()); + for (size_t i = 0; i < old_indices.size(); i++) { + Var var = Downcast(old_indices[i]); + PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var); + old_virtual_var_to_new_virtual_var.Set(var, expr); + } + + if_then_else_condition = Substitute(if_then_else_condition, loop_var_to_virtual_var); + new_indices = new_access_indices; + + return [target_realize = info.innermost_block_realize, new_iter_vars, new_iter_values, + old_virtual_var_to_new_virtual_var](const BlockRealizeNode* op, + const BlockRealize& visited) -> Optional { + if (op == target_realize.get()) { + Block block = visited->block; + block = + Downcast(Substitute(std::move(block), old_virtual_var_to_new_virtual_var)); + block.CopyOnWrite()->iter_vars = new_iter_vars; + + BlockRealize realize = visited; + { + auto write_ptr = realize.CopyOnWrite(); + write_ptr->block = block; + write_ptr->iter_values = new_iter_values; + } + return realize; + } else { + return NullOpt; + } + }; + }(); + + bool all_stores_replaced = true; + auto replace_store = [&](const BufferStoreNode* op) -> Optional { + if (!op->buffer.same_as(info.store->buffer)) { + all_stores_replaced = false; + return NullOpt; + } + ICHECK_EQ(old_indices.size(), op->indices.size()); + ExprDeepEqual expr_equal; + for (size_t i = 0; i < old_indices.size(); i++) { + if (!expr_equal(old_indices[i], op->indices[i])) { + all_stores_replaced = false; + return NullOpt; + } + } + + return BufferStore(new_buffer, + if_then_else(if_then_else_condition, pad_value.value(), op->value), + new_indices); + }; + + BufferStoreReplacer replacer(replace_store, replace_block_realize); + Stmt stmt = replacer(info.dependent_loopnest.back()->body); + if (!all_stores_replaced) { + return NullOpt; + } + + std::unordered_map var_remap; + ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size()); + for (size_t i = 0; i < info.dependent_loopnest.size(); i++) { + Var var = info.dependent_loopnest[i]->loop_var; + PrimExpr expr = inverse->final_indices[i]; + var_remap[var.get()] = expr; + } + stmt = Substitute(std::move(stmt), var_remap); + + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { + size_t i = (inverse->initial_indices.size() - 1) - rev_i; + Var loop_var = inverse->initial_indices[i]; + PrimExpr extent = new_buffer->shape[i]; + stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); + } + + return stmt; + }; + + Map loop_replacements; + + for (const auto& info : write_info_) { + if (info.dependent_loopnest.size()) { + if (auto opt_stmt = generate_if_then_else_block(info)) { + loop_replacements.Set(info.dependent_loopnest[0], opt_stmt.value()); + } + } + } + + if (loop_replacements.size()) { + return ReplacementPlan{std::move(loop_replacements)}; + } else { + return std::nullopt; + } + } + std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) const { @@ -235,11 +489,13 @@ class LayoutTransformPlanner : private StmtExprVisitor { BindLetVar(LayoutTransformPlanner* self, Var var, PrimExpr value) : self_(self), var_(var) { if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) { self_->loop_depth_lookup_[var_.get()] = loop_depth.value(); + self_->active_let_bindings_[var_.get()] = Substitute(value, self_->active_let_bindings_); } } ~BindLetVar() { if (self_) { self_->loop_depth_lookup_.erase(var_.get()); + self_->active_let_bindings_.erase(var_.get()); } } BindLetVar(const BindLetVar&) = delete; @@ -260,6 +516,11 @@ class LayoutTransformPlanner : private StmtExprVisitor { struct BindBlockRealize { BindBlockRealize(LayoutTransformPlanner* self, BlockRealize block_realize) : self_(self) { + ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size()); + for (size_t i = 0; i < block_realize->iter_values.size(); i++) { + bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var, + block_realize->iter_values[i]); + } cache_ = std::move(block_realize); std::swap(self_->innermost_block_realize_, cache_); } @@ -271,6 +532,7 @@ class LayoutTransformPlanner : private StmtExprVisitor { LayoutTransformPlanner* self_{nullptr}; Optional cache_; + std::vector bound_vars_; }; struct WriteInfo { @@ -284,11 +546,20 @@ class LayoutTransformPlanner : private StmtExprVisitor { // the store. Not all loop variables in the loopnest need to // contribute, but the first and last must. std::vector dependent_loopnest; + + // Whether the padding could be represented as a tir::if_then_else + // node. This requires that the surrounding loop iterators + // iterate over all pre-transformation buffer axes, that there are + // no data dependencies between loop iterations, and that + bool contains_row_major_traversal{false}; }; + struct LoopEntry {}; + std::vector write_info_; std::vector active_loops_; std::unordered_map> loop_depth_lookup_; + std::unordered_map active_let_bindings_; Optional innermost_block_realize_{NullOpt}; Buffer old_buffer_; @@ -353,6 +624,21 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { return output; } + Stmt VisitStmt_(const ForNode* op) final { + // Some replacements may include the original string, such as + // replacing `loop` with `{loop, post_proc}`. In this case, avoid + // infinite recursion. + + For node = GetRef(op); + if (auto plan_ptr = std::get_if(&plan_)) { + auto it = plan_ptr->replacements.find(node); + if (it != plan_ptr->replacements.end()) { + return VisitStmt((*it).second); + } + } + return Parent::VisitStmt_(op); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad buffer_load = Downcast(Parent::VisitExpr_(op)); if (buffer_load->buffer.same_as(old_buffer_)) { From 7f5707cbf4265c903690268c648a79864d125b4d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 16:22:04 -0500 Subject: [PATCH 13/26] Removed development-only unit test --- .../test_tir_schedule_transform_layout.py | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index a47c697a7a57..827f01f40c0e 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -444,35 +444,6 @@ def before(): expected = tvm.tir.schedule.schedule.ScheduleError -@pytest.mark.xfail(reason="Superceded by TestPaddedTransformIfThenElse") -class TestPaddedTransformPostProc(BasePaddingCompare): - """Set the transformation padding in a post-processing block. - - This test is incompatible with TestPaddedTransformIfThenElse, and - is here for initial development purposes. - """ - - pad_value = tvm.testing.parameter(0) - transformed_buffer = tvm.testing.parameter("B") - - def before(A: T.Buffer[14, "int32"]): - B = T.alloc_buffer(14, "int32") - for i in T.serial(14): - with T.block("block"): - B[i] = A[i] - - def expected(A: T.Buffer[14, "int32"]): - B = T.alloc_buffer([4, 4], "int32") - for i in T.serial(14): - with T.block("block"): - B[i // 4, i % 4] = A[i] - - for i, j in T.grid(4, 4): - with T.block("buffer_B_padding"): - T.where(i == 3 and 2 <= j) - B[i, j] = 0 - - class TestPaddedTransformIfThenElse(BasePaddingCompare): """Use if_then_else to represent padding, if possible. @@ -482,10 +453,6 @@ class TestPaddedTransformIfThenElse(BasePaddingCompare): transform the loop iterators to be a row-major traversal of the post-transformation buffer, with padding represented by `T.if_then_else`. - - This test is incompatible with TestPaddedTransformPostProc. This - is the long-term intended method to be supported, with - TestPaddedTransformPostProc present for development purposes. """ pad_value = tvm.testing.parameter(0) From 5a1e63fe5f571d2d69c094874d0dbc135b1c4b7c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 08:44:26 -0500 Subject: [PATCH 14/26] Add default value of NullOpt for pad_value --- include/tvm/tir/schedule/schedule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index d497faca3a8f..901189e23d21 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -605,7 +605,7 @@ class ScheduleNode : public runtime::Object { */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value) = 0; + const Optional& pad_value = NullOpt) = 0; /*! * \brief Apply a transformation represented by IndexMap to block From c46304307913324b9342dae6d574f16e846d3140 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 08:45:53 -0500 Subject: [PATCH 15/26] Removed debug code --- tests/python/unittest/test_tir_schedule_transform_layout.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 827f01f40c0e..f11afc8fac74 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -341,7 +341,6 @@ def transform(mod): sch.transform_layout( "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value ) - # sch.transform_block_layout("block", lambda i: [i // 4, i % 4]) return sch.mod return transform From 98a84469e1a69c9299de3fd18e645f25a8cb1b1e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 11:09:27 -0500 Subject: [PATCH 16/26] Update unit tests to use non-opaque blocks Unless specifically testing opaque blocks, all unit tests for the transform layout scheduling primitive now operate on non-opaque blocks. --- .../primitive/layout_transformation.cc | 41 +++++- .../test_tir_schedule_transform_layout.py | 139 +++++++++--------- 2 files changed, 99 insertions(+), 81 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 80cea729aed4..2e351da781bd 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -213,10 +213,22 @@ class LayoutTransformPlanner : private StmtExprVisitor { return std::nullopt; } + Array iter_vars; + Array iter_values; Array indices; - for (const auto& var : inverse->initial_indices) { - indices.push_back(var); + Map loop_indices_to_block_indices; + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + const auto& loop_var = inverse->initial_indices[i]; + const auto& dim = new_buffer->shape[i]; + Var block_var("v_" + loop_var->name_hint, loop_var->dtype); + IterVar iter_var(Range(0, dim), block_var, kDataPar); + loop_indices_to_block_indices.Set(loop_var, block_var); + indices.push_back(iter_var->var); + iter_vars.push_back(iter_var); + iter_values.push_back(loop_var); } + padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value.value()); Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr})); @@ -224,7 +236,8 @@ class LayoutTransformPlanner : private StmtExprVisitor { std::stringstream block_name; block_name << "buffer_" << new_buffer->name << "_assumptions"; auto read_region = BufferRegion::FromPoint(new_buffer, indices); - stmt = BlockRealize({}, Bool(true), Block({}, {read_region}, {}, block_name.str(), stmt)); + stmt = BlockRealize(iter_values, Bool(true), + Block(iter_vars, {read_region}, {}, block_name.str(), stmt)); for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { size_t i = (inverse->initial_indices.size() - 1) - rev_i; @@ -283,7 +296,7 @@ class LayoutTransformPlanner : private StmtExprVisitor { break; } } - if (block_index_start >= iter_vars.size() - old_indices.size()) { + if (block_index_start > iter_vars.size() - old_indices.size()) { return no_change; } @@ -429,18 +442,30 @@ class LayoutTransformPlanner : private StmtExprVisitor { return std::nullopt; } + Array iter_vars; + Array iter_values; Array indices; - for (const auto& var : inverse->initial_indices) { - indices.push_back(var); + Map loop_indices_to_block_indices; + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + const auto& loop_var = inverse->initial_indices[i]; + const auto& dim = new_buffer->shape[i]; + Var block_var("v_" + loop_var->name_hint, loop_var->dtype); + IterVar iter_var(Range(0, dim), block_var, kDataPar); + loop_indices_to_block_indices.Set(loop_var, block_var); + indices.push_back(iter_var->var); + iter_vars.push_back(iter_var); + iter_values.push_back(loop_var); } + padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); Stmt stmt = BufferStore(new_buffer, pad_value.value(), indices); std::stringstream block_name; block_name << "buffer_" << new_buffer->name << "_padding"; auto write_region = BufferRegion::FromPoint(new_buffer, indices); - stmt = - BlockRealize({}, padding_predicate, Block({}, {}, {write_region}, block_name.str(), stmt)); + stmt = BlockRealize(iter_values, padding_predicate, + Block(iter_vars, {}, {write_region}, block_name.str(), stmt)); ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index f11afc8fac74..1812b8c34d57 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -355,13 +355,15 @@ def before(): A = T.alloc_buffer(16, "int32") for i in T.serial(16): with T.block("block"): - A[i] = 0 + vi = T.axis.remap("S", [i]) + A[vi] = 0 def expected(): A = T.alloc_buffer([4, 4], "int32") for i in T.serial(16): with T.block("block"): - A[i // 4, i % 4] = 0 + vi = T.axis.remap("S", [i]) + A[vi // 4, vi % 4] = 0 class TestNoPaddingMultipleUsage(BasePaddingCompare): @@ -378,27 +380,34 @@ def before(): A = T.alloc_buffer(16, "int32") for i in T.serial(16): with T.block("block"): - A[i] = 0 + vi = T.axis.remap("S", [i]) + A[vi] = 0 B = T.alloc_buffer(16, "int32") for i in T.serial(16): with T.block("other"): - B[i] = A[i] + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] def expected(): A = T.alloc_buffer([4, 4], "int32") for i in T.serial(16): with T.block("block"): - A[i // 4, i % 4] = 0 + vi = T.axis.remap("S", [i]) + A[vi // 4, vi % 4] = 0 B = T.alloc_buffer(16, "int32") for i in T.serial(16): with T.block("other"): - B[i] = A[i // 4, i % 4] + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // 4, vi % 4] + +class TestNoPaddingOpaqueBlock(BasePaddingCompare): + """Transformations without padding do not depend on pad_value. -class TestNoPaddingVirtualIndex(BasePaddingCompare): - """Like TestNoPadding, but accessed through block indices.""" + Like TestNoPadding, but buffer access is done in an opaque block. + """ pad_value = tvm.testing.parameter(None, 42) @@ -406,15 +415,13 @@ def before(): A = T.alloc_buffer(16, "int32") for i in T.serial(16): with T.block("block"): - vi = T.axis.remap("S", [i]) - A[vi] = 0 + A[i] = 0 def expected(): A = T.alloc_buffer([4, 4], "int32") for i in T.serial(16): with T.block("block"): - vi = T.axis.remap("S", [i]) - A[vi // 4, vi % 4] = 0 + A[i // 4, i % 4] = 0 class TestErrorIfPaddingForbidden(BasePaddingCompare): @@ -424,7 +431,8 @@ def before(): A = T.alloc_buffer(14, "int32") for i in T.serial(14): with T.block("block"): - A[i] = 0 + vi = T.axis.remap("S", [i]) + A[vi] = 0 expected = tvm.tir.schedule.schedule.ScheduleError @@ -438,7 +446,8 @@ def before(): A = T.alloc_buffer(14, "int32") for i in T.serial(14): with T.block("block"): - A[i] = 0 + vi = T.axis.remap("S", [i]) + A[vi] = 0 expected = tvm.tir.schedule.schedule.ScheduleError @@ -461,13 +470,15 @@ def before(A: T.Buffer[14, "int32"]): B = T.alloc_buffer(14, "int32") for i in T.serial(14): with T.block("block"): - B[i] = A[i] + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] def expected(A: T.Buffer[14, "int32"]): B = T.alloc_buffer([4, 4], "int32") for i, j in T.grid(4, 4): with T.block("block"): - B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, A[i * 4 + j], dtype="int32") + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, A[vi * 4 + vj], dtype="int32") class TestPaddedTransformWithoutLoop(BasePaddingCompare): @@ -492,8 +503,9 @@ def expected(A: T.Buffer[(4, 4), "int32"]): for i, j in T.grid(4, 4): with T.block("buffer_A_padding"): - T.where(i == 3 and 2 <= j) - A[i, j] = 0 + vi, vj = T.axis.remap("SS", [i, j]) + T.where(vi == 3 and 2 <= vj) + A[vi, vj] = 0 class TestPaddedTransformIfThenElseReduction(BasePaddingCompare): @@ -502,75 +514,50 @@ class TestPaddedTransformIfThenElseReduction(BasePaddingCompare): pad_value = tvm.testing.parameter(0) transformed_buffer = tvm.testing.parameter("B") - def before(A: T.Buffer[(14, 32), "int32"]): - B = T.alloc_buffer(14, "int32") - for i in T.serial(14): - B[i] = 0 - for k in T.serial(32): - with T.block("block"): - B[i] = B[i] + A[i, k] - - def expected(A: T.Buffer[(14, 32), "int32"]): - B = T.alloc_buffer([4, 4], "int32") - for i, j in T.grid(4, 4): - B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 0, dtype="int32") - for k in T.serial(32): - with T.block("block"): - B[i, j] = T.if_then_else( - i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32" - ) - - -class TestPaddedTransformIfThenElseReductionBlock(BasePaddingCompare): - """Like TestPaddedTransformIfThenElse, but with a reduction axis""" - - pad_value = tvm.testing.parameter(0) - transformed_buffer = tvm.testing.parameter("B") - def before(A: T.Buffer[(14, 32), "int32"]): B = T.alloc_buffer(14, "int32") for i, k in T.grid(14, 32): with T.block("block"): + vi, vk = T.axis.remap("SR", [i, k]) with T.init(): - B[i] = 0 - B[i] = B[i] + A[i, k] + B[vi] = 0 + B[vi] = B[vi] + A[vi, vk] def expected(A: T.Buffer[(14, 32), "int32"]): B = T.alloc_buffer([4, 4], "int32") for i, j, k in T.grid(4, 4, 32): with T.block("block"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): - B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 0, dtype="int32") - B[i, j] = T.if_then_else( - i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32" + B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 0, dtype="int32") + B[vi, vj] = T.if_then_else( + vi == 3 and 2 <= vj, 0, B[vi, vj] + A[vi * 4 + vj, vk], dtype="int32" ) -class TestPaddedTransformIfThenElseReductionBlockVirtualAxes(BasePaddingCompare): - """Like TestPaddedTransformIfThenElse, but with a reduction axis""" +class TestPaddedTransformIfThenElseReductionOpaque(BasePaddingCompare): + """Like TestPaddedTransformIfThenElseReduction, but with opaque blocks""" pad_value = tvm.testing.parameter(0) transformed_buffer = tvm.testing.parameter("B") def before(A: T.Buffer[(14, 32), "int32"]): B = T.alloc_buffer(14, "int32") - for i, k in T.grid(14, 32): - with T.block("block"): - vi, vk = T.axis.remap("SR", [i, k]) - with T.init(): - B[vi] = 0 - B[vi] = B[vi] + A[vi, vk] + for i in T.serial(14): + B[i] = 0 + for k in T.serial(32): + with T.block("block"): + B[i] = B[i] + A[i, k] def expected(A: T.Buffer[(14, 32), "int32"]): B = T.alloc_buffer([4, 4], "int32") - for i, j, k in T.grid(4, 4, 32): - with T.block("block"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 0, dtype="int32") - B[vi, vj] = T.if_then_else( - vi == 3 and 2 <= vj, 0, B[vi, vj] + A[vi * 4 + vj, vk], dtype="int32" - ) + for i, j in T.grid(4, 4): + B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 0, dtype="int32") + for k in T.serial(32): + with T.block("block"): + B[i, j] = T.if_then_else( + i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32" + ) class TestPaddedTransformPostProcIfRequiredDueToSideEffects(BasePaddingCompare): @@ -588,21 +575,24 @@ def before(A: T.Buffer[14, "int32"]): C = T.alloc_buffer(14, "int32") for i in T.serial(14): with T.block("block"): - B[i] = A[i] - C[i] = 0 + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + C[vi] = 0 def expected(A: T.Buffer[14, "int32"]): B = T.alloc_buffer([4, 4], "int32") C = T.alloc_buffer(14, "int32") for i in T.serial(14): with T.block("block"): - B[i // 4, i % 4] = A[i] - C[i] = 0 + vi = T.axis.remap("S", [i]) + B[vi // 4, vi % 4] = A[vi] + C[vi] = 0 for i, j in T.grid(4, 4): with T.block("block_pad_B"): - T.where(i == 3 and 2 <= j) - B[i, j] = 0 + vi, vj = T.axis.remap("SS", [i, j]) + T.where(vi == 3 and 2 <= vj) + B[vi, vj] = 0 class TestPaddedTransformOfInputCreatesAssumption(BasePaddingCompare): @@ -613,16 +603,19 @@ class TestPaddedTransformOfInputCreatesAssumption(BasePaddingCompare): def before(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]): for i in T.serial(14): with T.block("block"): - B[i] = A[i] + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] def expected(A: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): for i, j in T.grid(4, 4): with T.block("buffer_A_assumption"): - T.assume(not (i == 3 and 2 <= j) or A[i, j] == 42) + vi, vj = T.axis.remap("SS", [i, j]) + T.assume(not (vi == 3 and 2 <= vj) or A[vi, vj] == 42) for i in T.serial(14): with T.block("block"): - B[i] = A[i // 4, i % 4] + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // 4, vi % 4] if __name__ == "__main__": From 19af1ee4f55ff6243ce4393075cc474d61b815e9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 11:47:34 -0500 Subject: [PATCH 17/26] Resolve linting error --- src/tir/schedule/primitive/layout_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 2e351da781bd..676981e328be 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -61,7 +61,7 @@ class LayoutTransformPlanner : private StmtExprVisitor { } private: - LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} + explicit LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} void VisitStmt_(const ForNode* op) override { BindLoopVar context(this, GetRef(op)); From 8eb775a9b20ecb5cd787f3c9189c7fc18e9944e2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 14:45:35 -0500 Subject: [PATCH 18/26] Removed iter_var usage from T.where clauses --- src/tir/schedule/primitive/layout_transformation.cc | 3 --- tests/python/unittest/test_tir_schedule_transform_layout.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 676981e328be..989ac338f326 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -445,19 +445,16 @@ class LayoutTransformPlanner : private StmtExprVisitor { Array iter_vars; Array iter_values; Array indices; - Map loop_indices_to_block_indices; ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { const auto& loop_var = inverse->initial_indices[i]; const auto& dim = new_buffer->shape[i]; Var block_var("v_" + loop_var->name_hint, loop_var->dtype); IterVar iter_var(Range(0, dim), block_var, kDataPar); - loop_indices_to_block_indices.Set(loop_var, block_var); indices.push_back(iter_var->var); iter_vars.push_back(iter_var); iter_values.push_back(loop_var); } - padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); Stmt stmt = BufferStore(new_buffer, pad_value.value(), indices); diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 1812b8c34d57..abf167a8ad1a 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -504,7 +504,7 @@ def expected(A: T.Buffer[(4, 4), "int32"]): for i, j in T.grid(4, 4): with T.block("buffer_A_padding"): vi, vj = T.axis.remap("SS", [i, j]) - T.where(vi == 3 and 2 <= vj) + T.where(i == 3 and 2 <= j) A[vi, vj] = 0 @@ -591,7 +591,7 @@ def expected(A: T.Buffer[14, "int32"]): for i, j in T.grid(4, 4): with T.block("block_pad_B"): vi, vj = T.axis.remap("SS", [i, j]) - T.where(vi == 3 and 2 <= vj) + T.where(i == 3 and 2 <= j) B[vi, vj] = 0 From e874020ba681655fdf5e665badbc0b226f26fa40 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 16:36:26 -0500 Subject: [PATCH 19/26] Improve docstring on pad_value Specifically calling attention to how `pad_value` interacts with input buffers, that correctness depends on the calling scope providing the specified `pad_value`. --- include/tvm/tir/schedule/schedule.h | 15 ++++++++++++++- python/tvm/tir/schedule/schedule.py | 13 ++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 901189e23d21..547f9153b953 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -601,7 +601,20 @@ class ScheduleNode : public runtime::Object { * \param buffer_index The index of the buffer in block's read or write region. * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \param index_map The transformation to apply. - * \param pad_value The value to write into padding introduced by the transformation. + * + * \param pad_value The value to write into padding introduced by + * the transformation. If the schedule contains a producer block + * for the specified buffer, the pad value will be written as + * part of the producer block if possible, or after the producer + * block otherwise. Otherwise, if the buffer is an input, will + * insert an annotation block to state that the padding contains + * the known value. + * + * Note: If applied to an input buffer, the calling scope is + * responsible for ensuring that the pad_value is present. + * Algebraic symplifications, branch elimination, and other + * optimizations may assume that this precondition is met, and + * may result in incorrect results being returned. */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 35b4a97dda04..d1f685d96c74 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2483,7 +2483,18 @@ def transform_layout( pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] The value to be used for any padding introduced by the - transformation. + transformation. If the schedule contains a producer block + for the specified buffer, the pad value will be written as + part of the producer block if possible, or after the producer + block otherwise. Otherwise, if the buffer is an input, will + insert an annotation block to state that the padding contains + the known value. + + Note: If applied to an input buffer, the calling scope is + responsible for ensuring that the pad_value is present. + Algebraic symplifications, branch elimination, and other + optimizations may assume that this precondition is met, and + may result in incorrect results being returned. If None, the transformation may not introduce padding. From 6386db593407df6bebaf20f39fc5807da7d2891e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 17:28:58 -0500 Subject: [PATCH 20/26] Documentation for TransformLayoutPlanner, rename for consistency The previous name `LayoutTransformPlanner` didn't follow the pattern of `TransformLayoutWriter`. Therefore, renaming to `TransformLayoutPlanner`. --- .../primitive/layout_transformation.cc | 127 +++++++++++++----- 1 file changed, 96 insertions(+), 31 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 989ac338f326..107bdffde928 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -26,7 +26,44 @@ namespace tvm { namespace tir { -class LayoutTransformPlanner : private StmtExprVisitor { +/*! \brief Planning stage prior to rewriting in TransformLayoutRewriter + * + * There are four ways that transformation may be handled. Each + * updates the buffer shape and the indices used to acces the buffer + * in BufferStore/BufferLoad nodes, but differ in how they handle the + * `pad_value`. In order of preference, the different strategies are + * as follows: + * + * 1. NoPaddingRequired. The transformation does not introduce + * padding, so only local changes to update the indices of + * BufferLoad/BufferStore nodes are required. No blocks are added, + * removed, or replaced. + * + * 2. ProloguePlan. The transformation introduces padding, but the + * analyzed block has no write stages for the transformed buffer. + * This buffer is an input and the caller is responsible for ensuring + * that the padding contains the specified `pad_value`. The generated + * prologue contains `builtin::assume()` calls that will expose this + * known value during scheduling/simplification, but will be removed + * during lowering. + * + * 3. ReplacementPlan. The transformation introduces padding, has at + * least one write stage for the transformed buffer, and at least one + * of those write stages writes to all pre-transformation indices + * following a row-major traversal. These write stage is rewritten to + * be row-major traversals of the post-transformation indices, with a + * `tir::if_then_else` call to write either the specified `pad_value` + * into padding or the computed value into non-padding. + * + * 4. EpiloguePlan. The transformation introduces padding, has at + * least one write stage for the transformed buffer, but no write + * stage can be rewritten to use `tir::if_then_else`. The + * transformation still requires the `pad_value` to be written into + * the padding, so a new block is inserted after the last write stage + * to explicitly fill the padding. + * + */ +class TransformLayoutPlanner : private StmtExprVisitor { public: // Statement to be inserted prior to the analyzed block struct ProloguePlan { @@ -55,13 +92,13 @@ class LayoutTransformPlanner : private StmtExprVisitor { static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, Optional pad_value) { - LayoutTransformPlanner visitor(old_buffer); + TransformLayoutPlanner visitor(old_buffer); visitor(block); return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value); } private: - explicit LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} + explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} void VisitStmt_(const ForNode* op) override { BindLoopVar context(this, GetRef(op)); @@ -69,7 +106,7 @@ class LayoutTransformPlanner : private StmtExprVisitor { } void VisitStmt_(const LetStmtNode* op) override { - BindLetVar context(this, op->var, op->value); + BindVariableDefinition context(this, op->var, op->value); StmtExprVisitor::VisitStmt_(op); } @@ -121,7 +158,7 @@ class LayoutTransformPlanner : private StmtExprVisitor { for (size_t i = 0; i < loopnest.size(); i++) { const For& loop = loopnest[i]; const PrimExpr& buffer_dim = old_buffer_->shape[i]; - PrimExpr index = Substitute(op->indices[i], active_let_bindings_); + PrimExpr index = Substitute(op->indices[i], active_var_bindings_); bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) && ExprDeepEqual()(loop->extent, buffer_dim) && loop->kind == ForKind::kSerial; @@ -487,7 +524,7 @@ class LayoutTransformPlanner : private StmtExprVisitor { } struct BindLoopVar { - BindLoopVar(LayoutTransformPlanner* self, For for_node) + BindLoopVar(TransformLayoutPlanner* self, For for_node) : self_(self), var_(for_node->loop_var) { size_t loop_depth = self_->active_loops_.size(); self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth}; @@ -502,42 +539,45 @@ class LayoutTransformPlanner : private StmtExprVisitor { BindLoopVar(BindLoopVar&&) = delete; BindLoopVar& operator=(BindLoopVar&&) = delete; - LayoutTransformPlanner* self_{nullptr}; + TransformLayoutPlanner* self_{nullptr}; Var var_; }; - struct BindLetVar { - BindLetVar() {} - BindLetVar(LayoutTransformPlanner* self, Var var, PrimExpr value) : self_(self), var_(var) { + struct BindVariableDefinition { + BindVariableDefinition() {} + BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value) + : self_(self), var_(var) { if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) { self_->loop_depth_lookup_[var_.get()] = loop_depth.value(); - self_->active_let_bindings_[var_.get()] = Substitute(value, self_->active_let_bindings_); + self_->active_var_bindings_[var_.get()] = Substitute(value, self_->active_var_bindings_); } } - ~BindLetVar() { + ~BindVariableDefinition() { if (self_) { self_->loop_depth_lookup_.erase(var_.get()); - self_->active_let_bindings_.erase(var_.get()); + self_->active_var_bindings_.erase(var_.get()); } } - BindLetVar(const BindLetVar&) = delete; - BindLetVar& operator=(const BindLetVar&) = delete; - BindLetVar(BindLetVar&& other) : BindLetVar() { swap(other); } - BindLetVar& operator=(BindLetVar&& other) { + BindVariableDefinition(const BindVariableDefinition&) = delete; + BindVariableDefinition& operator=(const BindVariableDefinition&) = delete; + BindVariableDefinition(BindVariableDefinition&& other) : BindVariableDefinition() { + swap(other); + } + BindVariableDefinition& operator=(BindVariableDefinition&& other) { swap(other); return *this; } - void swap(BindLetVar& other) { + void swap(BindVariableDefinition& other) { std::swap(self_, other.self_); std::swap(var_, other.var_); } - LayoutTransformPlanner* self_{nullptr}; + TransformLayoutPlanner* self_{nullptr}; Var var_; }; struct BindBlockRealize { - BindBlockRealize(LayoutTransformPlanner* self, BlockRealize block_realize) : self_(self) { + BindBlockRealize(TransformLayoutPlanner* self, BlockRealize block_realize) : self_(self) { ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size()); for (size_t i = 0; i < block_realize->iter_values.size(); i++) { bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var, @@ -552,9 +592,9 @@ class LayoutTransformPlanner : private StmtExprVisitor { BindBlockRealize(BindBlockRealize&&) = delete; BindBlockRealize& operator=(BindBlockRealize&&) = delete; - LayoutTransformPlanner* self_{nullptr}; + TransformLayoutPlanner* self_{nullptr}; Optional cache_; - std::vector bound_vars_; + std::vector bound_vars_; }; struct WriteInfo { @@ -576,14 +616,39 @@ class LayoutTransformPlanner : private StmtExprVisitor { bool contains_row_major_traversal{false}; }; - struct LoopEntry {}; - + /*! \brief Collected information about each BufferStore */ std::vector write_info_; + + /*! \brief The loop iterators surrounding the current node + * + * The outermost loop iterator is `active_loops_.front()`, and the + * innermost loop iterator is `active_loops_.back()`. + * + * Used to fill the `WriteInfo::dependent_loopnest` field. + */ std::vector active_loops_; + + /*! \brief Lookup for the outer/inner loops + * + * Used to fill the `WriteInfo::dependent_loopnest` field. + */ std::unordered_map> loop_depth_lookup_; - std::unordered_map active_let_bindings_; + + /*! \brief The variable mappings that are currently in-scope + * + * Used to determine whether the indices of a BufferStore are a + * row-major traversal, even if they are rebound in let/block + * mappings. + */ + std::unordered_map active_var_bindings_; + + /*! \brief The innermost BlockRealize surrounding the current node + * + * Used to fill the `WriteInfo::innermost_block_realize` field.. + */ Optional innermost_block_realize_{NullOpt}; + /*! \brief The buffer to be replaced */ Buffer old_buffer_; }; @@ -602,13 +667,13 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& padding_predicate, const Optional& pad_value) { - auto plan = LayoutTransformPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse, + auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse, padding_predicate, pad_value); arith::Analyzer analyzer; TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); Block result = Downcast(rewriter(scope_stmt)); - if (auto plan_ptr = std::get_if(&plan)) { + if (auto plan_ptr = std::get_if(&plan)) { auto write_ptr = result.CopyOnWrite(); write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); } @@ -618,7 +683,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { private: TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer, const IndexMap& index_map, - const LayoutTransformPlanner::TransformPlan& plan, + const TransformLayoutPlanner::TransformPlan& plan, arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer), old_buffer_(old_buffer), @@ -638,7 +703,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { Stmt VisitStmt(const Stmt& stmt) final { Stmt output = Parent::VisitStmt(stmt); - if (auto plan_ptr = std::get_if(&plan_)) { + if (auto plan_ptr = std::get_if(&plan_)) { if (plan_ptr->insert_after.same_as(stmt)) { return SeqStmt({output, plan_ptr->new_block}); } @@ -652,7 +717,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { // infinite recursion. For node = GetRef(op); - if (auto plan_ptr = std::get_if(&plan_)) { + if (auto plan_ptr = std::get_if(&plan_)) { auto it = plan_ptr->replacements.find(node); if (it != plan_ptr->replacements.end()) { return VisitStmt((*it).second); @@ -711,7 +776,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const Buffer& old_buffer_; const Buffer& new_buffer_; const IndexMap& index_map_; - const LayoutTransformPlanner::TransformPlan& plan_; + const TransformLayoutPlanner::TransformPlan& plan_; Map buffer_data_to_buffer_; Map block_sref_reuse_; }; From 59a0acf699b7b54d5adffe533d1a8b0799f3695d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Sep 2022 10:37:50 -0500 Subject: [PATCH 21/26] Updated C++ API to take IndexMap as input --- include/tvm/tir/schedule/schedule.h | 2 +- python/tvm/tir/function.py | 32 +++++--- python/tvm/tir/schedule/schedule.py | 9 +++ src/tir/schedule/concrete_schedule.cc | 2 +- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/primitive.h | 2 +- .../primitive/layout_transformation.cc | 81 ++++++++++++++----- src/tir/schedule/schedule.cc | 2 +- src/tir/schedule/traced_schedule.cc | 2 +- src/tir/schedule/traced_schedule.h | 2 +- .../test_tir_schedule_transform_layout.py | 38 +++++++++ 11 files changed, 136 insertions(+), 38 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 547f9153b953..c07671c3ca49 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -618,7 +618,7 @@ class ScheduleNode : public runtime::Object { */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value = NullOpt) = 0; + const Optional& pad_value = NullOpt) = 0; /*! * \brief Apply a transformation represented by IndexMap to block diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 12c8053e39cc..4c6cdb08013d 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -389,17 +389,27 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = final_indices = [] axis_separators = [] - for val in mapping: - if isinstance(val, tvm.ir.PrimExpr): - final_indices.append(val) - elif val is IndexMap.AXIS_SEPARATOR: - axis_separators.append(len(final_indices)) - else: - raise TypeError( - "Expected mapping function to return list of " - "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " - f"Instead received {val} of type {type(val)}." - ) + + try: + iter(mapping) + is_iterable = True + except TypeError: + is_iterable = False + + if is_iterable: + for val in mapping: + if isinstance(val, tvm.ir.PrimExpr): + final_indices.append(val) + elif val is IndexMap.AXIS_SEPARATOR: + axis_separators.append(len(final_indices)) + else: + raise TypeError( + "Expected mapping function to return list of " + "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " + f"Instead received {val} of type {type(val)}." + ) + else: + final_indices.append(mapping) return IndexMap(initial_indices, final_indices), axis_separators diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d1f685d96c74..9c79795845d5 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2562,6 +2562,15 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> else: axis_separators = [] + if pad_value is None: + pass + elif callable(pad_value): + pad_value = IndexMap.from_func(pad_value, ndim=len(index_map.final_indices)) + elif not isinstance(pad_value, IndexMap): + pad_value = IndexMap.from_func( + lambda *indices: pad_value, ndim=len(index_map.final_indices) + ) + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member self, block, buffer_index, buffer_index_type_enum, index_map, pad_value diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 4c8271a45f9f..a12aac0867ce 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -762,7 +762,7 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value) { + const Optional& pad_value) { TVM_TIR_SCHEDULE_BEGIN(); tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map, pad_value); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index e92d2aa35ac5..8d992f790044 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -143,7 +143,7 @@ class ConcreteScheduleNode : public ScheduleNode { void Unannotate(const BlockRV& block_rv, const String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const Optional& pad_value) override; + const IndexMap& index_map, const Optional& pad_value) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 280c57808f7d..96bb14e4b1da 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -478,7 +478,7 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& */ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value); + const Optional& pad_value); /*! * \brief Apply a transformation represented by IndexMap to block diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 107bdffde928..516b9b8581a2 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -91,7 +91,9 @@ class TransformLayoutPlanner : private StmtExprVisitor { static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) { + Optional pad_value) { + ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) + << "Internal error: Should be caught by ScheduleError checks prior to this point"; TransformLayoutPlanner visitor(old_buffer); visitor(block); return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value); @@ -225,7 +227,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { }; TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, - PrimExpr padding_predicate, Optional pad_value) const { + PrimExpr padding_predicate, Optional pad_value) const { if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); prologue_plan.has_value()) { @@ -245,7 +247,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) const { + Optional pad_value) const { if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -267,7 +269,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { } padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); - PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value.value()); + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; + PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value_at_index); Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr})); std::stringstream block_name; @@ -288,7 +291,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) const { + Optional pad_value) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -424,8 +427,9 @@ class TransformLayoutPlanner : private StmtExprVisitor { } } + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0]; return BufferStore(new_buffer, - if_then_else(if_then_else_condition, pad_value.value(), op->value), + if_then_else(if_then_else_condition, pad_value_at_index, op->value), new_indices); }; @@ -474,7 +478,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) const { + Optional pad_value) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -493,7 +497,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { iter_values.push_back(loop_var); } - Stmt stmt = BufferStore(new_buffer, pad_value.value(), indices); + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; + Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices); std::stringstream block_name; block_name << "buffer_" << new_buffer->name << "_padding"; @@ -666,7 +671,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { static std::pair> Rewrite( const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& padding_predicate, - const Optional& pad_value) { + const Optional& pad_value) { auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse, padding_predicate, pad_value); @@ -805,14 +810,44 @@ class BufferIsSubregionError : public ScheduleError { Buffer buffer_; }; +class TransformationPaddingIndexMapError : public ScheduleError { + public: + TransformationPaddingIndexMapError(IRModule mod, IndexMap pad_value) + : mod_(mod), pad_value_(pad_value) {} + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: The IndexMap specifying pad_value has " + << pad_value_->final_indices.size() << " outputs, should only have one output"; + return ss.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream ss; + ss << "ScheduleError: Pad value is specified as " << pad_value_ << " which has " + << pad_value_->final_indices.size() << " outputs, but should only have one output"; + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + IndexMap pad_value_; +}; + class TransformationPaddingTypeError : public ScheduleError { public: - TransformationPaddingTypeError(IRModule mod, Buffer buffer, PrimExpr pad_value) - : mod_(mod), buffer_(buffer), pad_value_(pad_value) {} + TransformationPaddingTypeError(IRModule mod, Buffer buffer, IndexMap pad_value) + : mod_(mod), buffer_(buffer), pad_value_(pad_value) { + ICHECK_EQ(pad_value_->final_indices.size(), 1); + pad_value_dtype_ = pad_value_->final_indices[0].dtype(); + } String FastErrorString() const final { std::ostringstream ss; - ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_->dtype; + ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_dtype_; return ss.str(); } @@ -820,7 +855,7 @@ class TransformationPaddingTypeError : public ScheduleError { std::ostringstream ss; ss << "ScheduleError: Buffer " << buffer_->name << " has elements of type " << buffer_->dtype << ", but the transformation fills padding with " << pad_value_ << ", which is of type " - << pad_value_->dtype; + << pad_value_dtype_; return ss.str(); } @@ -830,7 +865,8 @@ class TransformationPaddingTypeError : public ScheduleError { private: IRModule mod_; Buffer buffer_; - PrimExpr pad_value_; + IndexMap pad_value_; + DataType pad_value_dtype_; }; class TransformationIntroducesPaddingError : public ScheduleError { @@ -869,7 +905,7 @@ class TransformationIntroducesPaddingError : public ScheduleError { void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value) { + const Optional& pad_value) { // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = @@ -878,8 +914,13 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); } - if (pad_value && pad_value.value()->dtype != old_buffer->dtype) { - throw TransformationPaddingTypeError(self->mod, old_buffer, pad_value.value()); + if (pad_value) { + if (pad_value.value()->final_indices.size() != 1) { + throw TransformationPaddingIndexMapError(self->mod, pad_value.value()); + } + if (pad_value.value()->final_indices[0]->dtype != old_buffer->dtype) { + throw TransformationPaddingTypeError(self->mod, old_buffer, pad_value.value()); + } } StmtSRef scope_sref = defining_site_sref.defined() @@ -1286,7 +1327,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, Integer buffer_index_type, IndexMap index_map, - Optional pad_value) { + Optional pad_value) { return sch->TransformLayout(block_rv, buffer_index.IntValue(), static_cast(buffer_index_type->value), index_map, pad_value); @@ -1294,7 +1335,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, Integer buffer_index_type, IndexMap index_map, - Optional pad_value) { + Optional pad_value) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); @@ -1304,7 +1345,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits py.Input("buffer", os.str()); py.Input("index_map", index_map->ToPythonString()); - py.Input("pad_value", pad_value); + py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : "None"); return py.Str(); } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index d3bf99d783dd..69b9a2d0e952 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -249,7 +249,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map, - const Optional& pad_value) { + const Optional& pad_value) { return self->TransformLayout(block_rv, buffer_index, static_cast(buffer_index_type), index_map, pad_value); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 340b614dd7f5..bbb1a239190e 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -488,7 +488,7 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value) { + const Optional& pad_value) { ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map, pad_value); static const InstructionKind& kind = InstructionKind::Get("TransformLayout"); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 8ba1120df667..8e99b76d390f 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -103,7 +103,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Unannotate(const BlockRV& block_rv, const String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const Optional& pad_value) override; + const IndexMap& index_map, const Optional& pad_value) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index abf167a8ad1a..2421cd451b2b 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -618,5 +618,43 @@ def expected(A: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): B[vi] = A[vi // 4, vi % 4] +class TestPaddedTransformNonConstantValue(tvm.testing.CompareBeforeAfter): + """Allow an expression to specify the pad value. + + Like TestPaddedTransformIfThenElse, but the pad value depends on + the indices. + """ + + @pytest.fixture + def transform(self): + def transform(mod): + sch = tir.Schedule(mod) + sch.transform_layout( + "block", + "B", + lambda i: [i // 4, i % 4], + pad_value=lambda i, j: i + j, + ) + return sch.mod + + return transform + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + def expected(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j in T.grid(4, 4): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.if_then_else( + vi == 3 and 2 <= vj, vi + vj, A[vi * 4 + vj], dtype="int32" + ) + + if __name__ == "__main__": tvm.testing.main() From d53261026863e3ed7459bf9c4c1a01e6f1317c4c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Sep 2022 11:04:31 -0500 Subject: [PATCH 22/26] Update expected trace in multi-level tiling tests --- ...hedule_schedule_rule_multi_level_tiling.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index fe1220c50925..a1cd4a328337 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -600,9 +600,9 @@ def test_cuda_tensor_core_matmul_relu(): b2 = sch.reindex(block=b0, buffer=("write", 0)) b3 = sch.reindex(block=b0, buffer=("read", 0)) b4 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ), pad_value=None) sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, )) @@ -740,9 +740,9 @@ def test_cuda_tensor_core_software_pipeline_matmul_relu(): b2 = sch.reindex(block=b0, buffer=("write", 0)) b3 = sch.reindex(block=b0, buffer=("read", 0)) b4 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ), pad_value=None) sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, )) @@ -863,9 +863,9 @@ def test_cuda_tensor_core_matmul_relu_global(): b1 = sch.reindex(block=b0, buffer=("write", 0)) b2 = sch.reindex(block=b0, buffer=("read", 0)) b3 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ), pad_value=None) sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) @@ -963,9 +963,9 @@ def test_cuda_tensor_core_matmul_relu_global(): b1 = sch.reindex(block=b0, buffer=("write", 0)) b2 = sch.reindex(block=b0, buffer=("read", 0)) b3 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (j, k, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (j, k, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ), pad_value=None) sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) @@ -1099,9 +1099,9 @@ def test_cuda_tensor_core_conv2d(): b1 = sch.reindex(block=b0, buffer=("write", 0)) b2 = sch.reindex(block=b0, buffer=("read", 0)) b3 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda h, w, rh, rw, rc: (((h*16) + w), (((rh*96) + (rw*32)) + rc), )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda h, w, rh, rw, rc: (((h*16) + w), (((rh*96) + (rw*32)) + rc), ), pad_value=None) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, ), pad_value=None) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, ), pad_value=None) sch.transform_block_layout(block=b1, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) sch.transform_block_layout(block=b2, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) sch.transform_block_layout(block=b3, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) From efb25ac01673213941a2cc48b358cb22760733e6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 13 Sep 2022 10:44:29 -0500 Subject: [PATCH 23/26] Update shared_32x16_to_ldmatrix_32x16_layout to be injective Previous version mapped the 512 input indices in a `(32,16)` array to only 128 output indices. This wasn't caught before, because the bijectivity assertion was only triggered for TE schedules. --- python/tvm/tir/tensor_intrin/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 64d7c24840ae..a309b091285b 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j): def shared_32x16_to_ldmatrix_32x16_layout(i, j): - thread_id = (i % 4) + 4 * (j % 8) + thread_id = (i % 16) // 4 + 4 * (j % 8) return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 From 13b8cef32eaa803401ed29cf9854dedccedbf763 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 15 Sep 2022 12:21:18 -0500 Subject: [PATCH 24/26] Updated IndexMap docstrings for single PrimExpr returns --- python/tvm/tir/function.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 4c6cdb08013d..36812f9fd722 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -294,8 +294,9 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): The function to map from source indices to target indices. The function should accept `tir.Var` parameters and return - a list. Each element of the returned list should be a - `tir.PrimExpr`. + a either a `tir.PrimExpr`, or a list of `tir.PrimExpr`. + Returning a `tir.PrimExpr` is equivalent to returning a + list of length 1 containing that `tir.PrimExpr`. ndim: Optional[int] @@ -329,9 +330,12 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = mapping_function : Callable The function to map from source indices to target indices. - The function should accept tir.Var parameters and return a - list. Each element of the returned list should be either a - `tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`. + The function should accept tir.Var parameters and return + either a `tir.PrimExpr` or a list. Each element of the + returned list should be either a `tir.PrimExpr` or the + object `IndexMap.AXIS_SEPARATOR`. Returning a + `tir.PrimExpr` is equivalent to returning a list of length + 1 containing that `tir.PrimExpr`. ndim: Optional[int] From 19a78e8cfd33a792c4bb03260b31d1931401a752 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Sep 2022 08:43:29 -0500 Subject: [PATCH 25/26] Updated docstring for valid pad values, validate --- python/tvm/tir/schedule/schedule.py | 5 ++ .../primitive/layout_transformation.cc | 57 +++++++++++++ .../test_tir_schedule_transform_layout.py | 83 +++++++++++++++++++ 3 files changed, 145 insertions(+) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 9c79795845d5..2acbd6399daa 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2490,6 +2490,11 @@ def transform_layout( insert an annotation block to state that the padding contains the known value. + The pad value may not contain instances of BufferLoad, + except where it loads a value from the buffer being + transformed (e.g. to create a circular buffer with + padding that consists of repeated elements). + Note: If applied to an input buffer, the calling scope is responsible for ensuring that the pad_value is present. Algebraic symplifications, branch elimination, and other diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 516b9b8581a2..0b16db357c70 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -869,6 +869,61 @@ class TransformationPaddingTypeError : public ScheduleError { DataType pad_value_dtype_; }; +class TransformationPaddingExpressionError : public ScheduleError { + public: + static void Check(IRModule mod, Buffer buffer, IndexMap pad_value) { + Visitor visitor(buffer); + ICHECK_EQ(pad_value->final_indices.size(), 1) + << "Internal error: Should be caught by ScheduleError checks prior to this point"; + visitor(pad_value->final_indices[0]); + if (visitor.illegal_load) { + throw TransformationPaddingExpressionError(mod, buffer, pad_value, + visitor.illegal_load.value()); + } + } + + private: + struct Visitor : ExprVisitor { + Visitor(const Buffer& buffer) : buffer_(buffer) {} + + void VisitExpr_(const BufferLoadNode* op) final { + if (!op->buffer.same_as(buffer_)) { + illegal_load = GetRef(op); + } + ExprVisitor::VisitExpr_(op); + } + + const Buffer& buffer_; + Optional illegal_load; + }; + + TransformationPaddingExpressionError(IRModule mod, Buffer buffer, IndexMap pad_value, + BufferLoad illegal_load) + : mod_(mod), buffer_(buffer), pad_value_(pad_value), illegal_load_(illegal_load) {} + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: Pad value may not contain load load from " << illegal_load_->buffer->name; + return ss.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream ss; + ss << "ScheduleError: Pad value may only contain BufferLoad from the transformed buffer " + << buffer_->name << ", but pad_value " << pad_value_ << " contains expression " + << illegal_load_; + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + Buffer buffer_; + IndexMap pad_value_; + BufferLoad illegal_load_; +}; + class TransformationIntroducesPaddingError : public ScheduleError { public: TransformationIntroducesPaddingError(IRModule mod, Buffer buffer, IndexMap index_map, @@ -921,6 +976,8 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ if (pad_value.value()->final_indices[0]->dtype != old_buffer->dtype) { throw TransformationPaddingTypeError(self->mod, old_buffer, pad_value.value()); } + + TransformationPaddingExpressionError::Check(self->mod, old_buffer, pad_value.value()); } StmtSRef scope_sref = defining_site_sref.defined() diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 2421cd451b2b..8ed350cc4c46 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -656,5 +656,88 @@ def expected(A: T.Buffer[14, "int32"]): ) +@pytest.mark.xfail(reason="Not yet implemented") +class TestPaddedTransformRepeatedBufferElement(tvm.testing.CompareBeforeAfter): + """Allow an expression to specify the pad value. + + Like TestPaddedTransformOfInputCreatesAssumption, but the pad + value depends on another portion of the buffer. In this case, the + padding at the end of A contains repeated elements from the + beginning of A. + """ + + @pytest.fixture + def transform(self): + def transform(mod): + sch = tir.Schedule(mod) + + A = sch.get(sch.get_block("block")).reads[0].buffer + sch.transform_layout( + "block", + "A", + lambda i: [i // 4, i % 4], + pad_value=lambda i, j: A[(4 * i + j) % 14], + ) + return sch.mod + + return transform + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + def expected(A: T.Buffer[(4, 4), "int32"]): + for i, j in T.grid(4, 4): + with T.block("buffer_A_assumption"): + vi, vj = T.axis.remap("SS", [i, j]) + T.assume( + not (vi == 3 and 2 <= vj) + or A[vi, vj] == A[((4 * vi + j) % 14) // 4, ((4 * vi + j) % 14) % 4] + ) + + B = T.alloc_buffer(14, "int32") + for i in T.grid(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // 4, vi % 4] + + +class TestPadValueMayNotReferenceOtherBuffer(tvm.testing.CompareBeforeAfter): + """Allow an expression to specify the pad value. + + Like TestPaddedTransformRepeatedBufferElement, but the pad value depends on + a different buffer, which is not allowed. + """ + + @pytest.fixture + def transform(self): + def transform(mod): + sch = tir.Schedule(mod) + + A = sch.get(sch.get_block("block")).reads[0].buffer + other = tir.decl_buffer(1, A.dtype, name="other") + sch.transform_layout( + "block", + "A", + lambda i: [i // 4, i % 4], + pad_value=lambda i, j: other[0], + ) + return sch.mod + + return transform + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + expected = tvm.tir.schedule.schedule.ScheduleError + + if __name__ == "__main__": tvm.testing.main() From 6a4f4ccb7556cd27b72bcc15f07f1844a5c6d273 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Sep 2022 13:17:54 -0500 Subject: [PATCH 26/26] Fix lint error --- src/tir/schedule/primitive/layout_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 6cfc1c9b599a..025723e1793d 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -884,7 +884,7 @@ class TransformationPaddingExpressionError : public ScheduleError { private: struct Visitor : ExprVisitor { - Visitor(const Buffer& buffer) : buffer_(buffer) {} + explicit Visitor(const Buffer& buffer) : buffer_(buffer) {} void VisitExpr_(const BufferLoadNode* op) final { if (!op->buffer.same_as(buffer_)) {