Skip to content

Commit

Permalink
[TensorIR][PASS] CompactBufferAllocation (apache#7923)
Browse files Browse the repository at this point in the history
Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Cody Yu <comaniac0422@gmail.com>
  • Loading branch information
4 people authored and Trevor Morris committed May 6, 2021
1 parent 54efdc6 commit 93634ed
Show file tree
Hide file tree
Showing 10 changed files with 1,115 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};

/*!
Expand Down
12 changes: 11 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ class BufferStore : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};

/*!
Expand Down Expand Up @@ -991,13 +992,22 @@ class BufferRegion : public ObjectRef {
TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);

/*!
* \brief Create a BufferRegion which is full region of the given buffer..
* \brief Create a BufferRegion which is full region of the given buffer.
* \param buffer The buffer to generate full BufferRegion.
* \return The BufferRegion which covers all region of the given buffer
*/
TVM_DLL static BufferRegion FullRegion(Buffer buffer);

/*!
* \brief Create a BufferRegion which is a single point of the given buffer.
* \param buffer The buffer to generate single point BufferRegion.
* \param indices The access point indices of the buffer
* \return The BufferRegion which is the single point of the given buffer.
*/
TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);

TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode);
};

/*!
Expand Down
46 changes: 46 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,52 @@ TVM_DLL Pass LowerInitBlock();
*/
TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();

/*!
* \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the
* corresponding iter_values in BlockRealize, for opaque blocks by removing all
*. the iter_values in BlockRealize and iter_vars in Block.
* \return The pass.
*/
TVM_DLL Pass ConvertBlocksToOpaque();

/*!
* \brief Compact the buffer access region by removing the buffer regions that are not accessed,
* i.e. narrowing the buffer shape and adjust the access region if necessary.
* \example
* Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed.
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(16, 16)
* for j in range(0, 16):
* B[i, j] = A[i, j] + 1
* for j in range(0, 16):
* C[i, j] = B[i, j] + 1
*
* \endcode
*
* This pass narrows the buffer shape and adjust its accessed region accordingly.
* In this particular case, because only a `1 * 16` vector of `B` is accessed,
* the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`.
*
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(1, 16)
* for j in range(0, 16):
* B[0, j] = A[i, j] + 1
* for j in range(0, 16):
* C[i, j] = B[0, j] + 1
*
* \endcode
*
*
* \return The pass.
*/
TVM_DLL Pass CompactBufferAllocation();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,53 @@ def PlanAndUpdateBufferAllocationLocation():
The result pass
"""
return _ffi_api.PlanAndUpdateBufferAllocationLocation()


def ConvertBlocksToOpaque():
"""Substitute all the block vars with the PrimExprs they are bound to, indicated by
the corresponding iter_values in BlockRealize, and then convert the blocks into
opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ConvertBlocksToOpaque()


def CompactBufferAllocation():
"""Compact the buffer access region. by removing the buffer regions that are not accessed,
i.e. narrowing the buffer shape and adjust the access region if necessary.
Example
-------
Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed.
.. code-block:: python
for i in range(0, 16):
with tir.block([]):
B = tir.alloc_buffer(16, 16)
for j in range(0, 16):
B[i, j] = A[i, j] + 1
for j in range(0, 16):
C[i, j] = B[i, j] + 1
This pass narrows the buffer shape and adjust its accessed region accordingly.
In this particular case, because only a `1 * 16` vector of `B` is accessed,
the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`.
.. code-block:: python
for i in range(0, 16):
with tir.block([]):
B = tir.alloc_buffer(1, 16)
for j in range(0, 16):
B[0, j] = A[i, j] + 1
for j in range(0, 16):
C[i, j] = B[0, j] + 1
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CompactBufferAllocation()
19 changes: 19 additions & 0 deletions src/support/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include <sys/wait.h>
#endif // __hexagon__
#endif // _WIN32

#include <tvm/runtime/container.h>

#include <algorithm>
#include <array>
#include <cctype>
Expand Down Expand Up @@ -128,6 +131,22 @@ inline std::vector<std::string> Split(const std::string& str, char delim) {
return ret;
}

/*!
* \brief Check whether the string starts with a given prefix.
* \param str The given string.
* \param prefix The given prefix.
* \return Whether the prefix matched.
*/
inline bool StartsWith(const String& str, const char* prefix) {
size_t n = str.length();
for (size_t i = 0; i < n; i++) {
if (prefix[i] == '\0') return true;
if (str.data()[i] != prefix[i]) return false;
}
// return true if the str is equal to the prefix
return prefix[n + 1] == '\0';
}

/*!
* \brief EndsWith check whether the strings ends with
* \param value The full string
Expand Down
8 changes: 8 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,14 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) {
return BufferRegion(buffer, region);
}

BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
Array<Range> region;
for (const PrimExpr& index : indices) {
region.push_back(Range::FromMinExtent(index, 1));
}
return BufferRegion(buffer, region);
}

TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<Range> region) {
return BufferRegion(buffer, region);
});
Expand Down
Loading

0 comments on commit 93634ed

Please sign in to comment.