Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compute-inline] Prefer T.where for reverse compute-inlined block with predicate #17128

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,14 @@ class ReverseComputeInliner : public BaseInliner {
using BaseInliner::VisitStmt_;

/*! \brief Generate the predicate after inlining based on the consumer predicate */
Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) {
BlockRealize BuildInlinedConsumerPredicate(BlockRealize producer_block_realize) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
Block producer_block = producer_block_realize->block;
for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
const IterVar& iter = producer_block->iter_vars[i];
const PrimExpr& binding = producer_block_realize->iter_values[i];
subst_map.Set(iter->var, binding);
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
}
if (producer_block->annotations.count(tir::attr::auto_copy) != 0) {
Expand All @@ -705,30 +708,33 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer_.Simplify(predicate);
ObjectPtr<BlockNode> block = make_object<BlockNode>(*producer_block);
if (is_one(predicate)) {
return Block(block);
}
if (const auto* if_ = producer_block->body.as<tir::IfThenElseNode>()) {
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
if (!StructuralEqual()(predicate, if_predicate)) {
predicate = analyzer_.Simplify(predicate && if_->condition);
return producer_block_realize;
}
if (const auto* if_ = producer_block->body.as<IfThenElseNode>()) {
if (!if_->else_case.defined()) {
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
if (!StructuralEqual()(predicate, if_predicate)) {
predicate = analyzer_.Simplify(predicate && if_->condition);
producer_block.CopyOnWrite()->body = if_->then_case;
}
}
block->body = IfThenElse(predicate, if_->then_case);
return Block(block);
}
block->body = IfThenElse(predicate, block->body);
return Block(block);
PrimExpr outer_predicate = Substitute(predicate, subst_map);
auto n = producer_block_realize.CopyOnWrite();
n->block = producer_block;
n->predicate = analyzer_.Simplify(outer_predicate);
return GetRef<BlockRealize>(n);
}

Stmt VisitStmt_(const BlockNode* op) final {
Block src_block = GetRef<Block>(op);
Block tgt_block = Downcast<Block>(BaseInliner::VisitStmt_(op));
if (op == producer_block_) {
tgt_block = BuildInlinedConsumerPredicate(tgt_block.get());
block_reuse.Set(src_block, tgt_block);
Stmt VisitStmt_(const BlockRealizeNode* op) final {
Block src_block = op->block;
BlockRealize tgt_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (src_block.get() == producer_block_) {
tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize);
block_reuse.Set(src_block, tgt_block_realize->block);
}
return std::move(tgt_block);
return std::move(tgt_block_realize);
}

Stmt VisitStmt_(const BufferStoreNode* _store) final {
Expand Down
20 changes: 10 additions & 10 deletions tests/python/dlight/test_gpu_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1, v2])
if v1 < m:
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -200,10 +200,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[0, v1, v2])
if v1 < m:
matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on

mod = tvm.IRModule({"main": func})
Expand Down Expand Up @@ -466,10 +466,10 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2], lv3[T.int64(0), v1, v2])
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
if v1 < n:
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]

# fmt: on

Expand Down Expand Up @@ -596,9 +596,9 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.reads(lv52[T.int64(0), v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
T.writes(var_T_multiply_intermediate[v1, v2])
if v1 < n:
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))

# fmt: on

Expand Down Expand Up @@ -666,10 +666,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1)
T.where(ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1, v2])
if v1 < m:
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on


Expand Down
20 changes: 10 additions & 10 deletions tests/python/dlight/test_gpu_matmul_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m and ax2_2 * 4 + ax2_0 * 2 + ax2_1_1 < 15)
T.reads(compute_reindex_pad_local[v0, v1, v2])
T.writes(compute[v1, v2])
if v1 < m and v2 < 15:
compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2]
compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -417,11 +417,11 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64),
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < n)
T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2])
T.writes(p_output0_intermediate[0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
if v1 < n:
p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -690,11 +690,11 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < m)
T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2])
T.writes(matmul_1[0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
if v1 < m:
matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -831,10 +831,10 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha
v0 = T.axis.spatial(1, ax0_1)
v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
T.reads(C_reindex_pad_shared[v0, v1, v2])
T.writes(C[v1, 0, v2])
if v1 < batch_size:
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -971,10 +971,10 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f
v0 = T.axis.spatial(1, ax0_1)
v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
T.reads(C_reindex_pad_shared[v0, v1, v2])
T.writes(C[v1, 0, v2])
if v1 < batch_size:
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,11 +856,11 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1
v3 = T.axis.spatial(1, 0)
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127)
T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16])
T.block_attr({"meta_schedule.cooperative_fetch": 4})
if v0 * 32 + v2 * 16 + v4 < 127 and v1 * 16 + v5 < 127:
compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
# fmt: on

decision_0 = [
Expand Down
Loading
Loading