Skip to content

Commit

Permalink
[TensorIR][M2a] ComputeInline,ReverseComputeInline (#8170)
Browse files Browse the repository at this point in the history
This PR is part of the TensorIR upstreaming effort (#7527), which adds the first 2 schedule primitives:

- compute-Inline
- reverse-compute-inline

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Cody Yu <comaniac0422@gmail.com>
  • Loading branch information
7 people authored Jun 4, 2021
1 parent b753772 commit dd09bbb
Show file tree
Hide file tree
Showing 14 changed files with 1,599 additions and 33 deletions.
29 changes: 29 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,35 @@ class ScheduleNode : public runtime::Object {
* \return A list of loops above the given block in its scope, from outer to inner
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/******** Schedule: loops manipulation ********/
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
* 2) The block must not be the only leaf in the scope.
* 3) The body of the block must be a BufferStore statement in the form of,
* A[i, j, k, ...] = ...
* where the indices of the LHS are all distinct atomic variables,
* and no variables other than those indexing variables are allowed in the statement.
* \param block The block to be inlined to its consumer(s)
*/
virtual void ComputeInline(const BlockRV& block) = 0;
/*!
* \brief Inline a block into its only producer. It requires:
* 1) The block is a complete non-root block, which only produces and consumers one buffer
* 2) The block must not be the only leaf in the scope.
* 3) The only producer of the block is a read-after-write producer and a complete non-root block
* 4) The body of the block must be a BufferStore statement in the form of,
* B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...)
* where the indices of each `BufferLoad` on the RHS are all distinct atomic variables,
* and no variables other than those indexing variables are allowed in the statement.
* \param block The block to be inlined to its producer
*/
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/******** Schedule: loop binding/annotation ********/
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
/******** Schedule: blockize & tensorize ********/
};

/*!
Expand Down
115 changes: 115 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,121 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
"""
return _ffi_api_schedule.ScheduleGetLoops(self, block) # pylint: disable=no-member

########## Schedule: loops manipulation ##########
########## Schedule: compute location ##########
def compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its consumer(s). It requires:
1) The block is a complete non-root block, which only produces one buffer
2) The block must not be the only leaf in the scope.
3) The body of the block must be a BufferStore statement in the form of,
A[i, j, k, ...] = ...
where the indices of the LHS are all distinct atomic variables,
and no variables other than those indexing variables are allowed in the statement.
Parameters
----------
block : BlockRV
The block to be inlined to its consumer(s)
Examples
--------
Before compute-inline, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do compute-inline:
.. code-block:: python
sch = tir.Schedule(before_inline, debug_mode=True)
sch.compute_inline(sch.get_block("B"))
print(tvm.script.asscript(sch.mod["main"]))
After applying compute-inline, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
"""
_ffi_api_schedule.ScheduleComputeInline(self, block) # pylint: disable=no-member

def reverse_compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its only producer. It requires:
1) The block is a complete non-root block, which only produces and consumes one buffer
2) The block must not be the only leaf in the scope.
3) The only producer of the block is a read-after-write producer
and a complete non-root block
4) The body of the block must be a BufferStore statement in the form of,
B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...)
where the indices of each `BufferLoad` on the RHS are all distinct atomic variables,
and no variables other than those indexing variables are allowed in the statement.
Parameters
----------
block : BlockRV
The block to be inlined to its producer
Examples
--------
Before reverse-compute-inline, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do reverse-compute-inline:
.. code-block:: python
sch = tir.Schedule(before_inline, debug_mode=True)
sch.reverse_compute_inline(sch.get_block("C"))
print(tvm.script.asscript(sch.mod["main"]))
After applying reverse-compute-inline, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
"""
_ffi_api_schedule.ScheduleReverseComputeInline(self, block) # pylint: disable=no-member

########## Schedule: loop binding/annotation ##########
########## Schedule: cache read/write ##########
########## Schedule: reduction ##########
########## Schedule: blockize & tensorize ##########


@_register_object("tir.ConcreteSchedule")
class ConcreteSchedule(Schedule):
Expand Down
72 changes: 72 additions & 0 deletions src/support/array.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SUPPORT_ARRAY_H_
#define TVM_SUPPORT_ARRAY_H_
#include <tvm/runtime/container.h>

#include <vector>

namespace tvm {
namespace support {

/*!
* \brief Checks if two arrays contain the same objects
* \tparam T The type of objects in the array
* \param a The first array
* \param b The second array
* \return A boolean indicating if they are the same
*/
template <class T>
inline bool ArrayWithSameContent(const Array<T>& a, const Array<T>& b) {
if (a.size() != b.size()) {
return false;
}
int n = a.size();
for (int i = 0; i < n; ++i) {
if (!a[i].same_as(b[i])) {
return false;
}
}
return true;
}

/*!
* \brief Checks if two arrays contain the same objects
* \tparam T The type of objects in the array
* \param a The first array
* \param b The second array
* \return A boolean indicating if they are the same
*/
template <class T>
inline bool ArrayWithSameContent(const std::vector<T*>& a, const std::vector<T*>& b) {
if (a.size() != b.size()) {
return false;
}
int n = a.size();
for (int i = 0; i < n; ++i) {
if (a[i] != b[i]) {
return false;
}
}
return true;
}

} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_ARRAY_H_
59 changes: 51 additions & 8 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ namespace tir {

/******** Verification ********/
/*!
* \brief Verify the sref tree state is consistent with the IR
* \brief Verifies the sref tree state is consistent with the IR
* \param self The schedule state containing the sref to be verified
* \throw An exception will be thrown if the sref tree is not valid
*/
void VerifySRefTree(const ScheduleState& self);
/*!
* \brief Verify the cached flags in the schedule state, including:
* \brief Verifies the cached flags in the schedule state, including:
* - affine_binding
* - region_cover
* - stage_pipeline
Expand All @@ -41,10 +41,53 @@ void VerifySRefTree(const ScheduleState& self);
*/
void VerifyCachedFlags(const ScheduleState& self);

/******** Binding ********/
/******** Scope ********/
/*!
* \brief Gets the sref to the scope root block, exclusive
* \param sref The block or loop sref to be retrieved
* \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR
*/
Optional<StmtSRef> GetScopeRoot(const StmtSRef& sref);

/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and return it
* \param prim The name of the schedule primitive
* \param self The schedule state
* \param sref The sref whose scope is to be checked
* \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its
* scope root is not a stage pipeline
* \return The block sref to the scope root
*/
StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref);

/*!
* \brief Checks whether the block is a complete block under the scope
* \param self The schedule state
* \param block_sref The block to be checked
* \param scope_root The sref to the root block of the scope that `block_sref` is in
* \return A boolean indicating if the block is a complete block
* \note Definition of a complete block:
* 1) All block vars are data parallel
* 2) Dominant: the block is the only writer of its output,
* dominating the reader of its output buffers
* 3) No overlap between the buffers the block reads and writes
*/
bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root);

/*!
* \brief Checks if the block is a complete block
* \param self The schedule state
* \param block_sref The sref to the block whose completeness is to be checked
* \param scope_root_sref The scope root of the block
* \throw ScheduleError If the block is not a complete block
*/
void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref);

/******** Binding ********/
/*!
* \brief Verify if the block binding in a specific BlockRealize is an affine binding.
* \brief Verifies if the block binding in a specific BlockRealize is an affine binding.
* The binding can be represented as an injective affine map from the loop iterators.
* \param realize The BlockRealize to be analyzed
* \param loop_var_ranges The ranges of the loop variables
Expand All @@ -55,7 +98,7 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
arith::Analyzer* analyzer);

/*!
* \brief Extract the ranges of loop variables in a path of the sref tree
* \brief Extracts the ranges of loop variables in a path of the sref tree
* \param low_inclusive The lowest node in the path
* \param high_exclusive The highest node in the path, defaults to the scope root if not specified
* \param extra_relax_scope If the scope is not global, the method will look beyond the limit and
Expand All @@ -78,22 +121,22 @@ Map<Var, PrimExpr> GetBindings(const BlockRealize& realize);

/******** Block-loop relation ********/
/*!
* \brief Retrieve blocks in a specific function with its name
* \brief Retrieves blocks in a specific function with its name
* \param self The schedule state
* \param name The name of the blocks to be retrieved
* \param func_name The name of the function
* \return A list of blocks with the specific name
*/
Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const String& func_name);
/*!
* \brief Get the parent loops of the block in its scope, from outer to inner
* \brief Gets the parent loops of the block in its scope, from outer to inner
* \param self The schedule state
* \param block_sref The query block
* \return A list of loops above the given block in its scope, from outer to inner
*/
Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
/*!
* \brief Get the leaf blocks of a scope where a specific block/loop is in
* \brief Gets the leaf blocks of a scope where a specific block/loop is in
* \param self The schedule state
* \param parent_sref The StmtSRef that points to the parent block/loop
* \return A list of leaf blocks
Expand Down
Loading

0 comments on commit dd09bbb

Please sign in to comment.