Skip to content

Commit

Permalink
[IR] Refactor code in pd_op_to_kernel_pass.cc and reset vlog level (#…
Browse files Browse the repository at this point in the history
…57054)

* fix merge bug

* fix codestyle
  • Loading branch information
chen2016013 authored Sep 8, 2023
1 parent 14f00dc commit 9358b4b
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 180 deletions.
24 changes: 2 additions & 22 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ void HandleForSpecialOp(
variable_list);
}

if (op_name == "pd.feed") {
if (op_name == "pd.feed" || op_name == "pd.data") {
VLOG(6) << "Handle for" << op_name;
auto value = op->result(0);
VLOG(6) << "link feed output to feed in variable" << inner_scope;

Expand All @@ -273,27 +274,6 @@ void HandleForSpecialOp(
variable_list);
}

if (op_name == "pd.data") {
VLOG(6) << "Handle for pd.data";
auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();

auto value = op->result(0);

paddle::framework::Variable* var = inner_scope->FindVar(var_name);
PADDLE_ENFORCE(var,
paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", var_name));

AddNewData(value,
var_name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}

if (op_name == "builtin.combine") {
auto out_value = op->result(0);

Expand Down
331 changes: 173 additions & 158 deletions paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,166 @@ const std::unordered_set<std::string> UnchangeOutputOps = {
"builtin.get_parameter",
"pd.shadow_output"};

const std::unordered_set<std::string> SpecialOpList = {
"builtin.combine", "builtin.slice", "builtin.split"};

ir::OpResult GetNewInput(
const ir::Value cur_in,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
const int index,
const std::string op_name) {
PADDLE_ENFORCE_EQ(
map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST be in map pair", index, op_name));
auto new_in = map_value_pair.at(cur_in);
return new_in;
}

void DealWithSpecialBuiltinOps(
ir::Operation* op_item,
ir::Program* program,
std::unordered_map<ir::Operation*, ir::Operation*>* map_op_pair,
std::unordered_map<ir::Value, ir::OpResult>* map_value_pair,
ir::IrContext* ctx) {
if (op_item->name() == "builtin.combine") {
std::vector<phi::Place> out_places;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> vec_inner_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name());
vec_inputs.push_back(new_in);
vec_inner_types.push_back(new_in.type());
if (new_in.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place());
} else if (new_in.type()
.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
(*map_op_pair)[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
}

if (op_item->name() == "builtin.slice") {
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name());
vec_inputs.push_back(new_in);
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
auto index = op_item->attributes()
.at("index")
.dyn_cast<ir::Int32Attribute>()
.data();
op_output_types.push_back(vec_types[index]);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
(*map_op_pair)[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
}

if (op_item->name() == "builtin.split") {
std::vector<phi::Place> out_places(op_item->num_results());
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name());
vec_inputs.push_back(new_in);

if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
for (uint64_t idx = 0; idx < vec_types.size(); idx++) {
op_output_types.push_back(vec_types[idx]);
}
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
(*map_op_pair)[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
}

bool NeedFallBackCpu(const ir::Operation* op,
const std::string& kernel_fn_name,
const phi::KernelKey& kernel_key) {
Expand Down Expand Up @@ -620,6 +780,11 @@ phi::KernelKey GetKernelKey(

std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
phi::Place place) {
if (VLOG_IS_ON(2)) {
std::stringstream ss;
prog->Print(ss);
VLOG(2) << "Program after lowering to kernel pass : " << ss.str();
}
auto program = std::make_unique<ir::Program>(ir::IrContext::Instance());

auto block = prog->block();
Expand Down Expand Up @@ -647,163 +812,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
continue;
}

if (op_item->name() == "builtin.combine") {
std::vector<phi::Place> out_places;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> vec_inner_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
vec_inner_types.push_back(new_in.type());
if (new_in.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place());
} else if (new_in.type()
.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
out_places.push_back(
new_in.type()
.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue;
}

if (op_item->name() == "builtin.slice") {
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);

if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
auto index = op_item->attributes()
.at("index")
.dyn_cast<ir::Int32Attribute>()
.data();
op_output_types.push_back(vec_types[index]);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue;
}

if (op_item->name() == "builtin.split") {
std::vector<phi::Place> out_places(op_item->num_results());
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);

if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
for (uint64_t idx = 0; idx < vec_types.size(); idx++) {
op_output_types.push_back(vec_types[idx]);
}
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
if (SpecialOpList.count(op_item->name())) {
DealWithSpecialBuiltinOps(
op_item, program.get(), &map_op_pair, &map_value_pair, ctx);
continue;
}

Expand Down Expand Up @@ -1167,7 +1178,11 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
}
}
}

if (VLOG_IS_ON(2)) {
std::stringstream ss1;
program->Print(ss1);
VLOG(2) << "Program after lowering to kernel pass : " << ss1.str();
}
return program;
}

Expand Down

0 comments on commit 9358b4b

Please sign in to comment.