From 6caa5aff45615696683bd1be4a833c378bacee08 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 7 Apr 2023 16:00:56 -0400 Subject: [PATCH] =?UTF-8?q?[TIR]=20Blockize=20keeping=20T.init=20in=20inne?= =?UTF-8?q?r=20block=20when=20outer=20block=20does=20no=E2=80=A6=20(#172)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … reduction This PR modifies the behavior of tir.Schedule.blockize, so that when the outer block after blockization does no effective reduction, the T.init part will be kept in the inner block. NOTE: unit tests regarding blockize may fail due to structural inequality. Co-authored-by: Bohan Hou Co-authored-by: spectrometerHBH --- src/tir/schedule/primitive/blockize_tensorize.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index b1645d5cbd..5a9c55af00 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -229,8 +229,10 @@ Map DeriveBlockBinding(const Array& iter_vars, IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent), /*var=*/iter_var->var.copy_with_suffix("_o"), /*iter_type=*/iter_var->iter_type); - outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); - outer_iter_vars->push_back(outer_iter); + if (!is_one(outer_mark->extent)) { + outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); + outer_iter_vars->push_back(outer_iter); + } // create iter var for the inner block IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent), /*var=*/iter_var->var.copy_with_suffix("_i"), @@ -266,7 +268,7 @@ BlockRealize GenerateInner(bool is_write_reduction, Block block) { BlockNode* n = block.CopyOnWrite(); n->iter_vars = iter_vars; - n->init = NullOpt; + n->init = is_write_reduction ? NullOpt : std::move(block->init); if (is_write_reduction) { Array reads; reads.reserve(block->writes.size() + block->reads.size()); @@ -493,7 +495,7 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, /*name_hint=*/block_subst->name_hint + "_o", /*body=*/MakeLoopNest(inner_realize, loops), /*init=*/ - block_subst->init.defined() // + block_subst->init.defined() && has_outer_reduction // ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, block_subst->name_hint + "_init") : Optional(NullOpt)));