Skip to content

Commit

Permalink
[MetaSchedule] Handle 'warp_execution' in RewriteCooperativeFetch (#1…
Browse files Browse the repository at this point in the history
…1955)

Updated `RewriteCooperativeFetch` to handle 'warp_execution' annotation when the extend of `threadIdx.x` is not specified
  • Loading branch information
vinx13 authored Jun 30, 2022
1 parent 6424f1f commit 26ad703
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 2 deletions.
33 changes: 32 additions & 1 deletion src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst,
return Downcast<BlockRV>(inst->inputs[0]);
}

/*!
* \brief Parse instruction: sch.annotate(..., attr::warp_execution)
* \param sch The schedule
* \param inst The instruction to be parsed
* \return Whether ths parsing is successful
*/
bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) {
static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_kind_annotate)) {
return false;
}
ICHECK_EQ(inst->inputs.size(), 2);
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
return ann_key == attr::warp_execution;
}

} // namespace tir

namespace meta_schedule {
Expand All @@ -76,14 +93,24 @@ namespace meta_schedule {
class RewriteCooperativeFetchNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
void InitializeWithTuneContext(const TuneContext& context) final {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size")) {
this->thread_warp_size_ = v.value()->value;
} else {
TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not defined in the target";
}
}

// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;

void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode);

private:
int thread_warp_size_ = -1;
};

bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
Expand All @@ -101,6 +128,10 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
thread_extent_y = new_thread_extent.value()->value;
continue;
}
if (tir::ParseWarpExecutionAnn(sch, inst)) {
thread_extent_x = thread_warp_size_;
continue;
}
Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane);
if (!opt_block_rv.defined()) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

import tvm
import tvm.testing
from tvm import tir
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.postproc import RewriteCooperativeFetch
Expand Down Expand Up @@ -99,6 +100,108 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
C[v0, v1] = C_local[v0, v1]


@tvm.script.ir_module
class WarpExecutionAfterRewrite:
@T.prim_func
def main(
A: T.Buffer[(512, 512), "float32"],
B: T.Buffer[(512, 512), "float32"],
C: T.Buffer[(512, 512), "float32"],
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"):
for i2_0 in T.serial(0, 1):
for ax0_ax1_fused_0 in T.serial(0, 1024):
for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(
0, 32, thread="threadIdx.x"
):
with T.block("A_shared"):
v0 = T.axis.spatial(
512,
(
ax0_ax1_fused_0 * 256
+ ax0_ax1_fused_1 * 32
+ ax0_ax1_fused_2
)
// 512,
)
v1 = T.axis.spatial(
512,
(
ax0_ax1_fused_0 * 256
+ ax0_ax1_fused_1 * 32
+ ax0_ax1_fused_2
)
% 512,
)
T.reads([A[v0, v1]])
T.writes([A_shared[v0, v1]])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in T.serial(0, 32):
for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(
0, 32, thread="threadIdx.x"
):
for ax0_ax1_fused_3 in T.vectorized(0, 2):
with T.block("B_shared"):
v0 = T.axis.spatial(
512,
(
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 64
+ ax0_ax1_fused_2 * 2
+ ax0_ax1_fused_3
)
// 32,
)
v1 = T.axis.spatial(
512,
i0_0_i1_0_fused * 32
+ (
ax0_ax1_fused_0 * 512
+ ax0_ax1_fused_1 * 64
+ ax0_ax1_fused_2 * 2
+ ax0_ax1_fused_3
)
% 32,
)
T.reads([B[v0, v1]])
T.writes([B_shared[v0, v1]])
B_shared[v0, v1] = B[v0, v1]
for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2):
with T.block("C"):
i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
j = T.axis.spatial(
512,
i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4,
)
k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 + i2_2)
T.reads([A_shared[i, k], B_shared[k, j]])
T.writes([C_local[i, j]])
T.block_attr({"warp_execution": 1})
with T.init():
C_local[i, j] = T.float32(0)
C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j]
for ax0, ax1 in T.grid(32, 4):
with T.block("C_local"):
v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0)
v1 = T.axis.spatial(
512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1
)
T.reads([C_local[v0, v1]])
T.writes([C[v0, v1]])
C[v0, v1] = C_local[v0, v1]


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

Expand Down Expand Up @@ -147,5 +250,51 @@ def test_rewrite_cooperative_fetch():
tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0)


def test_rewrite_warp_execution():
mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512))
target = _target()
ctx = _create_context(mod, target)

sch = tir.Schedule(mod, debug_mask="all")
# fmt: off
# pylint: disable=line-too-long,invalid-name
b0 = sch.get_block(name="C", func_name="main")
b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")
l2, l3, l4 = sch.get_loops(block=b0)
sch.annotate(b0, "warp_execution", 1)
v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16])
l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])
v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2])
l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])
v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32])
l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])
sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)
l31 = sch.fuse(l10, l20)
sch.bind(loop=l31, thread_axis="blockIdx.x")
l32 = sch.fuse(l11, l21)
sch.bind(loop=l32, thread_axis="vthread.x")
l33 = sch.fuse(l12, l22)
sch.bind(loop=l33, thread_axis="threadIdx.y")
b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)
_, _, _, _, l39, l40 = sch.get_loops(block=b34)
l41 = sch.fuse(l39, l40)
_, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1])
sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)
b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)
_, _, _, _, l49, l50 = sch.get_loops(block=b44)
l51 = sch.fuse(l49, l50)
_, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2])
sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)
sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)
# pylint: enable=line-too-long,invalid-name
# fmt: on
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
print(sch.mod["main"].script())
tvm.ir.assert_structural_equal(sch.mod, WarpExecutionAfterRewrite)


if __name__ == "__main__":
test_rewrite_cooperative_fetch()
tvm.testing.main()

0 comments on commit 26ad703

Please sign in to comment.