Skip to content

Commit

Permalink
[PIR] Refine Build Scope and Context for optional value (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#57713)

* refine

* fix build scope

* fix bug

* add build phi context

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
zhangbo9674 authored and Frida-a committed Oct 14, 2023
1 parent b5776ba commit 362f3fb
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 134 deletions.
148 changes: 34 additions & 114 deletions paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,60 +141,6 @@ const std::unordered_set<std::string> SpecialOps = {"pd_op.feed",
"pd_op.shadow_output",
"pd_op.if"};

void AddNewData(pir::Value value,
std::string name,
paddle::framework::Variable* var,
std::unordered_map<pir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
if (value_2_var_name->count(value) == 0) {
value_2_var_name->emplace(value, name);
}

variable_2_var_name->emplace(var, name);
if (var_name_2_id->count(name) == 0) {
auto id = var_name_2_id->size();
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
}
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
}

void RenameData(pir::Value value,
std::string new_name,
std::string orig_name,
std::unordered_map<pir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id) {
(*value_2_var_name)[value] = new_name;

for (auto kv : (*value_2_var_name)) {
if (kv.second == orig_name) {
(*value_2_var_name)[kv.first] = new_name;
}
}

for (auto kv : (*variable_2_var_name)) {
if (kv.second == orig_name) {
(*variable_2_var_name)[kv.first] = new_name;
}
}

for (auto kv : *(var_name_2_id)) {
if (kv.first == orig_name) {
var_name_2_id->emplace(new_name, kv.second);
}
}
var_name_2_id->erase(orig_name);
}

using VariableNameMap =
std::unordered_map<const paddle::framework::Variable*, std::string>;

Expand Down Expand Up @@ -225,13 +171,6 @@ paddle::framework::Variable* CreateVar(
<< value_exe_info->GetScope();
var = value_exe_info->GetScope()->Var(name);
}
// AddNewData(value,
// name,
// var,
// value_2_var_name,
// variable_2_var_name,
// var_name_2_id,
// variable_list);

value_exe_info->Add(value, name);

Expand All @@ -246,7 +185,7 @@ void CheckInputVars(
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand_source(i);
if (value) {
if (IsInvalid(value)) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
value_2_var_name.end(),
Expand All @@ -262,6 +201,10 @@ void CheckInputVars(
void BuildValue(pir::Value value,
const std::string& var_name_prefix,
paddle::framework::ValueExecutionInfo* value_exe_info) {
if (!IsInvalid(value)) {
return;
}

paddle::framework::Variable* var = nullptr;
auto& value_2_var_name = value_exe_info->GetValue2VarName();
if (value_2_var_name.find(value) != value_2_var_name.end()) {
Expand Down Expand Up @@ -319,13 +262,6 @@ void HandleForSpecialOp(
var->GetMutable<phi::DenseTensor>();
auto value = op->result(0);

// AddNewData(value,
// fetch_var_name,
// var,
// value_2_var_name,
// variable_2_var_name,
// var_name_2_id,
// variable_list);
value_exe_info->Add(value, fetch_var_name);
}

Expand All @@ -342,13 +278,6 @@ void HandleForSpecialOp(
paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", name));

// AddNewData(value,
// name,
// var,
// value_2_var_name,
// variable_2_var_name,
// var_name_2_id,
// variable_list);
value_exe_info->Add(value, name);
}

Expand Down Expand Up @@ -402,12 +331,6 @@ void HandleForSpecialOp(
<< param_name;
}

// RenameData(value,
// param_name,
// orig_name,
// value_2_var_name,
// variable_2_var_name,
// var_name_2_id);
value_exe_info->Rename(value, param_name, orig_name);
}

Expand All @@ -424,12 +347,7 @@ void HandleForSpecialOp(
const_cast<paddle::framework::Scope*>(value_exe_info->GetScope()->root())
->Rename(orig_name, var_name);
}
// RenameData(value,
// var_name,
// orig_name,
// value_2_var_name,
// variable_2_var_name,
// var_name_2_id);

value_exe_info->Rename(value, var_name, orig_name);
}

Expand All @@ -441,14 +359,6 @@ void HandleForSpecialOp(
.AsString();
auto value = op->result(0);

// paddle::framework::Variable* var =
// value_exe_info->GetScope()->FindVar(param_name); AddNewData(value,
// param_name,
// var,
// value_2_var_name,
// variable_2_var_name,
// var_name_2_id,
// variable_list);
value_exe_info->Add(value, param_name);
}

Expand Down Expand Up @@ -545,7 +455,7 @@ void HandleForInplaceOp(pir::Operation* op,

for (size_t i = 0; i < op->num_results(); ++i) {
pir::Value value = op->result(i);
if (value.type().storage() == nullptr) {
if (!IsInvalid(value)) {
continue;
}
std::string value_name = yaml_parser.OutputNames()[i];
Expand Down Expand Up @@ -652,23 +562,29 @@ void BuildRuntimeContext(
auto index = op_yaml_info.InputName2Id().at(name);
pir::Value ptr = op->operand_source(index);

if (!IsInvalid(ptr)) {
continue;
}

auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name;

PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name),
phi::errors::PreconditionNotMet(
"can not find var[%s] in scope", in_var_name));
auto var = inner_scope->FindVar(in_var_name);
std::vector<paddle::framework::Variable*> vec_tmp = {var};
auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);

runtime_ctx->inputs[legacy_attr_name].push_back(var);
}

auto& output_name_list = op_yaml_info.OutputNames();
for (size_t i = 0; i < output_name_list.size(); ++i) {
auto name = output_name_list[i];
pir::Value ptr = op->result(i);
auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);

if (!IsInvalid(ptr)) {
continue;
}

auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name << "\t" << in_var_name;
Expand All @@ -680,7 +596,6 @@ void BuildRuntimeContext(
auto var = inner_scope->FindVar(in_var_name);

auto type = ptr.type();
auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);
if (type.isa<paddle::dialect::AllocatedDenseTensorType>() ||
type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
runtime_ctx->outputs[legacy_arg_name] = {var};
Expand Down Expand Up @@ -718,18 +633,21 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(

auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();

// build inputs
for (auto& name : vec_kernel_fn_tensor_params) {
PADDLE_ENFORCE_EQ(
name2id.count(name),
true,
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
pir::Value ptr = op->operand_source(index);
auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);

auto in_var_name = name_map.at(ptr);
if (!IsInvalid(ptr)) {
continue;
}

auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);
in_name_map[legacy_attr_name].push_back(in_var_name);
in_name_map[legacy_attr_name].push_back(name_map.at(ptr));
}

// build attribute
Expand Down Expand Up @@ -805,20 +723,22 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
}
}

// build outputs
auto& output_name_list = op_yaml_info.OutputNames();
for (size_t i = 0; i < output_name_list.size(); ++i) {
auto name = output_name_list[i];
pir::Value ptr = op->result(i);
auto legacy_arg_name =
op_normalizer.GetLegacyArgName(fluid_op_name, output_name_list[i]);

auto out_var_name = name_map.at(ptr);
if (!IsInvalid(ptr)) {
continue;
}

auto type = ptr.type();
auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);
if (type.isa<paddle::dialect::AllocatedDenseTensorType>() ||
type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
out_name_map[legacy_arg_name].push_back(out_var_name);
} else if (type.isa<pir::VectorType>()) {
auto var = scope->FindVar(out_var_name);
if (ptr.type().isa<paddle::dialect::AllocatedDenseTensorType>() ||
ptr.type().isa<paddle::dialect::AllocatedSelectedRowsType>()) {
out_name_map[legacy_arg_name].push_back(name_map.at(ptr));
} else if (ptr.type().isa<pir::VectorType>()) {
auto var = scope->FindVar(name_map.at(ptr));
auto var_ref = var->Get<paddle::framework::VariableRefArray>();
for (size_t k = 0; k < var_ref.size(); ++k) {
PADDLE_ENFORCE(variable_2_var_name.count(var_ref[k]),
Expand Down
65 changes: 45 additions & 20 deletions paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ class ValueExecutionInfo {
} // namespace paddle

namespace pir {

// NOTE(zhangbo): Some operators of Paddle support optional inputs or outputs,
// representing whether the input or output exists. In the Pir, whether the
// value itself is empty or the type it holds is empty is used to indicate
// whether the input or output exists.
inline bool IsInvalid(pir::Value value) {
if ((!value) || (!value.type())) {
return false;
}
return true;
}

void BuildScope(
const pir::Block& block,
const std::string& var_name_prefix,
Expand Down Expand Up @@ -155,20 +167,28 @@ void BuildPhiContext(

auto attr_map = op->attributes();

// EmplaceBackInputs
auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(is_kernel);

auto& name2id = op_yaml_info.InputName2Id();
for (auto& t : vec_kernel_fn_tensor_params) {
PADDLE_ENFORCE_EQ(
name2id.count(t),
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.InputName2Id().at(t);
pir::Value ptr = op->operand_source(index);
if (!ptr) {
phi::DenseTensor* ptr = nullptr;
OutType in_ptr(ptr);
ctx->EmplaceBackInput(in_ptr);

pir::Value ptr = op->operand_source(op_yaml_info.InputName2Id().at(t));

if (!IsInvalid(ptr)) {
if (op_yaml_info.GetInputType(op_yaml_info.InputName2Id().at(t)) ==
"pir::VectorType<paddle::dialect::DenseTensorType>") {
InListType optional_inputs;
ctx->EmplaceBackInputs(optional_inputs);
} else {
phi::DenseTensor* temp = nullptr;
InType optional_input(temp);
ctx->EmplaceBackInput(optional_input);
}
VLOG(6) << "ctx->EmplaceBackInput : an optioanl input " << t;
continue;
}

Expand Down Expand Up @@ -206,6 +226,7 @@ void BuildPhiContext(
}
}

// EmplaceBackAttributes
auto& vec_kernel_fn_attr_params = op_yaml_info.AttrParams(is_kernel);
for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) {
Expand Down Expand Up @@ -367,29 +388,33 @@ void BuildPhiContext(
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
}

// TODO(phlrain): use var type instead of op name
// EmplaceBackOutputs
for (size_t i = 0; i < op->num_results(); ++i) {
pir::Value out_ptr = op->result(i);
auto out_type = out_ptr.type();
if (out_type) {
auto& name = name_map.at(out_ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name;
} else {
if (!IsInvalid(out_ptr)) {
if (op_yaml_info.GetOutputType(i) ==
"pir::VectorType<paddle::dialect::DenseTensorType>") {
OutListType optional_outputs;
ctx->EmplaceBackOutputs(optional_outputs);
} else {
phi::DenseTensor* temp = nullptr;
OutType optional_input(temp);
ctx->EmplaceBackOutput(optional_input);
}
VLOG(6) << "ctx->EmplaceBackOutput : an optioanl output";
continue;
}
if (!out_type) {
phi::DenseTensor* ptr = nullptr;
OutType out_ptr(ptr);
ctx->EmplaceBackOutput(out_ptr);
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {

if (out_ptr.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(inner_scope->FindVar(name_map.at(out_ptr))
->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
} else if (out_ptr.type()
.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(inner_scope->FindVar(name_map.at(out_ptr))
->Get<phi::SelectedRows>()))));
} else if (out_type.isa<pir::VectorType>()) {
} else if (out_ptr.type().isa<pir::VectorType>()) {
OutListType outputs;
auto& variable_array = inner_scope->FindVar(name_map.at(out_ptr))
->Get<paddle::framework::VariableRefArray>();
Expand Down
Loading

0 comments on commit 362f3fb

Please sign in to comment.