Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Implement API for padded layout transformations #12720

Merged
merged 27 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
93559cf
[UnitTests] Initial unit tests for padded transformation behavior
Lunderberg Aug 25, 2022
60ea527
[Utils][Fix] Correction for non-empty Callable type annotations
Lunderberg Aug 31, 2022
2971b5b
[TIR] Pass the pad_value argument from Python to C++
Lunderberg Aug 25, 2022
885fd78
[TIR] Added check to validate lack of transformation padding
Lunderberg Aug 25, 2022
185eead
Raise error if pad value doesn't match buffer's data type.
Lunderberg Aug 26, 2022
874bfc2
Simplify expresions in IndexMap::NonsurjectiveInverse
Lunderberg Sep 6, 2022
ddea093
Preparatory refactor, update BlockNode::alloc_buffers while visiting
Lunderberg Sep 6, 2022
2055bbf
Introduced LayoutTransformPlanner for planning how to pad
Lunderberg Sep 6, 2022
f3538cd
Implemented insertion of T.assume for input buffers
Lunderberg Sep 6, 2022
a6dbd30
Implement epilogue plan for explicitly setting pad value
Lunderberg Sep 6, 2022
619c5b7
Check LetStmt bindings when determining loop dependencies
Lunderberg Sep 6, 2022
aa9bbf7
Implement replacement plan for using tir::if_then_else
Lunderberg Sep 6, 2022
7f5707c
Removed development-only unit test
Lunderberg Sep 6, 2022
5a1e63f
Add default value of NullOpt for pad_value
Lunderberg Sep 7, 2022
c463043
Removed debug code
Lunderberg Sep 7, 2022
98a8446
Update unit tests to use non-opaque blocks
Lunderberg Sep 7, 2022
19af1ee
Resolve linting error
Lunderberg Sep 7, 2022
8eb775a
Removed iter_var usage from T.where clauses
Lunderberg Sep 7, 2022
e874020
Improve docstring on pad_value
Lunderberg Sep 7, 2022
6386db5
Documentation for TransformLayoutPlanner, rename for consistency
Lunderberg Sep 7, 2022
59a0acf
Updated C++ API to take IndexMap as input
Lunderberg Sep 12, 2022
d532610
Update expected trace in multi-level tiling tests
Lunderberg Sep 12, 2022
efb25ac
Update shared_32x16_to_ldmatrix_32x16_layout to be injective
Lunderberg Sep 13, 2022
13b8cef
Updated IndexMap docstrings for single PrimExpr returns
Lunderberg Sep 15, 2022
19a78e8
Updated docstring for valid pad values, validate
Lunderberg Sep 16, 2022
d801dab
Merge branch 'main' into padded_layout_api
Lunderberg Sep 16, 2022
6a4f4cc
Fix lint error
Lunderberg Sep 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,11 @@ class ScheduleNode : public runtime::Object {
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The transformation to apply.
* \param pad_value The value to write into padding introduced by the transformation.
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
*/
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;
BufferIndexType buffer_index_type, const IndexMap& index_map,
const Optional<PrimExpr>& pad_value = NullOpt) = 0;

/*!
* \brief Apply a transformation represented by IndexMap to block
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/_type_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
return "atomic", [type_]


def callable_str(subtypes):
def callable_str(*subtypes):
if subtypes:
*arg_types, return_type = subtypes
arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types)
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,6 +2443,7 @@ def transform_layout(
block: Union[BlockRV, str],
buffer: Union[Tuple[str, int], str, Buffer],
index_map: Union[IndexMap, Callable],
pad_value: Optional[Union[int, float, IndexMap, Callable]] = None,
) -> None:
"""Apply a transformation represented by IndexMap to buffer

Expand Down Expand Up @@ -2479,6 +2480,20 @@ def transform_layout(
primitive will be called in addition to the
TransformLayout primitive.

pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the assumption when pad_value is IndexMap. I remember in the RFC we assume it should contain no BufferLoad from buffers except the current buffer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and the docstring has been updated. I've also added two unit tests, one that validates that an error is raised when the pad value loads from a different buffer, and one that specifies the intended behavior for pad value that loads from the transformed buffer. The latter is currently marked with pytest.mark.xfail, as the support isn't implemented yet.


The value to be used for any padding introduced by the
transformation.

If None, the transformation may not introduce padding.

If an int, float or PrimExpr, the transformation is the
specific value to be present in the padding.

If an IndexMap or Callable, the transformation is the
value to be present in the padding in terms of the
transformed index.
Comment on lines +2509 to +2511
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpp side only accepts Optional[PrimExpr], seems this is not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I had been thinking of it as the (const Array<Var>&, const Array<PrimExpr>&) call signature on the TE side for the transformation, and was avoiding introducing additional structures. I had forgotten that the TIR schedule accepts an IndexMap for the transformation, and agree that the C++ side would be better expressed as an Optional<IndexMap> instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates made to pass Optional<IndexMap> pad_value throughout C++ API, mimicking how IndexMap index_map is passed, along with a unit test to validate the functionality.


Examples
--------
Before transform_layout, in TensorIR, the IR is:
Expand Down Expand Up @@ -2538,7 +2553,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->

buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
self, block, buffer_index, buffer_index_type_enum, index_map
self, block, buffer_index, buffer_index_type_enum, index_map, pad_value
)
if axis_separators:
_ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member
Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ bool RewriteLayout(const Schedule& sch) {
// Apply schedule
BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value());
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(),
NullOpt);
sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion(
state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt);
};

for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
// Unpack the map to an array, maintaining the same parameter order.
Array<PrimExpr> inverse_exprs;
for (const auto& index : (*this)->initial_indices) {
inverse_exprs.push_back(inverse_exprs_map.at(index));
inverse_exprs.push_back(analyzer.Simplify(inverse_exprs_map.at(index)));
}

PrimExpr padding_predicate = padded_iter_map->padding_predicate;
Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,11 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann
/******** Schedule: Layout transformation ********/
void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) {
const IndexMap& index_map,
const Optional<PrimExpr>& pad_value) {
TVM_TIR_SCHEDULE_BEGIN();
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map);
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map,
pad_value);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
const IndexMap& index_map, const Optional<PrimExpr>& pad_value) override;
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/instruction_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,9 @@ TVM_ALWAYS_INLINE Array<ObjectRef> UnpackedInstTraits<TTraits>::_ConvertOutputs(
/********** PythonAPICall **********/

inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) {
if (const auto* str = obj.as<runtime::StringObj>()) {
if (!obj.defined()) {
os << "None";
} else if (const auto* str = obj.as<runtime::StringObj>()) {
os << str->data;
} else if (const auto* int_imm = obj.as<IntImmNode>()) {
os << int_imm->value;
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,11 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String&
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The transformation to apply.
* \param pad_value The value to write into padding introduced by the transformation.
*/
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map);
BufferIndexType buffer_index_type, const IndexMap& index_map,
const Optional<PrimExpr>& pad_value);

/*!
* \brief Apply a transformation represented by IndexMap to block
Expand Down
Loading