Skip to content

Commit

Permalink
[TIR] Enhance software pipeline validation and fix predicate of epilo…
Browse files Browse the repository at this point in the history
…gue (apache#11106)

* Fix pipeline validation

* fix predicate

* Update test_tir_transform_inject_software_pipeline.py

* Update inject_software_pipeline.cc
  • Loading branch information
vinx13 authored and Sergey Shtin committed May 17, 2022
1 parent 8e0853c commit 2b0d817
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 9 deletions.
74 changes: 65 additions & 9 deletions src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,10 @@ class PipelineRewriter : public StmtExprMutator {
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var);
} else {
// normalize loop range
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min));
PrimExpr delta = start - pipeline_loop_->min;
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + delta);
Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, Map<Var, PrimExpr>{{loop_iter, loop_iter + delta}});
}
new_block = Downcast<Block>(Substitute(new_block, subst_map));
stmts.push_back(BlockRealize({}, inbound, new_block));
Expand Down Expand Up @@ -570,6 +573,40 @@ class PipelineRewriter : public StmtExprMutator {
Array<Block> ordered_stmts_;
};

/*!
* \brief Build the dependency graph among a array of blocks.
* \param[in] blocks The array of blocks.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the
* destination.
* \param[out] dep_dst2src Optional, a map to store dependency edges from the
* destination to the source.
*/
void BuildDependencyGraph(
const Array<Block>& blocks,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) {
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

for (const Block& block : blocks) {
for (const BufferRegion& read : block->reads) {
auto it = buffer_writers.find(read->buffer->data);
if (it != buffer_writers.end()) {
for (const Block& writer : it->second) {
if (dep_src2dst != nullptr) {
(*dep_src2dst)[writer].push_back(block);
}
if (dep_dst2src != nullptr) {
(*dep_dst2src)[block].push_back(writer);
}
}
}
}
for (const BufferRegion& write : block->writes) {
buffer_writers[write->buffer->data].push_back(block);
}
}
}

class PipelineInjector : private StmtExprMutator {
public:
static Stmt Inject(const PrimFunc& func) {
Expand All @@ -587,24 +624,43 @@ class PipelineInjector : private StmtExprMutator {

/*!
* \brief Check the pipeline satisfies the following conditions:
* 1) No conflicting order: The order of each statement should be unique.
* 2) No reordering with the same stage: Statements in the same stage are not allowed to be
* reordered.
* 1. No conflicting order: The order of each statement should be unique.
* 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for
* dependency (e.g. read-after-write) from statement A to statement B, it requires:
* case 1: stage(A) < stage(B)
* case 2: stage(A) == stage(B) and order(A) < order(B)
*/
void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) {
std::unordered_set<int> used_orders;
std::unordered_map<int, int> stage_max_order;
std::unordered_map<int, const Block*> order_to_block;
std::unordered_map<const Block*, int> block_to_stage;
for (const Block& block : original_order) {
const auto& stmt_info = pipeline_info.at(block);
int stage = stmt_info.stage;
int order = stmt_info.order;
CHECK(!used_orders.count(order))
<< "ValueError: Two statements in the software pipeline cannot have the same order";
used_orders.insert(order);
CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order)
<< "ValueError: Statements in the same stage of the software pipeline must have "
"increasing order.";
stage_max_order[stage] = order;
}

std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst;
BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

for (const auto& pair : dep_src2dst) {
const Block& src = pair.first;
const auto& src_info = pipeline_info.at(src);
const Array<Block>& dsts = pair.second;
for (const Block& dst : dsts) {
const auto& dst_info = pipeline_info.at(dst);
CHECK_LE(src_info.stage, dst_info.stage)
<< "ValueError: statement " << dst << " in stage " << dst_info.stage
<< " cannot depends on statement " << src << " in a later stage " << src_info.stage;
if (src_info.stage == dst_info.stage) {
CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer "
"access dependency in the same stage of the "
"software pipeline cannot be reordered";
}
}
}
}

Expand Down
201 changes: 201 additions & 0 deletions tests/python/unittest/test_tir_transform_inject_software_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,199 @@ def transformed_simple_compute(
C[tx, 15] = B[1, tx, 0] + T.float32(1)


@T.prim_func
def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(
0,
16,
annotations={
"software_pipeline_stage": [0, 1, 2],
"software_pipeline_order": [0, 1, 2],
},
):
with T.block():
T.reads(A[tx, i])
T.writes(D[tx, i])
B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
C = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
with T.block():
T.reads(A[tx, i])
T.writes(B[tx, 0])
B[tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.reads(B[tx, 0])
T.writes(C[tx, 0])
C[tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(C[tx, 0])
T.writes(D[tx, i])
D[tx, i] = C[tx, 0] + T.float32(1)


@T.prim_func
def transformed_three_stage_compute(
A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]
) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
T.reads(A[tx, 0:16])
T.writes(D[tx, 0:16])
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
with T.block():
T.reads(A[tx, 0:2], B[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[0:2, tx, 0])
for i in T.unroll(2):
with T.block():
T.reads(A[tx, i])
T.writes(B[0:2, tx, 0])
B[i, tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.where(1 <= i)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
for i in T.serial(14):
with T.block():
T.reads(A[tx, i + 2])
T.writes(B[0:2, tx, 0])
B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
with T.block():
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i])
D[tx, i] = C[i % 2, tx, 0] + T.float32(1)
with T.block():
T.reads(B[0:2, tx, 0], C[0:2, tx, 0])
T.writes(C[0:2, tx, 0], D[tx, 14:16])
for i in T.unroll(2):
with T.block():
T.where(i < 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i + 14])
D[tx, i + 14] = C[i, tx, 0] + T.float32(1)


@T.prim_func
def dag_interleaving(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(
0,
16,
annotations={
"software_pipeline_stage": [0, 0, 0, 0, 1],
"software_pipeline_order": [0, 2, 1, 3, 4],
},
):
with T.block():
T.reads(A[tx, i])
T.writes(C[tx, i])
AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
AL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
BL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
with T.block():
T.reads(A[tx, i])
T.writes(AS[tx, 0])
AS[tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.reads(AS[tx, 0])
T.writes(AL[0, 0])
AL[0, 0] = AS[tx, 0]
with T.block():
T.reads(B[tx, i])
T.writes(BS[tx, 0])
BS[tx, 0] = B[tx, i] + T.float32(2)
with T.block():
T.reads(BS[tx, 0])
T.writes(BL[0, 0])
BL[0, 0] = BS[tx, 0]
with T.block():
T.reads(AL[0, 0], BL[0, 0])
T.writes(C[tx, i])
C[tx, i] = AL[0, 0] * BL[0, 0]


@T.prim_func
def transformed_dag_interleaving(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
T.reads(A[tx, 0:16], B[tx, 0:16])
T.writes(C[tx, 0:16])
AS = T.alloc_buffer([16, 1], dtype="float32", scope="shared")
BS = T.alloc_buffer([16, 1], dtype="float32", scope="shared")
AL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local")
BL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local")
with T.block():
T.reads(A[tx, 0], B[tx, 0], AS[tx, 0], BS[tx, 0])
T.writes(AS[tx, 0], BS[tx, 0], AL[0, 0, 0], BL[0, 0, 0])
with T.block():
T.reads(A[tx, 0])
T.writes(AS[tx, 0])
AS[tx, 0] = A[tx, 0] * T.float32(2)
with T.block():
T.reads(B[tx, 0])
T.writes(BS[tx, 0])
BS[tx, 0] = B[tx, 0] + T.float32(2)
with T.block():
T.reads(AS[tx, 0])
T.writes(AL[0, 0, 0])
AL[0, 0, 0] = AS[tx, 0]
with T.block():
T.reads(BS[tx, 0])
T.writes(BL[0, 0, 0])
BL[0, 0, 0] = BS[tx, 0]
with T.block():
T.reads(
A[tx, 1:16], B[tx, 1:16], AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0]
)
T.writes(AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0], C[tx, 0:15])
for i in T.serial(15):
with T.block():
T.reads(A[tx, i + 1])
T.writes(AS[tx, 0])
AS[tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block():
T.reads(B[tx, i + 1])
T.writes(BS[tx, 0])
BS[tx, 0] = B[tx, i + 1] + T.float32(2)
with T.block():
T.reads(AS[tx, 0])
T.writes(AL[(i + 1) % 2, 0, 0])
AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
with T.block():
T.reads(BS[tx, 0])
T.writes(BL[(i + 1) % 2, 0, 0])
BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
with T.block():
T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
T.writes(C[tx, i])
C[tx, i] = AL[i % 2, 0, 0] * BL[i % 2, 0, 0]
with T.block():
T.reads(AL[1, 0, 0], BL[1, 0, 0])
T.writes(C[tx, 15])
C[tx, 15] = AL[1, 0, 0] * BL[1, 0, 0]


@T.prim_func
def nested_pipeline_simple(
A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
Expand Down Expand Up @@ -792,6 +985,14 @@ def test_trivial_pipeline():
_check(trivial_pipeline, transformed_trivial_pipeline)


def test_three_stage_compute():
_check(three_stage_compute, transformed_three_stage_compute)


def test_dag_interleaving():
_check(dag_interleaving, transformed_dag_interleaving)


def test_nest_pipeline_simple():
_check(nested_pipeline_simple, transformed_nested_pipeline_simple)

Expand Down

0 comments on commit 2b0d817

Please sign in to comment.