Skip to content

Commit

Permalink
[TIR] Blockize keeping T.init in inner block when outer block does no… (
Browse files Browse the repository at this point in the history
#172)

… reduction

This PR modifies the behavior of tir.Schedule.blockize, so that when the
outer block after blockization does no effective reduction, the T.init
part will be kept in the inner block.

NOTE: unit tests regarding blockize may fail due to structural
inequality.

Co-authored-by: Bohan Hou <spectrometerh@gmail.com>

Co-authored-by: spectrometerHBH <spectrometerh@gmail.com>
  • Loading branch information
MasterJH5574 and spectrometerHBH committed Apr 18, 2023
1 parent 29ec796 commit 6caa5af
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars,
IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
/*var=*/iter_var->var.copy_with_suffix("_o"),
/*iter_type=*/iter_var->iter_type);
outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
outer_iter_vars->push_back(outer_iter);
if (!is_one(outer_mark->extent)) {
outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
outer_iter_vars->push_back(outer_iter);
}
// create iter var for the inner block
IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
/*var=*/iter_var->var.copy_with_suffix("_i"),
Expand Down Expand Up @@ -266,7 +268,7 @@ BlockRealize GenerateInner(bool is_write_reduction,
Block block) {
BlockNode* n = block.CopyOnWrite();
n->iter_vars = iter_vars;
n->init = NullOpt;
n->init = is_write_reduction ? NullOpt : std::move(block->init);
if (is_write_reduction) {
Array<BufferRegion> reads;
reads.reserve(block->writes.size() + block->reads.size());
Expand Down Expand Up @@ -493,7 +495,7 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
/*name_hint=*/block_subst->name_hint + "_o",
/*body=*/MakeLoopNest(inner_realize, loops),
/*init=*/
block_subst->init.defined() //
block_subst->init.defined() && has_outer_reduction //
? GenerateOuterInit(block_subst->init.value(), inner_realize, loops,
block_subst->name_hint + "_init")
: Optional<Stmt>(NullOpt)));
Expand Down

0 comments on commit 6caa5af

Please sign in to comment.