Skip to content

Commit

Permalink
[CINN] Replace Old Stage Style Lower and Polyhedral ASTGen in GraphCo…
Browse files Browse the repository at this point in the history
…mpiler to the New Lower and ASTGen (#57454)

Replace old stage style lower and Polyhedral ASTGen in graph_compiler and op_lowering_impl to the new lower and ASTGen.

TODO: if this PR successfully run recently, we will remove old style stage & schedule completely and clean code in the next PR. In the next PR, we will rename LowerToAst to Lower, rename LowerToAstVec to LowerVec, and replace the test codes where use them.
  • Loading branch information
zhhsplendid authored Oct 8, 2023
1 parent c13609a commit a498e0b
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 84 deletions.
83 changes: 83 additions & 0 deletions paddle/cinn/ast_gen_ius/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/cinn/ir/operation.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"

namespace cinn {
namespace ast_gen_ius {
Expand Down Expand Up @@ -84,11 +85,75 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
tensor_group->MarkShareMemBuffer(tensor, init_tensor);
tensor_group->CtrlDepend(tensor, init_tensor);
Expr init_body = ir::Store::Make(init_tensor, init_value, axis_exprs);
// create schedule block itervars, i0,i1...
std::vector<ir::Var> block_vars;
std::vector<ir::Expr> iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
std::vector<Var> axis_vars = common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
block_vars.push_back(Var(Expr(0),
shape[i],
cinn::UniqName("i" + std::to_string(i)),
/*is_reduce = */ false));
optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars[i]);
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
}
}
init_body = ir::ScheduleBlockRealize::Make(
iter_values,
ir::ScheduleBlock::Make(
block_vars, {}, {}, reduce_init_name, init_body));

// For the remaining reduce axis, make reduce body
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
ir::Expr reduce_body =
ConvertReduceBody(tensor->body(), tensor, axis_exprs);
// create schedule block itervars, i0,i1...
std::vector<ir::Var> reduce_block_vars;
std::vector<ir::Expr> reduce_iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
std::vector<Var> reduce_axis_vars = common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
reduce_block_vars.push_back(Var(Expr(0),
shape[i],
cinn::UniqName("i" + std::to_string(i)),
/*is_reduce = */ false));
reduce_axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
reduce_iter_values.push_back(Expr(0));
} else {
reduce_iter_values.push_back(axis_vars[i]);
}
}
for (int i = 0; i < reduce_axis.size(); ++i) {
int count = shape.size() + i;
reduce_block_vars.push_back(
Var(reduce_axis[i]->lower_bound,
reduce_axis[i]->upper_bound,
cinn::UniqName("i" + std::to_string(count)),
/*is_reduce = */ true));
ir::Var reduce_axis_var = reduce_axis[i];
reduce_axis_var->is_reduce_axis = true;
reduce_iter_values.push_back(reduce_axis_var);
}
for (int i = 0; i < axis.size(); ++i) {
optim::ReplaceVarWithExpr(&reduce_body, axis[i], reduce_block_vars[i]);
}
for (int i = axis.size(); i < reduce_block_vars.size(); ++i) {
optim::ReplaceVarWithExpr(
&reduce_body, reduce_axis[i - axis.size()], reduce_block_vars[i]);
}

reduce_body = ir::ScheduleBlockRealize::Make(
reduce_iter_values,
ir::ScheduleBlock::Make(
reduce_block_vars, {}, {}, tensor->name, reduce_body));
for (int i = static_cast<int>(reduce_axis.size()) - 1; i >= 0; --i) {
reduce_body = ir::For::Make(reduce_axis[i],
reduce_axis[i]->lower_bound,
Expand All @@ -114,6 +179,24 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
return body;
} else {
ir::Expr body = ir::Store::Make(tensor, tensor->body(), axis_exprs);
// create schedule block itervars, i0,i1...
std::vector<ir::Var> block_vars;
std::vector<ir::Expr> iter_values;
std::vector<Var> axis_vars = common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
block_vars.push_back(Var(
Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false));
optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]);
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
}
}
body = ir::ScheduleBlockRealize::Make(
iter_values,
ir::ScheduleBlock::Make(block_vars, {}, {}, tensor->name, body));
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
ir::Var loop_var = axis[i];
ir::Expr loop_extent = shape[i];
Expand Down
94 changes: 81 additions & 13 deletions paddle/cinn/ast_gen_ius/tensor_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,37 @@
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/poly/stage.h"

namespace cinn {
namespace ast_gen_ius {

TensorGroup::TensorGroup(const std::vector<ir::Tensor>& tensors) {
std::set<ir::Tensor> all_tensors(tensors.begin(), tensors.end());

for (auto& tensor : tensors) {
for (const ir::Tensor& tensor : tensors) {
output_tensor_names_.insert(tensor->name);
std::set<ir::Expr> used_tensors = ir::ir_utils::CollectIRNodes(
tensor->body(), [](const Expr* x) { return x->as_tensor(); });
for (const Expr& x : used_tensors) {
const ir::Tensor to_dep = x.as_tensor_ref();
all_tensors.insert(to_dep);
this->CtrlDepend(tensor, to_dep);
this->Insert(tensor);
}
}

void TensorGroup::ShowLog() const {
VLOG(6) << "Showing log for TensorGroup";
for (auto& p : name_to_tensor_) {
VLOG(6) << "Tensor name = " << p.first << " depends on {";
if (ctrl_dep_.count(p.first)) {
for (auto& dep_name : ctrl_dep_.at(p.first)) {
VLOG(6) << dep_name;
}
}
VLOG(6) << "}";
}
}

for (const ir::Tensor& t : all_tensors) {
name_to_tensor_.insert({t->name, t});
TensorGroup::TensorGroup(
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
for (const auto& map_pair : tensor_map) {
const ir::Tensor& tensor = map_pair.second;
output_tensor_names_.insert(tensor->name);
this->Insert(tensor);
}
}

Expand All @@ -51,7 +62,23 @@ bool TensorGroup::Contain(const std::string& name) const {
}

void TensorGroup::Insert(const ir::Tensor& tensor) {
name_to_tensor_.insert({tensor->name, tensor});
if (!name_to_tensor_.count(tensor->name)) {
name_to_tensor_.insert({tensor->name, tensor});
}

// Using set to de-duplicate
std::set<ir::Tensor> dep_tensors;
std::set<ir::Expr> used_tensors = ir::ir_utils::CollectIRNodes(
tensor->body(), [](const Expr* x) { return x->as_tensor(); });
for (const Expr& x : used_tensors) {
const ir::Tensor to_dep = x.as_tensor_ref();
dep_tensors.insert(to_dep);
this->CtrlDepend(tensor, to_dep);
}

for (const ir::Tensor& t : dep_tensors) {
this->Insert(t);
}
}

ir::Tensor TensorGroup::Get(const std::string& name) {
Expand All @@ -72,6 +99,8 @@ std::vector<ir::Tensor> TensorGroup::GetGenFuncTopoOrder(
for (const auto& dep_pair : ctrl_dep_) {
const std::unordered_set<std::string>& dep_tensor_names = dep_pair.second;
in_degree[dep_pair.first] = dep_tensor_names.size();
VLOG(6) << "indegree[" << dep_pair.first
<< "] = " << dep_tensor_names.size();
}

std::vector<ir::Tensor> ret;
Expand All @@ -95,7 +124,6 @@ std::vector<ir::Tensor> TensorGroup::GetGenFuncTopoOrder(
while (!node_set.empty()) {
const std::string cur = *(node_set.begin());
node_set.erase(node_set.begin());

if (!input_arg_names.count(cur)) {
ret.push_back(name_to_tensor_[cur]);
}
Expand Down Expand Up @@ -187,5 +215,45 @@ absl::flat_hash_map<std::string, ir::Tensor> TensorGroup::AllocateBuffers() {
return name_to_tensor_;
}

void StageMapShareMemory(const poly::StageMap& stages) {
absl::flat_hash_map<std::string, ir::_Tensor_*> tensor_map;
for (auto& stage : stages) {
tensor_map[stage.second->tensor()->name] = stage.second->tensor();
}
for (auto& stage : stages) {
if (!stage.second->tensor()->buffer.defined() &&
!stage.second->meta.tensors_to_share_buffer_with.empty()) {
for (auto& str : stage.second->meta.tensors_to_share_buffer_with) {
if (tensor_map[str]->buffer.defined()) {
auto edited_shape = tensor_map[str]->buffer->shape;
stage.second->tensor()->Bind(tensor_map[str]->buffer);
tensor_map[str]->buffer->shape = edited_shape;
VLOG(3) << "Stage Tensor " << stage.second->tensor()->name
<< " bind buffer to " << tensor_map[str]->name << " , "
<< tensor_map[str]->buffer->name;
}
}
}
}
}

TensorGroup ConvertStageMapToTensorGroup(const poly::StageMap& stage_map) {
std::vector<ir::Tensor> stage_tensors;
std::set<ir::Tensor> reshape_tensors;
for (auto iter = stage_map.begin(); iter != stage_map.end(); ++iter) {
if (iter->second->has_expression()) {
const std::string& tensor_name = iter->first;
stage_tensors.push_back(ir::Tensor(iter->second->tensor()));
if (utils::Endswith(tensor_name, "_reshape")) {
reshape_tensors.insert(ir::Tensor(iter->second->tensor()));
}
}
}

ast_gen_ius::TensorGroup tensor_group(stage_tensors);
StageMapShareMemory(stage_map);
return tensor_group;
}

} // namespace ast_gen_ius
} // namespace cinn
15 changes: 15 additions & 0 deletions paddle/cinn/ast_gen_ius/tensor_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/poly/stage.h"

namespace cinn {
namespace ast_gen_ius {
Expand All @@ -41,11 +42,21 @@ class TensorGroup {
*/
explicit TensorGroup(const std::vector<ir::Tensor>& tensors);

/**
* Constructor for a TensorGroup, the argument tensors should be output tensor
* arguments of the AST body to be generated. The dependent tensors of the
* output tensors will be collected during construction.
*/
explicit TensorGroup(
const std::unordered_map<std::string, ir::Tensor>& tensor_map);

/**
* Destructor.
*/
~TensorGroup();

void ShowLog() const;

/**
* Returns true if TensorGroup collection contains a tensor with input name.
*/
Expand Down Expand Up @@ -119,5 +130,9 @@ class TensorGroup {
std::unordered_map<std::string, std::string> share_memory_tensor_;
};

// TODO(zhhsplendid): remove stage_map need to change all fcompute CINNValuePack
// we will change it in the next PR
TensorGroup ConvertStageMapToTensorGroup(const poly::StageMap& stage_map);

} // namespace ast_gen_ius
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ TEST_F(TestMultiLevelTiling, Matmul) {
ScheduleBlock(temp_matmul_out_local_temp_buffer)
{
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined)])
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(0:32)], _Y[reduce_k(0:32), j(undefined:undefined)])
write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)])
{
temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))]))
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST(CodeGenC, module) {
ast_gen_ius::TensorGroup tensor_group({A, B, C});
auto func = lang::LowerToAst("add1", {A, B, C}, &tensor_group);

LOG(INFO) << "Huihuang debug: " << func << std::endl;
LOG(INFO) << "Func to codegen: " << func << std::endl;

builder.AddFunction(func);

Expand Down
24 changes: 15 additions & 9 deletions paddle/cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include "paddle/cinn/utils/enum_string.h"
#include "paddle/cinn/utils/profiler.h"

#include "paddle/cinn/ast_gen_ius/tensor_group.h"

namespace cinn {
namespace hlir {
namespace framework {
Expand Down Expand Up @@ -372,14 +374,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(

poly::StageMap stages = C.back();
std::string func_name_prefix = "fn_";
auto funcs = lang::LowerVec(func_name_prefix + node_id,
stages,
all_arg_tensors,
{},
{},
nullptr,
target,
true);

ast_gen_ius::TensorGroup tensor_group =
ast_gen_ius::ConvertStageMapToTensorGroup(stages);
auto funcs = lang::LowerToAstVec(
func_name_prefix + node_id, all_arg_tensors, &tensor_group, target);

VLOG(4) << "Lower op: " << node_id << ", get " << funcs.size()
<< " LoweredFunc:\n";
for (auto fun : funcs) {
VLOG(4) << fun;
}

std::vector<common::CINNValue> schedule_inputs;
for (int i = 0; i < C.size() - 1; ++i) {
Expand Down Expand Up @@ -426,7 +431,8 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body));
#endif
auto temp_buffers = lang::GetTempBuffers(
all_arg_tensors, stages, funcs_after_schedule[i]->body);
all_arg_tensors, tensor_group, funcs_after_schedule[i]->body);

funcs_after_schedule[i]->temp_bufs = temp_buffers;
funcs_after_schedule[i] =
ir::_LoweredFunc_::Make(funcs_after_schedule[i]->name,
Expand Down
17 changes: 9 additions & 8 deletions paddle/cinn/hlir/framework/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/cinn/hlir/framework/op_lowering_impl.h"

#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/op_lowering_util.h"
Expand Down Expand Up @@ -391,16 +392,16 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
}

// 2.Do lower
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("fn_" + node->id(),
tmp_stages,
*op_func_arg_tensors,
{},
{},
nullptr,
this->target_,
true);
ast_gen_ius::TensorGroup tensor_group =
ast_gen_ius::ConvertStageMapToTensorGroup(tmp_stages);
std::vector<ir::LoweredFunc> funcs = lang::LowerToAstVec(
"fn_" + node->id(), *op_func_arg_tensors, {&tensor_group}, this->target_);

VLOG(4) << "Lower op: " << node->op()->name << ", get " << funcs.size()
<< " LoweredFunc:\n";
for (auto fun : funcs) {
VLOG(4) << fun;
}

op_func_arg_tensors->clear();
for (int idx = 0; idx < pack.size() - 1; ++idx) {
Expand Down
Loading

0 comments on commit a498e0b

Please sign in to comment.