Skip to content

Commit

Permalink
Hot fix for bound predicate (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored Jan 6, 2022
1 parent c2f8106 commit 0f3892b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class BufferFlattener : public StmtExprMutator {
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// If the block has bound predicates, transform it to if-then-else
const Optional<ObjectRef>& bound_predicate =
new_block->annotations.Get(tir::attr::require_block_var_bound_predicate);
if (bound_predicate.defined()) {
body = IfThenElse(Downcast<PrimExpr>(bound_predicate.value()), std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer& buffer = new_block->alloc_buffers[i - 1];
Expand Down
62 changes: 62 additions & 0 deletions tests/python/unittest/test_tir_transform_flatten_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,64 @@ def annotated_loops(a: T.handle) -> None:
A[i] = 0.0


@T.prim_func
def tiled_pooling_cache_after_compute_at(a: T.handle, b: T.handle) -> None:
X = T.match_buffer(a, [224, 224], dtype="float32")
Y = T.match_buffer(b, [224, 224], dtype="float32")
# body
# with T.block("root")
cache = T.alloc_buffer([10, 10], dtype="float32")
dache = T.alloc_buffer([10, 10], dtype="float32")
for hh_0, ww_0 in T.grid(28, 28):
for ax0, ax1 in T.grid(10, 10):
with T.block("cache"):
T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.writes(cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224})
cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]
for ax0, ax1 in T.grid(10, 10):
with T.block("dache"):
T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.writes(dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224})
dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]
for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
with T.block("compute"):
T.reads(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1], cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1], dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1])
T.writes(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1])
Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1] = T.max(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1],
T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool")
and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool")
and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool")
and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"),
cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1]
+ dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1],
T.float32(0), dtype="float32"))


@T.prim_func
def flattened_tiled_pooling_cache_after_compute_at(X: T.Buffer[(224, 224), "float32"], Y: T.Buffer[(224, 224), "float32"]) -> None:
cache = T.allocate([100], "float32", "global")
dache = T.allocate([100], "float32", "global")
for hh_0, ww_0 in T.grid(28, 28):
for ax0, ax1 in T.grid(10, 10):
if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225:
T.store(cache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True)
for ax0, ax1 in T.grid(10, 10):
if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225:
T.store(dache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True)
for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
T.store(Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1,
T.max(T.load("float32", Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1),
T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool")
and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool")
and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool")
and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"),
T.load("float32", cache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11)
+ T.load("float32", dache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11),
T.float32(0), dtype="float32")), True)


def test_elementwise():
_check(compacted_elementwise_func, flattened_elementwise_func)

Expand Down Expand Up @@ -305,6 +363,10 @@ def test_annotated_loops():
tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))


def test_bound_predicate():
_check(tiled_pooling_cache_after_compute_at, flattened_tiled_pooling_cache_after_compute_at)


if __name__ == "__main__":
test_elementwise()
test_gpu_workload()
Expand Down

0 comments on commit 0f3892b

Please sign in to comment.