From dd09bbb31f068f29db55b139122d39becb28f5b3 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 3 Jun 2021 19:50:05 -0700 Subject: [PATCH] [TensorIR][M2a] ComputeInline,ReverseComputeInline (#8170) This PR is part of the TensorIR upstreaming effort (#7527), which adds the first 2 schedule primitives: - compute-Inline - reverse-compute-inline Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Cody Yu --- include/tvm/tir/schedule/schedule.h | 29 + python/tvm/tir/schedule/schedule.py | 115 +++ src/support/array.h | 72 ++ src/tir/schedule/analysis.h | 59 +- src/tir/schedule/analysis/analysis.cc | 140 ++++ src/tir/schedule/concrete_schedule.cc | 29 +- src/tir/schedule/concrete_schedule.h | 8 + src/tir/schedule/error.cc | 4 +- src/tir/schedule/error.h | 4 +- src/tir/schedule/primitive.h | 67 ++ src/tir/schedule/primitive/compute_inline.cc | 677 ++++++++++++++++++ src/tir/schedule/schedule.cc | 10 + src/tir/schedule/utils.h | 45 +- .../test_tir_schedule_compute_inline.py | 373 ++++++++++ 14 files changed, 1599 insertions(+), 33 deletions(-) create mode 100644 src/support/array.h create mode 100644 src/tir/schedule/primitive.h create mode 100644 src/tir/schedule/primitive/compute_inline.cc create mode 100644 tests/python/unittest/test_tir_schedule_compute_inline.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 2aee2cb136b3..9a09d0ad211f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -195,6 +195,35 @@ class ScheduleNode : public runtime::Object { * \return A list of loops above the given block in its scope, from outer to inner */ virtual Array GetLoops(const BlockRV& block_rv) = 0; + /******** Schedule: loops manipulation ********/ + /******** Schedule: compute location ********/ + /*! + * \brief Inline a block into its consumer(s). It requires: + * 1) The block is a complete non-root block, which only produces one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The body of the block must be a BufferStore statement in the form of, + * A[i, j, k, ...] = ... + * where the indices of the LHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param block The block to be inlined to its consumer(s) + */ + virtual void ComputeInline(const BlockRV& block) = 0; + /*! + * \brief Inline a block into its only producer. It requires: + * 1) The block is a complete non-root block, which only produces and consumers one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The only producer of the block is a read-after-write producer and a complete non-root block + * 4) The body of the block must be a BufferStore statement in the form of, + * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) + * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param block The block to be inlined to its producer + */ + virtual void ReverseComputeInline(const BlockRV& block) = 0; + /******** Schedule: loop binding/annotation ********/ + /******** Schedule: cache read/write ********/ + /******** Schedule: reduction ********/ + /******** Schedule: blockize & tensorize ********/ }; /*! diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d420f7d32db0..9452f5ab72ee 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -256,6 +256,121 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: """ return _ffi_api_schedule.ScheduleGetLoops(self, block) # pylint: disable=no-member + ########## Schedule: loops manipulation ########## + ########## Schedule: compute location ########## + def compute_inline(self, block: BlockRV) -> None: + """Inline a block into its consumer(s). It requires: + 1) The block is a complete non-root block, which only produces one buffer + 2) The block must not be the only leaf in the scope. + 3) The body of the block must be a BufferStore statement in the form of, + A[i, j, k, ...] = ... + where the indices of the LHS are all distinct atomic variables, + and no variables other than those indexing variables are allowed in the statement. + + Parameters + ---------- + block : BlockRV + The block to be inlined to its consumer(s) + + Examples + -------- + + Before compute-inline, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do compute-inline: + + .. code-block:: python + + sch = tir.Schedule(before_inline, debug_mode=True) + sch.compute_inline(sch.get_block("B")) + print(tvm.script.asscript(sch.mod["main"])) + + After applying compute-inline, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + """ + _ffi_api_schedule.ScheduleComputeInline(self, block) # pylint: disable=no-member + + def reverse_compute_inline(self, block: BlockRV) -> None: + """Inline a block into its only producer. It requires: + 1) The block is a complete non-root block, which only produces and consumes one buffer + 2) The block must not be the only leaf in the scope. + 3) The only producer of the block is a read-after-write producer + and a complete non-root block + 4) The body of the block must be a BufferStore statement in the form of, + B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) + where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, + and no variables other than those indexing variables are allowed in the statement. + + Parameters + ---------- + block : BlockRV + The block to be inlined to its producer + + Examples + -------- + + Before reverse-compute-inline, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do reverse-compute-inline: + + .. code-block:: python + + sch = tir.Schedule(before_inline, debug_mode=True) + sch.reverse_compute_inline(sch.get_block("C")) + print(tvm.script.asscript(sch.mod["main"])) + + After applying reverse-compute-inline, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_inline(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + """ + _ffi_api_schedule.ScheduleReverseComputeInline(self, block) # pylint: disable=no-member + + ########## Schedule: loop binding/annotation ########## + ########## Schedule: cache read/write ########## + ########## Schedule: reduction ########## + ########## Schedule: blockize & tensorize ########## + @_register_object("tir.ConcreteSchedule") class ConcreteSchedule(Schedule): diff --git a/src/support/array.h b/src/support/array.h new file mode 100644 index 000000000000..12d76d18db21 --- /dev/null +++ b/src/support/array.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SUPPORT_ARRAY_H_ +#define TVM_SUPPORT_ARRAY_H_ +#include + +#include + +namespace tvm { +namespace support { + +/*! + * \brief Checks if two arrays contain the same objects + * \tparam T The type of objects in the array + * \param a The first array + * \param b The second array + * \return A boolean indicating if they are the same + */ +template +inline bool ArrayWithSameContent(const Array& a, const Array& b) { + if (a.size() != b.size()) { + return false; + } + int n = a.size(); + for (int i = 0; i < n; ++i) { + if (!a[i].same_as(b[i])) { + return false; + } + } + return true; +} + +/*! + * \brief Checks if two arrays contain the same objects + * \tparam T The type of objects in the array + * \param a The first array + * \param b The second array + * \return A boolean indicating if they are the same + */ +template +inline bool ArrayWithSameContent(const std::vector& a, const std::vector& b) { + if (a.size() != b.size()) { + return false; + } + int n = a.size(); + for (int i = 0; i < n; ++i) { + if (a[i] != b[i]) { + return false; + } + } + return true; +} + +} // namespace support +} // namespace tvm +#endif // TVM_SUPPORT_ARRAY_H_ diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 8d52a621b900..dd7fee37e2d1 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -26,13 +26,13 @@ namespace tir { /******** Verification ********/ /*! - * \brief Verify the sref tree state is consistent with the IR + * \brief Verifies the sref tree state is consistent with the IR * \param self The schedule state containing the sref to be verified * \throw An exception will be thrown if the sref tree is not valid */ void VerifySRefTree(const ScheduleState& self); /*! - * \brief Verify the cached flags in the schedule state, including: + * \brief Verifies the cached flags in the schedule state, including: * - affine_binding * - region_cover * - stage_pipeline @@ -41,10 +41,53 @@ void VerifySRefTree(const ScheduleState& self); */ void VerifyCachedFlags(const ScheduleState& self); -/******** Binding ********/ +/******** Scope ********/ +/*! + * \brief Gets the sref to the scope root block, exclusive + * \param sref The block or loop sref to be retrieved + * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + */ +Optional GetScopeRoot(const StmtSRef& sref); + +/*! + * \brief Checks if scope the specified sref is in is a stage-pipeline and return it + * \param prim The name of the schedule primitive + * \param self The schedule state + * \param sref The sref whose scope is to be checked + * \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its + * scope root is not a stage pipeline + * \return The block sref to the scope root + */ +StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref); + +/*! + * \brief Checks whether the block is a complete block under the scope + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root The sref to the root block of the scope that `block_sref` is in + * \return A boolean indicating if the block is a complete block + * \note Definition of a complete block: + * 1) All block vars are data parallel + * 2) Dominant: the block is the only writer of its output, + * dominating the reader of its output buffers + * 3) No overlap between the buffers the block reads and writes + */ +bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root); + +/*! + * \brief Checks if the block is a complete block + * \param self The schedule state + * \param block_sref The sref to the block whose completeness is to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError If the block is not a complete block + */ +void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); +/******** Binding ********/ /*! - * \brief Verify if the block binding in a specific BlockRealize is an affine binding. + * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. * The binding can be represented as an injective affine map from the loop iterators. * \param realize The BlockRealize to be analyzed * \param loop_var_ranges The ranges of the loop variables @@ -55,7 +98,7 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va arith::Analyzer* analyzer); /*! - * \brief Extract the ranges of loop variables in a path of the sref tree + * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path * \param high_exclusive The highest node in the path, defaults to the scope root if not specified * \param extra_relax_scope If the scope is not global, the method will look beyond the limit and @@ -78,7 +121,7 @@ Map GetBindings(const BlockRealize& realize); /******** Block-loop relation ********/ /*! - * \brief Retrieve blocks in a specific function with its name + * \brief Retrieves blocks in a specific function with its name * \param self The schedule state * \param name The name of the blocks to be retrieved * \param func_name The name of the function @@ -86,14 +129,14 @@ Map GetBindings(const BlockRealize& realize); */ Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); /*! - * \brief Get the parent loops of the block in its scope, from outer to inner + * \brief Gets the parent loops of the block in its scope, from outer to inner * \param self The schedule state * \param block_sref The query block * \return A list of loops above the given block in its scope, from outer to inner */ Array GetLoops(const StmtSRef& block_sref); /*! - * \brief Get the leaf blocks of a scope where a specific block/loop is in + * \brief Gets the leaf blocks of a scope where a specific block/loop is in * \param self The schedule state * \param parent_sref The StmtSRef that points to the parent block/loop * \return A list of leaf blocks diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e4b767bc40ad..d58dece3c644 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -21,6 +21,146 @@ namespace tvm { namespace tir { +/******** Scope ********/ + +Optional GetScopeRoot(const StmtSRef& sref) { + for (const StmtSRefNode* p = sref->parent; p != nullptr; p = p->parent) { + if (p->stmt->IsInstance()) { + return GetRef(p); + } + } + return NullOpt; +} + +StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref) { + class RootBlockError : public ScheduleError { + public: + explicit RootBlockError(IRModule mod) : mod_(mod) {} + IRModule mod() const final { return mod_; } + String FastErrorString() const final { + return "ScheduleError: The primitive does not operate on the root block"; + } + String DetailRenderTemplate() const final { + return "The primitive does not operate on the root block"; + } + Array LocationsOfInterest() const final { return {}; } + IRModule mod_; + }; + + class NotStagePipelineError : public ScheduleError { + public: + explicit NotStagePipelineError(IRModule mod, Block block) : mod_(mod), block_(block) {} + IRModule mod() const final { return mod_; } + String FastErrorString() const final { + return "ScheduleError: The scope root is not a stage pipeline"; + } + String DetailRenderTemplate() const final { + return R"(The scope {0} is not a stage pipeline. +Definition of a scope that is a stage pipeline: +- The region cover property holds for every of its child blocks +- No write-after-read dependency or opaque dependency, +- only read-after-write and write-after-write are allowed +- All the statements in the scope are schedulable statements, i.e. Block and For +)"; + } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + + StmtSRef scope_root_sref{nullptr}; + if (Optional opt_scope_root_sref = GetScopeRoot(sref)) { + scope_root_sref = opt_scope_root_sref.value(); + } else { + throw RootBlockError(self->mod); + } + bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; + if (stage_pipeline == false) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); + throw NotStagePipelineError(self->mod, GetRef(block)); + } + return scope_root_sref; +} + +/*! + * \brief Check the dominant property of a block: + * the block is the only writer of its output, dominating the reader of its output buffers + * \param self The schedule state + * \param block_sref The block whose dominant property is to be checked + * \return A boolean indicating if the block is a dominant block + */ +bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) { + // Check whether the input block is the only writer of its outputs + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = + self->buffer_writers; + for (const BufferRegion& write_region : block->writes) { + ICHECK(buffer_writers.count(write_region->buffer)) + << "InternalError: buffer \"" << write_region->buffer->name + << "\" does not exist in the current scope, when querying block:\n" + << GetRef(block); + if (buffer_writers.at(write_region->buffer).size() != 1) { + return false; + } + } + return true; +} + +bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root) { + BlockScope scope = self->GetBlockScope(scope_root); + // Cond 1. All block vars are data parallel + const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != kDataPar) { + return false; + } + } + // Cond 2. Dominant: the block is the only writer of its output, + // dominating the reader of its output buffers + if (!IsDominantBlock(scope, block_sref)) { + return false; + } + // Cond 3. No overlap between the buffers the block reads and writes + std::unordered_set written_buffers; + written_buffers.reserve(block->writes.size()); + for (const BufferRegion& write : block->writes) { + written_buffers.insert(write->buffer.get()); + } + for (const BufferRegion& read : block->reads) { + if (written_buffers.count(read->buffer.get())) { + return false; + } + } + return true; +} + +void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + class IncompleteBlockError : public ScheduleError { + public: + explicit IncompleteBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} + String FastErrorString() const final { return "ScheduleError: Incomplete block"; } + String DetailRenderTemplate() const final { + return R"(The block {0} is not a complete block. +Definition of a complete block: +1) All block vars are data parallel +2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +3) No overlap between the buffers the block reads and writes)"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + + bool result = IsCompleteBlock(self, block_sref, scope_root_sref); + if (result == false) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); + throw IncompleteBlockError(self->mod, GetRef(block)); + } +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 60ab7920c37b..0563d39427b1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -195,11 +195,11 @@ Schedule ConcreteScheduleNode::Copy() const { * \param level An ScheduleErrorRenderLevel enum, level of error rendering * \sa ScheduleErrorRenderLevel */ -#define TVM_TIR_SCHEDULE_END(level) \ +#define TVM_TIR_SCHEDULE_END(primitive, level) \ } \ catch (const ScheduleError& error) { \ if ((level) == ScheduleErrorRenderLevel::kDetail) { \ - throw tvm::runtime::Error(error.RenderReport()); \ + throw tvm::runtime::Error(error.RenderReport(primitive)); \ } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ throw tvm::runtime::Error(error.FastErrorString()); \ } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ @@ -221,7 +221,6 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_na } } - String primitive() const final { return "get-block"; } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } @@ -249,7 +248,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_na if (blocks.size() != 1) { TVM_TIR_SCHEDULE_BEGIN(); throw NotSingleResult(name, this->state_->mod, blocks); - TVM_TIR_SCHEDULE_END(this->error_render_level_); + TVM_TIR_SCHEDULE_END("get-block", this->error_render_level_); } return CreateRV(blocks[0]); } @@ -258,6 +257,28 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } +/******** Schedule: loops manipulation ********/ +/******** Schedule: compute location ********/ + +void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ComputeInline(state_, this->GetSRef(block_rv)); + TVM_TIR_SCHEDULE_END("compute-inline", this->error_render_level_); + this->state_->DebugVerify(); +} + +void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); + TVM_TIR_SCHEDULE_END("reverse-compute-inline", this->error_render_level_); + this->state_->DebugVerify(); +} + +/******** Schedule: loop binding/annotation ********/ +/******** Schedule: cache read/write ********/ +/******** Schedule: reduction ********/ +/******** Schedule: blockize & tensorize ********/ + /******** FFI ********/ TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index ab467cec9ee3..8945fb9ee0dc 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -77,6 +77,14 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Block/Loop relation ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; + /******** Schedule: loops manipulation ********/ + /******** Schedule: compute location ********/ + void ComputeInline(const BlockRV& block) override; + void ReverseComputeInline(const BlockRV& block) override; + /******** Schedule: loop binding/annotation ********/ + /******** Schedule: cache read/write ********/ + /******** Schedule: reduction ********/ + /******** Schedule: blockize & tensorize ********/ /******** Utility functions ********/ protected: diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index f64d4aeb984b..d8dcf57b91e4 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -21,10 +21,10 @@ namespace tvm { namespace tir { -String ScheduleError::RenderReport() const { +String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - os << "ScheduleError: An error occurred in the schedule primitive '" << this->primitive() + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive << "'.\n\nThe IR is:\n" << AsTVMScript(mod); Array locs = LocationsOfInterest(); diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h index 1031672f0010..46447cfbde49 100644 --- a/src/tir/schedule/error.h +++ b/src/tir/schedule/error.h @@ -29,8 +29,6 @@ class ScheduleError : public tvm::runtime::Error { public: /*! \brief Base constructor */ ScheduleError() : tvm::runtime::Error("") {} - /*! \brief The error occurred in this scheduling primitive */ - virtual String primitive() const = 0; /*! \brief The error occurred in this IRModule */ virtual IRModule mod() const = 0; /*! \brief The locations of interest that we want to point out */ @@ -51,7 +49,7 @@ class ScheduleError : public tvm::runtime::Error { */ virtual String FastErrorString() const = 0; /*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */ - String RenderReport() const; + String RenderReport(const String& primitive) const; }; } // namespace tir diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h new file mode 100644 index 000000000000..ab8299e38169 --- /dev/null +++ b/src/tir/schedule/primitive.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_ +#define TVM_TIR_SCHEDULE_PRIMITIVE_H_ + +#include + +namespace tvm { +namespace tir { + +/******** Schedule: loops manipulation ********/ + +/******** Schedule: compute location ********/ +/*! + * \brief Inline a block into its consumer(s). It requires: + * 1) The block is a complete non-root block, which only produces one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The body of the block must be a BufferStore statement in the form of, + * A[i, j, k, ...] = ... + * where the indices of the LHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param self The state of the schedule + * \param block_sref The sref to the block to be inlined to its consumer(s) + */ +TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); +/*! + * \brief Inline a block into its only producer. It requires: + * 1) The block is a complete non-root block, which only produces and consumers one buffer + * 2) The block must not be the only leaf in the scope. + * 3) The only producer of the block is a read-after-write producer and a complete non-root block + * 4) The body of the block must be a BufferStore statement in the form of, + * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) + * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, + * and no variables other than those indexing variables are allowed in the statement. + * \param self The state of the schedule + * \param block_sref The sref to the block to be inlined to its producer + */ +TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); + +/******** Schedule: loop binding/annotation ********/ + +/******** Schedule: cache read/write ********/ + +/******** Schedule: reduction ********/ + +/******** Schedule: blockize & tensorize ********/ + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_PRIMITIVE_H_ diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc new file mode 100644 index 000000000000..6bd6388fafff --- /dev/null +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -0,0 +1,677 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of + 'A[i, j, k, ...] = f(i, j, k, ...)', +where the indices on the left are distinct atomic variables, +and there should not no variables other than the index variables)"; + +static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of + `B[...] = g(i, j, k, A[i, j, k, ...] ...)`, +where A is the only buffer the block consumes, whose indices are distinct atomic variables, +and there should not no variables other than the index variables)"; + +class NotSingleReadWriteBuffer : public ScheduleError { + public: + explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) + : mod_(mod), is_read_(is_read), block_(std::move(block)) {} + + String FastErrorString() const final { + return is_read_ ? "ScheduleError: The block is allowed to read only a single buffer region" + : "ScheduleError: The block is allowed to write only a single buffer region"; + } + + String DetailRenderTemplate() const final { + if (is_read_) { + int k = block_->reads.size(); + return "The block is only allowed to read a single buffer region, but it reads " + + std::to_string(k) + " region(s): {0}"; + } else { + int k = block_->writes.size(); + return "The block is only allowed to write a single buffer region, but it writes " + + std::to_string(k) + " region(s): {0}"; + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + bool is_read_; + Block block_; + + static Buffer GetSingleRead(const ScheduleState& self, const Block& block) { + if (block->reads.size() != 1) { + throw NotSingleReadWriteBuffer(self->mod, true, block); + } + return block->reads[0]->buffer; + } + + static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { + if (block->writes.size() != 1) { + throw NotSingleReadWriteBuffer(self->mod, false, block); + } + return block->writes[0]->buffer; + } +}; + +class BodyAnalysisError : public ScheduleError { + public: + explicit BodyAnalysisError(bool is_reverse, IRModule mod, Block block) + : is_reverse_(is_reverse), mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block cannot be inlined because its body pattern does not meet the " + "condition for inlining"; + } + + String DetailRenderTemplate() const final { + return is_reverse_ ? kErrBodyReverseInline : kErrBodyInline; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + bool is_reverse_; + IRModule mod_; + Block block_; +}; + +class OnlyLeafError : public ScheduleError { + public: + explicit OnlyLeafError(IRModule mod, Block leaf_block, StmtSRef scope_root_sref) + : mod_(mod), leaf_block_(std::move(leaf_block)), scope_root_(nullptr) { + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); + this->scope_root_ = GetRef(scope_root); + } + + String FastErrorString() const final { + return "ScheduleError: Cannot remove the only leaf in the scope"; + } + + String DetailRenderTemplate() const final { + return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " + "scope will be empty."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } + + IRModule mod_; + Block leaf_block_; + Block scope_root_; +}; + +class NonSingleProducerError : public ScheduleError { + public: + explicit NonSingleProducerError(IRModule mod, Block block) + : mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The consumer block to be inlined is required to have only a single " + "producer block, and the producer block should be a complete block who has only a " + "single consumer"; + } + + String DetailRenderTemplate() const final { + return "The consumer block {0} to be inlined is required to have only a single " + "producer block, and the producer block should be a complete block who has only a " + "single consumer"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + + static void Check(const ScheduleState& self, const StmtSRef& consumer_block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + Array producers = scope->GetDepsByDst(consumer_block_sref); + if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) { + const StmtSRef& producer_block_sref = producers[0]->src; + if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) { + Array consumers = scope->GetDepsBySrc(producer_block_sref); + if (consumers.size() == 1) { + return; + } + } + } + const BlockNode* block = TVM_SREF_TO_BLOCK(block, consumer_block_sref); + throw NonSingleProducerError(self->mod, GetRef(block)); + } +}; + +class OpaqueAccessError : public ScheduleError { + public: + explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) + : mod_(mod), scope_root_(nullptr) { + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); + this->scope_root_ = GetRef(scope_root); + } + + String FastErrorString() const final { + return "ScheduleError: The buffer to be inlined has opaque access (e.g. `B.data`), or its " + "subregion is matched into other blocks"; + } + + String DetailRenderTemplate() const final { + return "The buffer to be inlined has opaque access (e.g. `B.data`), or its " + "subregion is matched into other blocks: {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {scope_root_}; } + + IRModule mod_; + Block scope_root_; +}; + +/*! + * \brief Construct a new AST, with a specific sref tree leaf removed. + * The leaf's ancestors who have only a single child will be removed too. + * \param leaf_block_sref The block/loop sref to the sref tree leaf to be removed + * \param src_stmt The root of the subtree where the replacement begins + * \param tgt_stmt The root of the subtree after the replacement + * \return A boolean indicating if the leaf can be removed successfully + * \note Removal is not conducted beyond scope-level. + * + * An example of the removal plan, say we are removing the leaf block "B" from the AST. + * + * \code + * with block([], "scope_root"): + * ... + * with block([128, 128], "B") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + 1.0 + * with block([128, 128], "C") as [vi, vj]: + * C[vi, vj] = B[vi, vj] * 2.0 + * \endcode + * + * Ths method does not mutate the AST, instead it returns the a `(src_stmt, tgt_stmt)` pair as a + * plan to substitute certain pieces of the IR. + * + * In our example, it returns block "scope_root" as `src_stmt`, and the result `tgt_stmt` is: + * + * \code + * with block([], "scope_root"): + * ... + * with block([128, 128], "C") as [vi, vj]: + * C[vi, vj] = B[vi, vj] * 2.0 + * \endcode + */ +bool LeafBlockRemovalPlan(const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt) { + // Go upwards until find an ancestor with more than one child + const StmtNode* last_stmt = leaf_block_sref->stmt; + StmtSRefNode* sref = leaf_block_sref->parent; + for (;; last_stmt = sref->stmt, sref = sref->parent) { + if (const auto* loop = sref->StmtAs()) { + if (const auto* seq = loop->body.as()) { + if (seq->size() > 1) { + break; + } + } + } else { + // Removal is not done beyond scope-level. + // When encountering a block, i.e. the scope root, we simply stop + break; + } + } + if (const auto* block = sref->StmtAs()) { + if (const auto* seq = block->body.as()) { + ObjectPtr n = make_object(*block); + n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + *src_stmt = GetRef(block); + *tgt_stmt = Stmt(std::move(n)); + return true; + } + } + if (const auto* loop = sref->StmtAs()) { + if (const auto* seq = loop->body.as()) { + ObjectPtr n = make_object(*loop); + n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + *src_stmt = GetRef(loop); + *tgt_stmt = Stmt(std::move(n)); + return true; + } + } + return false; +} + +/*! + * \brief The base class of the inliner, which handles: + * 1) Substitute a subtree with the specific block being inlined + * 2) Update the block signature to reflect the changes of read/write/allocated buffers + * 3) Maintain a list of index variables and their substition of the buffer being inlined + */ +class BaseInliner : public StmtExprMutator { + protected: + explicit BaseInliner(const Buffer& inlined_buffer, const Block& inlined_block, + const StmtSRef& scope_root_sref) + : inlined_buffer_(inlined_buffer), + inlined_store_(inlined_block->body.as()), + scope_root_sref_(scope_root_sref) { + AddBuffersInBlockSignature(inlined_block.get()); + } + + PrimExpr VisitExpr_(const VarNode* var) final { + CheckOpaqueAccess(var); + return StmtExprMutator::VisitExpr_(var); + } + + PrimExpr VisitExpr_(const LoadNode* load) final { + CheckOpaqueAccess(load->buffer_var.get()); + return StmtExprMutator::VisitExpr_(load); + } + + Stmt VisitStmt_(const StoreNode* store) final { + CheckOpaqueAccess(store->buffer_var.get()); + return StmtExprMutator::VisitStmt_(store); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (src_stmt.get() == loop) { + loop = tgt_stmt.as(); + ICHECK(loop != nullptr); + } + return StmtExprMutator::VisitStmt_(loop); + } + + Stmt VisitStmt_(const BlockNode* block) final { + CheckMatchBufferRegion(block); + AddBuffersInBlockSignature(block); + Block src_block = GetRef(block); + if (src_block.same_as(src_stmt)) { + block = tgt_stmt.as(); + ICHECK(block != nullptr); + } + Block tgt_block = Downcast(StmtExprMutator::VisitStmt_(block)); + bool is_scope_root = src_block.get() == scope_root_sref_->stmt; + tgt_block = UpdateBuffersInBlockSignature(std::move(tgt_block), is_scope_root); + block_reuse.Set(src_block, tgt_block); + return std::move(tgt_block); + } + + /*! + * \brief Check if the indices are atomic distinct variables and the access is n-dimensional. + * If so, set `self->idx_vars_` properly. + * \param indices The indices to be extracted + * \param expected_ndim The expected ndim of the access + * \return A boolean flag indicating if the check is successful + */ + bool UpdateAndCheckIndexVars(const Array& indices, int expected_ndim) { + int n = indices.size(); + if (n != expected_ndim) { + // Failure: dimension mismatch + return false; + } + std::vector result; + result.reserve(n); + for (const PrimExpr& i : indices) { + if (const auto* var = i.as()) { + result.push_back(var); + } else { + // Failure: indexing expression is not a variable + return false; + } + } + using DistinctSet = std::unordered_set; + int n_distinct = DistinctSet(result.begin(), result.end()).size(); + if (n != n_distinct) { + // Failure: indexing variables are not distinct + return false; + } + if (idx_vars_.empty()) { + idx_vars_ = std::move(result); + } else if (!support::ArrayWithSameContent(idx_vars_, result)) { + // Failure: indexing variables are not consitent in different BufferLoads + return false; + } + return true; + } + + /*! + * \brief Set the mapping of index substitution `self->idx_sub_` + * \param indices The expressions that the corresponding index variables are replaced to + */ + void SetIndexSubstitution(const Array& indices) { + ICHECK_EQ(indices.size(), idx_vars_.size()); + int n = idx_vars_.size(); + idx_sub_.reserve(n); + for (int i = 0; i < n; ++i) { + idx_sub_[idx_vars_[i]] = indices[i]; + } + } + + private: + /*! + * \brief Add the buffers in the block signature to the `buffer_var_map_`, + * which is used for auto-completion of a block's read/write region + * \param block The block whose signature to be added + */ + void AddBuffersInBlockSignature(const BlockNode* block) { + for (const BufferRegion& buffer_region : block->reads) { + const Buffer& buffer = buffer_region->buffer; + buffer_var_map_.Set(buffer->data, buffer); + } + for (const BufferRegion& buffer_region : block->writes) { + const Buffer& buffer = buffer_region->buffer; + buffer_var_map_.Set(buffer->data, buffer); + } + for (const Buffer& buffer : block->alloc_buffers) { + buffer_var_map_.Set(buffer->data, buffer); + } + } + + /*! + * \brief Update the following block signature: + * 1) tir.alloc_buffer, if the block is scope root + * 2) tir.reads, if the block is not scope root + * 3) tir.writes, if the block is not scope root + * \param block The block to be updated + * \param is_scope_root A flag indicating if a block is the scope root of the block to be inlined + * \return The updated block + */ + Block UpdateBuffersInBlockSignature(Block block, bool is_scope_root) { + // Step 1. Update `BlockNode::alloc_buffers` + Array alloc_buffers; + if (is_scope_root) { + alloc_buffers.reserve(block->alloc_buffers.size()); + for (const Buffer& alloc_buffer : block->alloc_buffers) { + if (!alloc_buffer.same_as(inlined_buffer_)) { + alloc_buffers.push_back(alloc_buffer); + } + } + } else { + alloc_buffers = std::move(block->alloc_buffers); + } + // Step 2. Update `BlockNode::reads` and `BlockNode::writes` + Array reads = std::move(block->reads); + Array writes = std::move(block->writes); + if (!is_scope_root) { + Array> inspected = GetBlockAccessRegion(block, buffer_var_map_); + reads = std::move(inspected[0]); + writes = std::move(inspected[1]); + } + // Step 3. Assemble the result + BlockNode* n = block.CopyOnWrite(); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->alloc_buffers = std::move(alloc_buffers); + return block; + } + + /*! + * \brief Opaque access to the buffer to be inlined is disallowed. + * This method checks if a buffer var belongs to the buffer + * \param buffer_var The buffer var to be checked + */ + void CheckOpaqueAccess(const VarNode* buffer_var) { + if (inlined_buffer_->data.get() == buffer_var) { + this->has_opaque_access = true; + } + } + + /*! + * \brief The buffer to be inlined is not allowed to be region matched. + * This method checks if a block has the disallowed behavior of buffer region match. + * \param block The block to be checked + */ + void CheckMatchBufferRegion(const BlockNode* block) { + for (const MatchBufferRegion& match_buffer_region : block->match_buffers) { + const Buffer& matched = match_buffer_region->source->buffer; + if (matched.same_as(inlined_buffer_)) { + this->has_opaque_access = true; + } + } + } + + protected: + /*! \brief The buffer to be inlined */ + Buffer inlined_buffer_{nullptr}; + /*! \brief The body of the block to be inlined */ + const BufferStoreNode* inlined_store_{nullptr}; + /*! \brief The scope root */ + StmtSRef scope_root_sref_{nullptr}; + /*! \brief Maps a buffer's data field to itself */ + Map buffer_var_map_; + /*! \brief The indices used for indexing the buffer to be inlined */ + std::vector idx_vars_; + /*! \brief The mapping to substitute index variables to PrimExprs */ + std::unordered_map idx_sub_; + + public: + /*! + * \brief The Stmt to be replaced when removing the leaf block + * \note The pair (src_stmt, tgt_stmt) are produced by LeafBlockRemovalPlan to indicate a + * transformation on top of the input AST. We take this approach to avoid changing the AST twice + */ + Stmt src_stmt{nullptr}; + /*! \brief The Stmt to be replaced to when removing the leaf block */ + Stmt tgt_stmt{nullptr}; + /*! \brief The reuse mapping of block srefs */ + Map block_reuse; + /*! \brief Indicates if there is any opaque access of the inlined buffer */ + bool has_opaque_access{false}; +}; + +/*! + * \brief Helper to inline the producer block into its consumer(s) + * The derived class implements the following functionalities: + * 1) Substitute `BufferLoad` on the buffer to be inlined + * to its value calculation in the producer block + * 2) Analyze the producer block to determine the remapping of index variables + */ +class ComputeInliner : public BaseInliner { + public: + explicit ComputeInliner(const Buffer& inlined_buffer, const Block& producer_block, + const StmtSRef& scope_root_sref) + : BaseInliner(inlined_buffer, producer_block, scope_root_sref) {} + + bool BodyPatternAllowInline(const Block& producer_block) { + if (inlined_store_ == nullptr) { + return false; + } + int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) { + return false; + } + return true; + } + + private: + using BaseInliner::VisitExpr_; + using BaseInliner::VisitStmt_; + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (!load->buffer.same_as(inlined_buffer_)) { + return std::move(load); + } + return ReplaceInlinedBuffer(std::move(load)); + } + + PrimExpr ReplaceInlinedBuffer(BufferLoad load) { + SetIndexSubstitution(load->indices); + return Substitute(inlined_store_->value, idx_sub_); + } +}; + +/*! + * \brief Helper to inline the consumer block into its producer + * The derived class implements the following functionalities: + * 1) Analyze the consumer block to determine the remapping of index variables + * 2) Substitute `BufferStore` of the buffer to be inlined, + * replacing it with direct writing to the buffer that consumer writes + */ +class ReverseComputeInliner : public BaseInliner { + class Substituter : public StmtExprMutator { + public: + explicit Substituter(ReverseComputeInliner* self) : self_(self) {} + + private: + PrimExpr VisitExpr_(const VarNode* var) final { + auto it = self_->idx_sub_.find(var); + ICHECK(it != self_->idx_sub_.end()); + return (*it).second; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + return load->buffer.same_as(self_->inlined_buffer_) ? self_->producer_rhs_ : load; + } + + ReverseComputeInliner* self_; + }; + + public: + explicit ReverseComputeInliner(const Buffer& inlined_buffer, const Block& consumer_block, + const StmtSRef& scope_root_sref) + : BaseInliner(inlined_buffer, consumer_block, scope_root_sref) {} + + bool BodyPatternAllowInline(const Block& consumer_block) { + if (inlined_store_ == nullptr) { + // Failure: block body is not BufferStore + return false; + } + std::vector loads = ExtractBufferLoad(inlined_buffer_, inlined_store_); + if (loads.size() == 0) { + // Failure: no BufferLoad from the `inlined_buffer_` + return false; + } + int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + for (const BufferLoadNode* load : loads) { + if (!UpdateAndCheckIndexVars(load->indices, n_vars)) { + // Failure: incorrect of inconsistent index vars + return false; + } + } + return true; + } + + private: + using BaseInliner::VisitExpr_; + using BaseInliner::VisitStmt_; + + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (!store->buffer.same_as(inlined_buffer_)) { + return std::move(store); + } + return ReplaceInlinedBuffer(std::move(store)); + } + + Stmt ReplaceInlinedBuffer(BufferStore producer) { + SetIndexSubstitution(producer->indices); + producer_rhs_ = producer->value; + return Substituter(this)(GetRef(inlined_store_)); + } + + /*! + * \brief Extracts expressions that loads a specific buffer + * \param buffer The buffer to be loaded from + * \param from The BufferStore statement to be extracted from + * \return A list of `BufferLoad` expressions + */ + static std::vector ExtractBufferLoad(const Buffer& buffer, + const BufferStoreNode* from) { + struct Extractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.get() == buffer) { + result.push_back(load); + } + ExprVisitor::VisitExpr_(load); + } + const BufferNode* buffer; + std::vector result; + } extractor; + extractor.buffer = buffer.get(); + for (const PrimExpr& expr : from->indices) { + extractor(expr); + } + extractor(from->value); + return std::move(extractor.result); + } + + /*! \brief The RHS value of the producer's BufferStore statement */ + PrimExpr producer_rhs_{nullptr}; +}; + +void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { + const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); + Block producer_block = GetRef(_producer_block); + Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); + // Step 1. Get the scope block + StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, producer_block_sref); + // Step 2. Check completeness + CheckCompleteBlock(self, producer_block_sref, scope_root_sref); + // Step 3. Analyze the block body + ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref); + if (!inliner.BodyPatternAllowInline(producer_block)) { + throw BodyAnalysisError(false, self->mod, producer_block); + } + // Step 4. Create a plan that removes the leaf block to be inlined + if (!LeafBlockRemovalPlan(producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt)) { + throw OnlyLeafError(self->mod, producer_block, scope_root_sref); + } + // Step 5. Create an AST where the leaf `producer_block_sref` points to is removed, + // and update other blocks who read from the removed block + Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + if (inliner.has_opaque_access) { + throw OpaqueAccessError(self->mod, scope_root_sref); + } + // Step 6. Do the real mutation on the AST and the sref tree in the schedule state + self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); +} + +void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { + const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); + Block consumer_block = GetRef(_consumer_block); + Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); + // Step 1. Get the scope block + StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, consumer_block_sref); + // Step 2. Check completeness + CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); + // Step 3. Check if the consumer has a single complete producer + NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); + // Step 4. Analyze the block body + ReverseComputeInliner inliner(inlined_buffer, consumer_block, scope_root_sref); + if (!inliner.BodyPatternAllowInline(consumer_block)) { + throw BodyAnalysisError(true, self->mod, consumer_block); + } + // Step 5. Create a plan that removes the leaf block to be inlined + if (!LeafBlockRemovalPlan(consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt)) { + throw OnlyLeafError(self->mod, consumer_block, scope_root_sref); + } + // Step 6. Create an AST where the leaf `consumer_block_sref` points to is removed, + // and update other blocks who read from the removed block + Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + if (inliner.has_opaque_access) { + throw OpaqueAccessError(self->mod, scope_root_sref); + } + // Step 7. Do the real mutation on the AST and the sref tree in the schedule state + self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a1a4f09a7525..115f7936f64e 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -122,6 +122,16 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); +/******** (FFI) loops manipulation ********/ +/******** (FFI) compute location ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") + .set_body_method(&ScheduleNode::ComputeInline); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") + .set_body_method(&ScheduleNode::ReverseComputeInline); +/******** (FFI) loop binding/annotation ********/ +/******** (FFI) cache read/write ********/ +/******** (FFI) reduction ********/ +/******** (FFI) blockize & tensorize ********/ } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index e7c73120c730..19ed995ac8cc 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -34,8 +34,10 @@ #include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" +#include "../../support/array.h" #include "./analysis.h" #include "./error.h" +#include "./primitive.h" namespace tvm { namespace tir { @@ -114,6 +116,33 @@ inline bool CanRelaxStorageUndereThread(const runtime::StorageScope& storage_sco return static_cast(storage_scope.rank) <= static_cast(thread_scope.rank); } +/******** SeqStmt ********/ + +/*! + * \brief Remove a specific Stmt from a SeqStmt. If a SeqStmt contains a BlockRealize, + * whose block is the Stmt to be removed, then remove that BlockRealize too. + * \param seq The SeqStmt to be removed from + * \param to_remove The Stmt to be removed + * \return The removal result + */ +inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { + ICHECK_GT(seq->size(), 1); + Array new_stmts; + new_stmts.reserve(seq->size()); + for (const Stmt& stmt : seq->seq) { + if (to_remove.same_as(stmt)) { + continue; + } + if (const auto* realize = stmt.as()) { + if (to_remove.same_as(realize->block)) { + continue; + } + } + new_stmts.push_back(stmt); + } + return SeqStmt::Flatten(new_stmts); +} + /******** Integer set ********/ /*! @@ -132,22 +161,6 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } -/*! - * \brief Converts an N-dimensional integer set to N-dimensional region - * \param nd_int_set The integer set - * \return The region as the result of conversion - */ -inline Array AsRegion(const Array& nd_int_set, arith::Analyzer* analyzer) { - Array result; - result.reserve(nd_int_set.size()); - for (const arith::IntSet& int_set : nd_int_set) { - PrimExpr min = analyzer->Simplify(int_set.min()); - PrimExpr extent = analyzer->Simplify(int_set.max() - int_set.min() + 1); - result.push_back(Range::FromMinExtent(std::move(min), std::move(extent))); - } - return result; -} - } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py new file mode 100644 index 000000000000..c34ec8d610d6 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -0,0 +1,373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_multi_producer_consumer(a: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + D = tir.match_buffer(d, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers + + +@tvm.script.tir +def elementwise_multi_consumer_inlined(a: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + D = tir.match_buffer(d, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] + + +@tvm.script.tir +def elementwise_standalone(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_standalone_dce(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_under_loop(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + for j in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +@tvm.script.tir +def fail_multi_reader_writer(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.alloc_buffer((128, 128)) + D = tir.match_buffer(d, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + C[vi, vj] = A[vi, vj] + 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + D[vi, vj] = B[vi, vj] + C[vi, vj] + + +@tvm.script.tir +def elementwise_multi_reverse_loads(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 + + +@tvm.script.tir +def elementwise_multi_reverse_loads_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 + + +@tvm.script.tir +def opaque_access_load(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + tir.reads(B[0:128, 0:128]) + tir.writes(C[0:128, 0:128]) + C[vi, vj] = tir.load("float32", B.data, vi * 128 + vj) + 1.0 + + +@tvm.script.tir +def opaque_access_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + tir.reads(B[0:128, 0:128]) + tir.writes(C[0:128, 0:128]) + tir.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) + C[vi, vj] = tir.load("float32", B.data, vi * 16 + vj) + 1.0 + + +@tvm.script.tir +def buffer_matched(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + Bb = tir.match_buffer_region(B[vi : vi + 1, vj]) + C[vi, vj] = Bb[0, 0] + 1.0 + + +@tvm.script.tir +def elementwise_predicate(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.where(B[i, j] < 10.0) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_predicate_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.where(A[i, j] * 2.0 < 10.0) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +@tvm.script.tir +def elementwise_multi_loads(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 126], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] + + +@tvm.script.tir +def elementwise_multi_loads_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 126], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_compute_inline_elementwise(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + + +def test_compute_inline_under_loop(): + sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + + +def test_compute_inline_as_dce(): + sch = tir.Schedule(elementwise_standalone, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + + +def test_compute_inline_multi_consumer(): + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + block_d = sch.get_block("D") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_multi_consumer_inlined, sch.mod["main"]) + assert sch.get(block_c).name_hint == "C" + assert sch.get(block_d).name_hint == "D" + + +def test_compute_inline_fail_multi_writer(): + sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True, error_render_level="detail") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_reverse_compute_inline_elementwise(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_b).name_hint == "B" + + +def test_reverse_compute_inline_under_loop(): + sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + assert sch.get(block_b).name_hint == "B" + + +def test_reverse_compute_inline_fail_as_dce(): + sch = tir.Schedule(elementwise_standalone, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_b) + + +def test_reverse_compute_inline_fail_multi_producer(): + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + block_d = sch.get_block("D") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_d) + + +def test_reverse_compute_inline_fail_multi_reader(): + sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True) + block_c = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_c) + + +def test_reverse_compute_multi_reverse_loads(): + sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mode=True) + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_multi_reverse_loads_inlined, sch.mod["main"]) + + +def test_reverse_compute_fail_multi_reverse_loads(): + sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + block_c = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_c) + + +def test_opaque_access_load(): + sch = tir.Schedule(opaque_access_load, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_opaque_access_store(): + sch = tir.Schedule(opaque_access_store, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_buffer_matched(): + sch = tir.Schedule(buffer_matched, debug_mode=True) + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block_b) + + +def test_compute_inline_predicate(): + sch = tir.Schedule(elementwise_predicate, debug_mode=True) + block_b = sch.get_block("B") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_predicate_inlined, sch.mod["main"]) + + +def test_compute_inline_multi_loads(): + sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + block_b = sch.get_block("B") + sch.compute_inline(block_b) + tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) + + +if __name__ == "__main__": + test_compute_inline_elementwise() + test_compute_inline_under_loop() + test_compute_inline_as_dce() + test_compute_inline_multi_consumer() + test_compute_inline_fail_multi_writer() + test_reverse_compute_inline_elementwise() + test_reverse_compute_inline_under_loop() + test_reverse_compute_inline_fail_as_dce() + test_reverse_compute_inline_fail_multi_producer() + test_reverse_compute_inline_fail_multi_reader() + test_reverse_compute_multi_reverse_loads() + test_reverse_compute_fail_multi_reverse_loads() + test_opaque_access_load() + test_opaque_access_store() + test_buffer_matched() + test_compute_inline_predicate() + test_compute_inline_multi_loads()