Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][PASS][M1c] CompactBufferAllocation #7923

Merged
merged 1 commit into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};

/*!
Expand Down
12 changes: 11 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ class BufferStore : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};

/*!
Expand Down Expand Up @@ -991,13 +992,22 @@ class BufferRegion : public ObjectRef {
TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);

/*!
* \brief Create a BufferRegion which is full region of the given buffer..
* \brief Create a BufferRegion which is full region of the given buffer.
* \param buffer The buffer to generate full BufferRegion.
* \return The BufferRegion which covers all region of the given buffer
*/
TVM_DLL static BufferRegion FullRegion(Buffer buffer);

/*!
* \brief Create a BufferRegion which is a single point of the given buffer.
* \param buffer The buffer to generate single point BufferRegion.
* \param indices The access point indices of the buffer
* \return The BufferRegion which is the single point of the given buffer.
*/
TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
tqchen marked this conversation as resolved.
Show resolved Hide resolved

TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode);
};

/*!
Expand Down
46 changes: 46 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,52 @@ TVM_DLL Pass LowerInitBlock();
*/
TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();

/*!
* \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the
* corresponding iter_values in BlockRealize, for opaque blocks by removing all
*. the iter_values in BlockRealize and iter_vars in Block.
* \return The pass.
*/
TVM_DLL Pass ConvertBlocksToOpaque();

/*!
* \brief Compact the buffer access region by removing the buffer regions that are not accessed,
* i.e. narrowing the buffer shape and adjust the access region if necessary.
* \example
* Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed.
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(16, 16)
* for j in range(0, 16):
* B[i, j] = A[i, j] + 1
* for j in range(0, 16):
* C[i, j] = B[i, j] + 1
*
* \endcode
*
* This pass narrows the buffer shape and adjust its accessed region accordingly.
* In this particular case, because only a `1 * 16` vector of `B` is accessed,
* the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`.
*
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(1, 16)
* for j in range(0, 16):
* B[0, j] = A[i, j] + 1
* for j in range(0, 16):
* C[i, j] = B[0, j] + 1
*
* \endcode
*
*
* \return The pass.
*/
TVM_DLL Pass CompactBufferAllocation();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,53 @@ def PlanAndUpdateBufferAllocationLocation():
The result pass
"""
return _ffi_api.PlanAndUpdateBufferAllocationLocation()


def ConvertBlocksToOpaque():
"""Substitute all the block vars with the PrimExprs they are bound to, indicated by
the corresponding iter_values in BlockRealize, and then convert the blocks into
opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ConvertBlocksToOpaque()


def CompactBufferAllocation():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So,is this pass doing the same job as InferBound?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the functionality should be similar, except that infer bound needs to walk through the schedule tree, while in this case the split/reorder already updated the index, so the impl should be simpler

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaination!

"""Compact the buffer access region. by removing the buffer regions that are not accessed,
i.e. narrowing the buffer shape and adjust the access region if necessary.

Example
-------
Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed.
.. code-block:: python

for i in range(0, 16):
with tir.block([]):
B = tir.alloc_buffer(16, 16)
for j in range(0, 16):
B[i, j] = A[i, j] + 1
for j in range(0, 16):
C[i, j] = B[i, j] + 1
This pass narrows the buffer shape and adjust its accessed region accordingly.
In this particular case, because only a `1 * 16` vector of `B` is accessed,
the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`.
.. code-block:: python

for i in range(0, 16):
with tir.block([]):
B = tir.alloc_buffer(1, 16)
for j in range(0, 16):
B[0, j] = A[i, j] + 1
for j in range(0, 16):
C[i, j] = B[0, j] + 1

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CompactBufferAllocation()
19 changes: 19 additions & 0 deletions src/support/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include <sys/wait.h>
#endif // __hexagon__
#endif // _WIN32

#include <tvm/runtime/container.h>

#include <algorithm>
#include <array>
#include <cctype>
Expand Down Expand Up @@ -128,6 +131,22 @@ inline std::vector<std::string> Split(const std::string& str, char delim) {
return ret;
}

/*!
* \brief Check whether the string starts with a given prefix.
* \param str The given string.
* \param prefix The given prefix.
* \return Whether the prefix matched.
*/
inline bool StartsWith(const String& str, const char* prefix) {
size_t n = str.length();
for (size_t i = 0; i < n; i++) {
if (prefix[i] == '\0') return true;
if (str.data()[i] != prefix[i]) return false;
}
// return true if the str is equal to the prefix
return prefix[n + 1] == '\0';
}

/*!
* \brief EndsWith check whether the strings ends with
* \param value The full string
Expand Down
8 changes: 8 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,14 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) {
return BufferRegion(buffer, region);
}

BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
Array<Range> region;
for (const PrimExpr& index : indices) {
region.push_back(Range::FromMinExtent(index, 1));
}
return BufferRegion(buffer, region);
}

TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<Range> region) {
return BufferRegion(buffer, region);
});
Expand Down
Loading