Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【CINN】Add dynamic shape support for schedule primitives #59493

Merged
merged 10 commits into from
Dec 8, 2023
4 changes: 3 additions & 1 deletion paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ void CompilationTask::operator()() {
void CompilationTask::Lowering() {
auto op_lowerer = CreateOpLowerer<pir::GroupPtr>(context_->target_);
context_->SetLoweredFuncs(
op_lowerer.BucketLower(context_->group_, false, false, false));
op_lowerer.BucketLower(context_->group_, false, true, false));
// context_->SetLoweredFuncs(
// op_lowerer.BucketLower(context_->group_, false, false, false));
op_lowerer.InsertNameGeneToScope(context_->scope_);
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ OpLowererImpl::BucketLower(const GroupPtr& group,

// 2.Do group schedule.
ir::ModuleExpr mod_expr(func_bodies);
ir::IRSchedule ir_sch(mod_expr);
ir::IRSchedule ir_sch(
mod_expr, -1, false, cinn::utils::ErrorMessageLevel::kGeneral, true);
ir_sch.MergeExprs();
std::vector<std::pair<ir::SymbolicPredicate, ir::Expr>> cond2func_bodies;
VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
Expand Down
39 changes: 15 additions & 24 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,25 @@ namespace ir {
void DynamicShapeGroupScheduler::Schedule() {
// Fake schedule for test
int max_spacial_numel = 1;
ScheduleBlockNode* node = schedule_block_graph_->EndPoints()[0];
ir::Expr block_realize = node->Block();
std::vector<ir::Expr> loops = ir_sch_->GetLoops(block_realize);
ir::Expr extent = loops[0].As<ir::For>()->extent;
ScheduleBlockNode* node0 = schedule_block_graph_->StartPoints()[0];
ScheduleBlockNode* node1 = schedule_block_graph_->EndPoints()[0];

ir::Expr predicate1 = ir::LE::Make(extent, Expr(1024));
ir::Expr block_realize0 = node0->Block();
ir::Expr block_realize1 = node1->Block();

auto block0 = ir_sch_->GetBlock("var");
ir_sch_->ComputeInline(block0);
auto reorder1 = ir_sch_->Reorder("var_1", {1, 0});

auto loops1 = ir_sch_->GetLoops("var_1");
auto splited_loops1 = ir_sch_->DySplit(loops1[1], {-1, 1, 32});
ir_sch_->Bind(splited_loops1[1], "blockIdx.x");
ir_sch_->Bind(splited_loops1[2], "threadIdx.x");

ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024));
std::unique_ptr<ir::IRSchedule> new_ir_sch1 =
std::make_unique<ir::IRSchedule>(*ir_sch_);
ScheduleBlockGraph sbg1(*new_ir_sch1);
sbg1.NodesWalk([&](ir::ScheduleBlockNode* node) {
std::vector<cinn::ir::Expr> splited_loops =
new_ir_sch1->Split(new_ir_sch1->GetLoops(node->Block())[0], {-1, 1});
new_ir_sch1->Bind(splited_loops[1], "blockIdx.x");
new_ir_sch1->Bind(new_ir_sch1->GetLoops(node->Block())[2], "threadIdx.x");
});
ir_schs_.emplace_back(predicate1, std::move(new_ir_sch1));

ir::Expr predicate2 = ir::GT::Make(extent, Expr(1024));
std::unique_ptr<ir::IRSchedule> new_ir_sch2 =
std::make_unique<ir::IRSchedule>(*ir_sch_);
ScheduleBlockGraph sbg2(*new_ir_sch2);
sbg2.NodesWalk([&](ir::ScheduleBlockNode* node) {
std::vector<cinn::ir::Expr> splited_loops =
new_ir_sch2->Split(new_ir_sch2->GetLoops(node->Block())[0], {-1, 1024});
new_ir_sch2->Bind(splited_loops[1], "blockIdx.x");
new_ir_sch2->Bind(new_ir_sch2->GetLoops(node->Block())[2], "threadIdx.x");
});
ir_schs_.emplace_back(predicate2, std::move(new_ir_sch2));
}

std::vector<std::pair<SymbolicPredicate, ir::Expr>>
Expand Down
199 changes: 192 additions & 7 deletions paddle/cinn/ir/schedule/impl/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,55 @@
namespace cinn {
namespace ir {

void DyScheduleImpl::MergeExprs() { CINN_NOT_IMPLEMENTED; }
void DyScheduleImpl::MergeExprs() {
auto exprs = this->GetModule().GetExprs();
if (exprs.size() == 1U) return;
CHECK(exprs[0].As<ir::Block>());
CHECK_EQ(exprs[0].As<ir::Block>()->stmts.size(), 1U);
CHECK(exprs[0].As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>());
CHECK(exprs[0]
.As<ir::Block>()
->stmts[0]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
std::vector<Expr> merged_block;
merged_block.push_back(exprs[0]
.As<ir::Block>()
->stmts[0]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body);
VLOG(3) << "Before merging, exprs[0] is : " << exprs[0];
for (int i = 1; i < exprs.size(); ++i) {
auto root_block = ir::ir_utils::CollectIRNodesWithoutTensor(
exprs[i],
[&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
x->As<ir::ScheduleBlockRealize>()->iter_values.empty();
},
true);
CHECK_EQ(root_block.size(), 1U);
for (auto& it_block : root_block) {
auto& block_body = it_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body;
merged_block.push_back(block_body);
}
}
for (auto& block : merged_block) {
VLOG(3) << "in merged_block, it has " << block;
}
auto merged_expr = ir::Block::Make(merged_block);
exprs[0]
.As<ir::Block>()
->stmts[0]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body = merged_expr;
VLOG(3) << "After merging, exprs[0] is : " << exprs[0];
exprs.erase(exprs.begin() + 1, exprs.end());
this->SetExprs(exprs);
}

bool DyScheduleImpl::HasBlock(const std::string& block_name) const {
auto exprs = module_expr_.GetExprs();
Expand Down Expand Up @@ -64,22 +112,158 @@ DeviceAPI DyScheduleImpl::GetDeviceAPI() const {
void DyScheduleImpl::Annotate(const Expr& block,
const std::string& key,
const attr_t& value) {
CINN_NOT_IMPLEMENTED;
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
auto copied_block = ir::ir_utils::IRCopy(block);
auto* schedule_block = copied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
schedule_block->attrs.emplace(key, value);
this->Replace(block, copied_block);
}

void DyScheduleImpl::Unannotate(Expr& block,
const std::string& key) { // NOLINT
CINN_NOT_IMPLEMENTED;
const std::string& ann_key) { // NOLINT
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
auto* schedule_block = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
if (schedule_block->attrs.count(ann_key)) {
schedule_block->attrs.erase(ann_key);
} else {
LOG(WARNING) << "Can't find annotation with key: " << ann_key;
return;
}
}

void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
const Expr& block_target) {
CINN_NOT_IMPLEMENTED;
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block_target.As<ir::ScheduleBlockRealize>());
auto exprs = this->GetModule().GetExprs();
CHECK_EQ(exprs.size(), 1U);
auto expr = exprs[0];
auto vars = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
auto vars_target = block_target.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
auto old_iter_values = block.As<ir::ScheduleBlockRealize>()->iter_values;
auto iter_values_target =
block_target.As<ir::ScheduleBlockRealize>()->iter_values;
std::vector<Expr> new_iter_values;
for (int i = 0; i < vars.size() && i < vars_target.size(); ++i) {
CHECK(vars[i]->upper_bound.defined() &&
vars_target[i]->upper_bound.defined());
if (vars[i]->upper_bound.is_constant() &&
vars_target[i]->upper_bound.is_constant() &&
vars[i]->upper_bound.get_constant() ==
vars_target[i]->upper_bound.get_constant() &&
!vars[i]->is_reduce_axis && !vars_target[i]->is_reduce_axis) {
new_iter_values.push_back(iter_values_target[i]);
VLOG(3) << "new_iter_values.push_back " << iter_values_target[i];
} else {
break;
}
}

if (new_iter_values.empty())
LOG(FATAL) << "Cannot CopyTransformAndLoopInfo since shape[0] of source "
"and target is not equal! "
<< vars[0]->upper_bound << " v.s "
<< vars_target[0]->upper_bound;

int changed_loop_num = new_iter_values.size();
std::set<std::string> used_target_loop_vars;
for (auto& iter_val : new_iter_values) {
auto find_partial_loop =
ir::ir_utils::CollectIRNodesWithoutTensor(iter_val, [&](const Expr* x) {
if (x->as_var()) used_target_loop_vars.insert(x->as_var_ref()->name);
return x->as_var();
});
}
CHECK(!used_target_loop_vars.empty());
std::vector<Expr> used_target_loops;
auto expr_copy = ir::ir_utils::IRCopy(expr);
for (auto& var : used_target_loop_vars) {
auto find_loop_var = ir::ir_utils::CollectIRNodesWithoutTensor(
expr_copy,
[&](const Expr* x) {
return x->As<ir::For>() && x->As<ir::For>()->loop_var->name == var &&
Contains(*x, block_target);
},
true);
CHECK_EQ(find_loop_var.size(), 1U);
used_target_loops.push_back(*find_loop_var.begin());
VLOG(3) << "used_target_loops push_back " << used_target_loops.back();
}
std::sort(
used_target_loops.begin(), used_target_loops.end(), [&](Expr i, Expr j) {
return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size());
});
for (int i = new_iter_values.size(); i < old_iter_values.size(); ++i) {
CHECK(old_iter_values[i].as_var());
new_iter_values.push_back(old_iter_values[i]);
}
Expr new_loop;
VLOG(3) << "changed_loop_num is : " << changed_loop_num;
VLOG(3) << "old_iter_values.size() is : " << old_iter_values.size();
if (changed_loop_num >= static_cast<int>(old_iter_values.size())) {
new_loop = ir::ir_utils::IRCopy(block);
new_loop.As<ir::ScheduleBlockRealize>()->iter_values = new_iter_values;
} else {
CHECK(old_iter_values[changed_loop_num].as_var());
auto old_var = old_iter_values[changed_loop_num].as_var_ref();
auto find_partial_loop = ir::ir_utils::CollectIRNodesWithoutTensor(
expr,
[&](const Expr* x) {
return x->As<ir::For>() &&
x->As<ir::For>()->loop_var->name == old_var->name &&
Contains(*x, block);
},
true);
CHECK_EQ(find_partial_loop.size(), 1U);
new_loop = ir::ir_utils::IRCopy(*find_partial_loop.begin());
auto find_schedule_block = ir::ir_utils::CollectIRNodesWithoutTensor(
new_loop,
[&](const Expr* x) { return x->As<ir::ScheduleBlockRealize>(); },
true);
CHECK_EQ(find_schedule_block.size(), 1U);
Expr sch_block = (*find_schedule_block.begin());
sch_block.As<ir::ScheduleBlockRealize>()->iter_values = new_iter_values;
}
VLOG(3) << "new_loop is : " << new_loop;
CHECK(!used_target_loops.empty());
Expr res;
if (used_target_loops.size() == 1) {
auto for_loop = used_target_loops[0].As<ir::For>();
res = For::Make(for_loop->loop_var,
for_loop->min,
for_loop->extent,
for_loop->for_type(),
for_loop->device_api,
new_loop,
for_loop->vectorize_info(),
for_loop->bind_info());
} else {
Expr outer_loop = used_target_loops.front();
Expr inner_loop = used_target_loops.back();
inner_loop.As<ir::For>()->body = Block::Make({new_loop});
res = outer_loop;
}
VLOG(3) << "res is : " << res;
std::vector<Expr> all_loops = this->GetLoops(block);
CHECK(!all_loops.empty());
this->Replace(all_loops[0], res);
}

void DyScheduleImpl::CopyTransformAndLoopInfo(
const std::string& block_name, const std::string& block_target_name) {
CINN_NOT_IMPLEMENTED;
auto block = this->GetBlock(block_name);
auto block_target = this->GetBlock(block_target_name);
this->CopyTransformAndLoopInfo(block, block_target);
}

Expr DyScheduleImpl::SampleCategorical(
Expand All @@ -98,7 +282,8 @@ std::vector<Expr> DyScheduleImpl::SamplePerfectTile(
}

Expr DyScheduleImpl::AddUnitLoop(const Expr& block) const {
CINN_NOT_IMPLEMENTED;
auto exprs = module_expr_.GetExprs();
return analyzer::AddUnitLoop(exprs, block);
}

} // namespace ir
Expand Down
31 changes: 29 additions & 2 deletions paddle/cinn/ir/schedule/impl/compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,38 @@ void DyScheduleImpl::ReverseComputeAt(const Expr& block,
}

void DyScheduleImpl::ComputeInline(const Expr& schedule_block) {
CINN_NOT_IMPLEMENTED;
CHECK(schedule_block.As<ir::ScheduleBlockRealize>());
Expr root = this->GetRootBlock(schedule_block);
Expr store = CheckComputeInlineValidationAndGetStore(schedule_block, root);
ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(), store);
CHECK(inliner.BodyPatternAllowInline());
// Create a plan that removes the block to be inlined
LeafBlockRemovalPlan remove_plan(
schedule_block, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root);
inliner(&root);
return;
}

void DyScheduleImpl::ReverseComputeInline(const Expr& schedule_block) {
CINN_NOT_IMPLEMENTED;
Expr root = this->GetRootBlock(schedule_block);
auto exprs =
CheckReverseComputeInlineValidationAndGetExprs(schedule_block, root);
Expr inlined_load = std::get<0>(exprs);
Expr inlined_store = std::get<1>(exprs);
Expr target_store = std::get<2>(exprs);
ReverseComputeInliner inliner(
inlined_store.As<ir::Store>()->tensor.as_tensor_ref(),
inlined_store,
inlined_load,
target_store);
CHECK(inliner.BodyPatternAllowInline());
// Create a plan that removes the block to be inlined
LeafBlockRemovalPlan remove_plan(
schedule_block, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root);
inliner(&root);
inliner(&root);
}

} // namespace ir
Expand Down
Loading