Skip to content

Commit

Permalink
[MetaSchedule] Tile and pack intermediate output for CUDA TensorCore (#…
Browse files Browse the repository at this point in the history
…14108)

* [MetaSchedule] Tile and pack intermediate output for CUDA TensorCore

* clean up schedule rule mltc

* add lhs analyzer

* prevent simplifying single point

* clean up

* lint

* fix rewrite_tensorize test

* fix software pipeline test

* fix compile on mac

* fix test cases

* remove unused

* rebase

* only use json format for roundtrip

* lint

* Update src/tir/schedule/ir_comparator.h

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

---------

Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
3 people authored Mar 6, 2023
1 parent a15ade3 commit 424c749
Show file tree
Hide file tree
Showing 11 changed files with 567 additions and 473 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _find_match_sketch_id(
decisions=new_decisions,
).apply_to_schedule(sch, remove_postproc=True)
if structural_equal(sch.mod, expected_mod):
verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, text_format="json")
return sketch_id
return None

Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
Expand Down
13 changes: 9 additions & 4 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
return results;
}

Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop,
int n_tiles) const {
std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingNode::SplitLoop(
const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const {
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
return splits;
return {factors, splits};
}

std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
Expand All @@ -207,6 +207,9 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
state->tile_factors.resize(tiles.size());
std::vector<Array<tir::ExprRV>> tile_factors;
tile_factors.resize(tiles.size());
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;
Expand All @@ -231,14 +234,16 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
if (n_tiles == 1) {
tiles[idx->at(0)].push_back(loop);
} else {
auto splits = SplitLoop(sch, block_rv, loop, n_tiles);
auto [factors, splits] = SplitLoop(sch, block_rv, loop, n_tiles);

// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);
tile_factors[idx->at(j)].push_back(factors[j]);
}
}
}
state->tile_factors = std::move(tile_factors);
// Step 3. Reorder to organize the tiles
sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
// Step 4. Bind the tiles to threads
Expand Down
8 changes: 6 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class StateNode : public Object {
tir::BlockRV block_rv;
/*! \brief The loop tiles */
Array<Array<tir::LoopRV>> tiles;
/*! \brief The factors of the loop tiles. */
Array<Array<tir::ExprRV>> tile_factors;
/*! \brief The mapping from buffer index to read cache block. */
std::unordered_map<int, tir::BlockRV> read_reuse;
/*! \brief The mapping from buffer index to write cache block. */
Expand Down Expand Up @@ -163,8 +165,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);

virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV block,
tir::LoopRV loop, int n_tiles) const;
virtual std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const tir::Schedule& sch,
tir::BlockRV block,
tir::LoopRV loop,
int n_tiles) const;

// Annotate a block to use cooperative fetching
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;
Expand Down
176 changes: 163 additions & 13 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/tir/op.h>

#include <algorithm>
#include <utility>
Expand Down Expand Up @@ -124,6 +125,9 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
private:
// SubRule: Add tensorization-related transformations
inline std::vector<State> TransformForTensorization(TensorCoreState state) const;
// Subrule: Transform the layout of the output. This is necessary for efficient cache write the
// output in the shared memory.
std::vector<State> TransformIntermediateOutputLayout(TensorCoreState state);
// Subrule: Add tensorized load
inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) const;
// Subrule: Add tensorized store
Expand Down Expand Up @@ -225,6 +229,9 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
return TransformForTensorization(Downcast<TensorCoreState>(state));
});
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
states = SubRule(std::move(states), [&](State state) {
return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));
});
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
states = SubRule(std::move(states), [&](State state) {
return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state));
Expand All @@ -248,25 +255,162 @@ void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch,
(*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name);
}

std::vector<State> MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout(
TensorCoreState state) {
// Transform the intermediate output to packed layout
// [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m, accum_elem_n]
// where warp_m, warp_n are thread indices bound to the warp id, accum_frag_m, accum_frag_n are
// the index of the fragments in each warp, accum_elem_m, accum_elem_n are the index of the
// elements in each accumulator fragment.

// Get the shape of the wmma accumulator
auto [frag_shape_m, frag_shape_n] = [&]() {
tir::Block intrin_block =
Downcast<tir::BlockRealize>(
tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body)
->block;
tir::For loop_m = Downcast<tir::For>(intrin_block->body);
tir::For loop_n = Downcast<tir::For>(loop_m->body);
return std::make_tuple(loop_m->extent, loop_n->extent);
}();

// Get the tile index of the warp id (i.e. threadIdx.y)
auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y");
ICHECK(it != tile_binds.end());
auto tile_index_warp_id = std::distance(tile_binds.begin(), it);

// Get the extent of loop indicated by `loop_idx` inside the warp scope.
// For example, after spatial loops i, j are tiled, we will have
// tile_factors = ((i0, j0), (i1, j1), ..., (in, jn))
// This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id.
// `loop_idx` can be negative, in which case it is counted from the end.
auto f_get_inner_tile_product = [&](int loop_idx) {
Array<tir::ExprRV> factors;
for (int i = tile_index_warp_id + 1; i < static_cast<int>(s_indices_.size()); ++i) {
auto s_factors = state->tile_factors[s_indices_[i]];
if (loop_idx < 0) {
loop_idx += s_factors.size();
}
factors.push_back(s_factors[loop_idx]);
}
ICHECK(!factors.empty());
if (factors.size() == 1) {
return factors[0];
}
auto result = factors[0];
for (int i = 1; i < static_cast<int>(factors.size()); ++i) {
result = result * factors[i];
}
return result;
};

// Compute the number of output fragment of each warp
auto warp_num_frag_m = f_get_inner_tile_product(-2);
auto warp_num_frag_n = f_get_inner_tile_product(-1);

Schedule& sch = state->sch;
int buffer_ndim = static_cast<int>(sch->Get(state->block_rv)->writes[0]->buffer->shape.size());
// The dimension of the buffer should be larger or same as that of the tensor intrin.
ICHECK_GE(buffer_ndim, 2);
int num_higher_dims = buffer_ndim - 2;

auto index_map =
tir::IndexMap::FromFunc(buffer_ndim,
// frag_shape_m and frag_shape_n are structural bindings that cannot
// not be automatically captured until c++20
[&, frag_shape_m = frag_shape_m,
frag_shape_n = frag_shape_n](const Array<tir::Var>& indices) {
Array<PrimExpr> result;
result.reserve(indices.size() + 4);
for (int i = 0; i < num_higher_dims; ++i) {
result.push_back(indices[i]);
}
const auto& m = indices[num_higher_dims];
const auto& n = indices[num_higher_dims + 1];
auto accum_m = floormod(m, frag_shape_m);
auto accum_n = floormod(n, frag_shape_n);
auto outer_m = floordiv(m, frag_shape_m);
auto outer_n = floordiv(n, frag_shape_n);

result.push_back(floordiv(outer_m, warp_num_frag_m));
result.push_back(floordiv(outer_n, warp_num_frag_n));
result.push_back(floormod(outer_m, warp_num_frag_m));
result.push_back(floormod(outer_n, warp_num_frag_n));
result.push_back(accum_m);
result.push_back(accum_n);
return result;
});
sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map,
/*pad_value=*/NullOpt, /*assume_injective_transform=*/true);

return {state};
}

std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
TensorCoreState state) const {
// Add the cache write stage for Tensor Core
int level = r_indices_.front() - 1;
const LoopRV& loop = state->tiles[level].back();
Schedule& sch = state->sch;
auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator");
sch->ReverseComputeAt(cache_write, loop, true);

if (state->write_reuse.count(0)) {
// Fuse the iterators of the cache_write
Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
ICHECK_GT(buffer_loops.size(), 2);
sch->Fuse(Array<LoopRV>{buffer_loops.end() - 2, // The src shmem is always 2D
buffer_loops.end()});
AnnotateCooperativeFetching(&sch, state->write_reuse[0]);

// The compute block has been tiled by the warp shape and the fragment shape.
// We need to bind the cache write block (from the accumulator to the shared memory) to the warp
// id. The schedule is as follows:
//
// After adding cache write for wmma.accumulator, we will have
// for i0, j0, i1, j1, accum_m, accum_n:
// shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n]
// for i0', j0', i1', j1', accum_m', accum_n':
// global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
// shared_mem[i0', j0', i1', j1', accum_m', accum_n']
// where i0' and j0' are already bound to the block id and warp id.
//
// To reduce the shared memory usage and allow efficient data movement, we will apply
// transformations to generate the following schedule:
//
// for i1':
// for i0_j0 (fused and bound to threadIdx.y):
// for j1, accum_m, accum_n:
// shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n]
// for i0', j0', j1', accum_m', accum_n':
// global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
// shared_mem[i0', j0', i1', j1', accum_m', accum_n']
//
// i1' is reordered to the outermost. This effectively allows only a row (i.e. loop i1') of the
// fragments are moved to the shared memory and then to the global memory each time.
// As a result, shared memory for the output will only have shape of [j1, accum_m, accum_n]
// instead of [i0 * i1 * accum_m, j0 * j1 * accum_n].

// Get the loops other than the innermost two loops (accum_m and accum_n).
auto f_get_loops = [&](const BlockRV& block_rv) -> std::array<LoopRV, 4> {
Array<LoopRV> buffer_loops = sch->GetLoops(block_rv);
ICHECK_GT(buffer_loops.size(), 6);
return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5],
buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]};
};
{
const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]);
sch->Reorder({i1, i0, j0, j1});
sch->ComputeAt(cache_write, i1, true);
}
{
auto loops = f_get_loops(cache_write);
const auto& i0 = loops[0];
const auto& j0 = loops[1];
auto fused = sch->Fuse({i0, j0});
sch->Bind(fused, "threadIdx.y");
}

sch->ReverseComputeInline(state->tensor_core_reindex_store);
TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin);
auto loops = sch->GetLoops(cache_write);
auto blockized_store = sch->Blockize(loops[loops.size() - 2]);
sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize,
state->intrin_group.store_intrin);

Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
ICHECK_GT(buffer_loops.size(), 5);
sch->Fuse(Array<LoopRV>{buffer_loops.end() - 5, // The src shmem is always 2D
buffer_loops.end()});
AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
return {state};
}

Expand Down Expand Up @@ -508,7 +652,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
buffer_sub_index_map.Set(lhs_buffer, sub_index_map);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map,
/*pad_value=*/NullOpt, /*assume_injective_transform=*/true);
};

for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
Expand Down Expand Up @@ -569,6 +714,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);

CHECK(node->reuse_write_.req == ReuseType::kMustReuse &&
runtime::StorageScope::Create(node->reuse_write_.scope).rank ==
runtime::StorageRank::kShared)
<< "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore.";

node->intrin_groups.reserve(intrin_groups.size());
for (const auto& intrin_group_config : intrin_groups) {
node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode {
return ScheduleRule(n);
}

Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const;
std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const Schedule& sch, BlockRV block,
LoopRV loop, int n_tiles) const;
};

Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv,
LoopRV loop_rv, int n_tiles) const {
std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingWideVectorNode::SplitLoop(
const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const {
const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
Expand Down Expand Up @@ -99,12 +100,14 @@ Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch
Array<tir::LoopRV> outer_splits = sch->Split(
/*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()});
outer_splits.push_back(inner_splits[1]);
return outer_splits;
outer_factors.push_back(PrimExpr(vec_len));
return {outer_factors, outer_splits};
} else {
Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
factors.push_back(loop->extent);
return sch->Split(/*loop=*/loop_rv,
/*factors=*/{factors.begin(), factors.end()});
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop_rv,
/*factors=*/{factors.begin(), factors.end()});
return {factors, splits};
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
Map<Var, Buffer> buffer_var_map_;
/*! \brief The target buffer var mapping to its matching */
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
/*! \brief The analyzer for simplifying*/
arith::Analyzer analyzer_;

/*!
* \brief Update read/write buffers and regions with provided buffer and region
Expand Down Expand Up @@ -330,7 +328,12 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
for (size_t j = 0; j < regions[i].size(); j++) {
const tvm::arith::IntSet& range = regions[i][j];
region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j])));
if (range.IsSinglePoint()) {
PrimExpr min = range.min();
region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1)));
} else {
region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j])));
}
}
res.push_back(BufferRegion(buffers[i], region));
}
Expand Down
Loading

0 comments on commit 424c749

Please sign in to comment.