From 2b5cc895fec3f5e9251ab04ad934354a5e157426 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 02:51:42 +0000 Subject: [PATCH 01/30] add vector type support for program translator --- paddle/fluid/translator/op_translator.cc | 142 +++++++++++++++--- paddle/fluid/translator/op_translator.h | 3 +- paddle/fluid/translator/program_translator.cc | 15 +- paddle/fluid/translator/program_translator.h | 28 +++- paddle/fluid/translator/translate.h | 7 +- paddle/ir/builtin_op.cc | 3 + paddle/ir/builtin_op.h | 35 ++++- paddle/ir/printer.cc | 40 +++-- test/cpp/ir/program_translator_test.cc | 5 +- 9 files changed, 231 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index c6ff3f94125fb5..65102a4eb0b192 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -24,6 +24,8 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" +#include "paddle/ir/builtin_op.h" +#include "paddle/ir/builtin_type.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/value.h" #include "paddle/phi/core/enforce.h" @@ -84,23 +86,101 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { return op_info; } +inline ir::Operation* InsertSliceOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const VariableDefiningInfo& defining_info, + const std::string& arg_name) { + std::string slice_op_name(ir::SliceOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); + std::unordered_map op_attribute_map = { + {"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)}, + }; + ir::VectorType src_vec_type = + defining_info.value.type().dyn_cast(); + ir::Operation* operation = + ir::Operation::create({defining_info.value}, + {src_vec_type[defining_info.idx_in_vector]}, + op_attribute_map, + op_info); + program->InsertOp(operation); + ir::OpResult target_op_result = operation->GetResultByIndex(0); + (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); + return operation; +} + +inline ir::Operation* InsertCombineOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const std::vector& args) { + std::string combine_op_name(ir::CombineOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + std::vector src_values; + std::vector types_in_vec; + for (auto arg_name : args) { + auto defining_info = param_map->at(arg_name); + src_values.push_back(defining_info.value); + types_in_vec.push_back(defining_info.value.type()); + } + ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); + ir::Operation* operation = + ir::Operation::create(src_values, {target_vec_type}, {}, op_info); + program->InsertOp(operation); + return operation; +} + inline std::vector GenerateOperationInput( - TranslationContext* param_map, const OpDesc& op_desc) { + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const OpDesc& op_desc) { std::vector op_inputs = {}; + + // scan all inputs to see if any of them is generated as a vector + // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { auto& name = n.first; - VLOG(10) << "[input retriving]" - << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + for (const auto& arg_name : args) { PADDLE_ENFORCE_NE( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s as input should be exists before prasing %d", + "arg %s.%s as input should be exists before prasing %d", + name, arg_name, op_desc.Type())); - op_inputs.push_back((*param_map)[arg_name]); + auto defining_info = (*param_map)[arg_name]; + if (defining_info.generated_by_vector) { + InsertSliceOperationForTarget( + ctx, param_map, program, defining_info, arg_name); + } + } + } + + for (const auto& n : op_desc.Inputs()) { + auto& name = n.first; + VLOG(10) << "[input retriving]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + auto defining_info = (*param_map)[arg_name]; + op_inputs.push_back(defining_info.value); + } + + // if src type is Vector , need an additional `CombineOp` to + // assemble them. + } else { + auto* combine_op = + InsertCombineOperationForTarget(ctx, param_map, program, args); + op_inputs.push_back(combine_op->GetResultByIndex(0)); } } return op_inputs; @@ -119,16 +199,39 @@ inline std::tuple GenerateOperationOutput( VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name << " " - << var->GetType(); - - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); - arg_to_idx[arg_name] = op_output_types.size(); - op_output_types.push_back(translated_var_type); + size_t cur_output_idx = op_output_types.size(); + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + + arg_to_idx[arg_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); + } + + // if src type is Vector + } else { + std::vector types; + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + types.push_back(translated_var_type); + arg_to_idx[arg_name] = cur_output_idx; + } + ir::Type vec_type = ir::VectorType::get(ctx, types); + op_output_types.push_back(vec_type); } } return {op_output_types, arg_to_idx}; @@ -143,12 +246,17 @@ inline void RecordOpResultMapping(TranslationContext* param_map, VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + size_t idx_in_vector = 0; for (const auto& arg_name : args) { auto idx = arg_to_idx.at(arg_name); VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << arg_name << " " << idx; - (*param_map)[arg_name] = operation->GetResultByIndex(idx); + ir::OpResult value = operation->GetResultByIndex(idx); + bool generated_by_vector = value.type().isa(); + (*param_map)[arg_name] = VariableDefiningInfo( + value, generated_by_vector, generated_by_vector ? idx_in_vector : -1); + idx_in_vector++; } } } @@ -157,7 +265,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; @@ -193,7 +301,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); diff --git a/paddle/fluid/translator/op_translator.h b/paddle/fluid/translator/op_translator.h index c767f639d534b9..92c03458300257 100644 --- a/paddle/fluid/translator/op_translator.h +++ b/paddle/fluid/translator/op_translator.h @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/translator/program_translator.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/operation.h" #include "paddle/ir/program.h" @@ -28,8 +29,6 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; - class OpTranslator { public: using ResultIdx = size_t; diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 7cdbef589067ba..7618972f108040 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/translator/op_translator.h" +#include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/attribute.h" #include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_type.h" @@ -38,6 +39,11 @@ ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ctx = ir::IrContext::Instance(); } +const std::unordered_set ProgramTranslator::no_cast_var_names = { + "feed", + "fetch", +}; + void ProgramTranslator::Translate() { PADDLE_ENFORCE_EQ( legacy_program->Size(), @@ -59,19 +65,24 @@ void ProgramTranslator::Translate() { void ProgramTranslator::ExtractParameterFromSingleBlock( const BlockDesc& block) { + auto& type_translator = TypeTranslator::instance(); + for (auto& var : block.AllVars()) { if (!var->Persistable()) continue; if (param_map.count(var->Name()) != 0) continue; + if (no_cast_var_names.count(var->Name()) != 0) continue; std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, }; + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( - {}, {ir::Float32Type::get(ctx)}, op_attribute_map, op_info); + {}, {translated_var_type}, op_attribute_map, op_info); program->InsertOp(operation); - param_map[var->Name()] = operation->GetResultByIndex(0); + param_map[var->Name()] = + VariableDefiningInfo(operation->GetResultByIndex(0)); VLOG(10) << "[op translated][get parameter]" << operation; program->SetParameter(var->Name(), nullptr); diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 569b93b06aa6d9..f7fd4e2890ea64 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -27,7 +27,25 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; +struct VariableDefiningInfo { + VariableDefiningInfo(ir::OpResult value, + bool generated_by_vector = false, + int idx_in_vector = -1) + : value(value), + generated_by_vector(generated_by_vector), + idx_in_vector(idx_in_vector) {} + VariableDefiningInfo() {} + + ir::OpResult value; + + bool generated_by_vector = + false; // true if target variabe is generated by Vector + int idx_in_vector = + -1; // positive if target variabe is generated by Vector +}; + +using TranslationContext = + std::unordered_map; class ProgramTranslator { using ProgramDesc = ::paddle::framework::ProgramDesc; @@ -45,6 +63,14 @@ class ProgramTranslator { TranslationContext param_map; ir::IrContext* ctx; + /// In the legacy program desc, there are two special named varibales: + /// 1. "feed", the input variable of feed op + /// 2. "fetch", the output variable of fetch op + /// However, new feed has no input and new fetch has no output + /// So we don't handle these two vairables when + /// `ExtractParameterFromSingleBlock` + static const std::unordered_set no_cast_var_names; + void ExtractParameterFromSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); }; diff --git a/paddle/fluid/translator/translate.h b/paddle/fluid/translator/translate.h index aa2571f74c4cf4..8d6b7e7b1bbdc2 100644 --- a/paddle/fluid/translator/translate.h +++ b/paddle/fluid/translator/translate.h @@ -22,10 +22,7 @@ namespace paddle { -using LegacyProgramDesc = ::paddle::framework::ProgramDesc; -using Program = ::ir::Program; - -std::unique_ptr TranslateLegacyProgramToProgram( - const LegacyProgramDesc& legacy_program); +std::unique_ptr<::ir::ProgramProgram> TranslateLegacyProgramToProgram( + const ::paddle::framework::ProgramDesc& legacy_program); } // namespace paddle diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index d000d086b0f4cd..5d618af72c8784 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -21,4 +21,7 @@ const char *GetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = { "parameter_name"}; +const char **CombineOp::attributes_name = nullptr; +const char *SliceOp::attributes_name[attributes_num] = {"index"}; + } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index ca29867ff4a132..ae2c9c768dc5ba 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -22,7 +22,8 @@ namespace ir { /// The built-in Dialect will use this macro to quickly register all built-in /// OPs. /// -#define GET_BUILT_IN_OP_LIST ir::GetParameterOp, ir::SetParameterOp +#define GET_BUILT_IN_OP_LIST \ + ir::GetParameterOp, ir::SetParameterOp, ir::CombineOp, ir::SliceOp /// /// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute, @@ -40,7 +41,7 @@ class GetParameterOp : public ir::Op { }; /// -/// \brief GetParameterOp: SetParameterOp(OpOperand, {StrAttribute, +/// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// StrAttribute}) /// class SetParameterOp : public ir::Op { @@ -54,4 +55,34 @@ class SetParameterOp : public ir::Op { static const char* attributes_name[attributes_num]; }; +/// +/// \brief CombineOp: CombineOp(OpOperand, {StrAttribute, +/// StrAttribute}) +/// +class CombineOp : public ir::Op { + public: + using Op::Op; + + static const char* name() { return "builtin.combine"; } + + static constexpr uint32_t attributes_num = 0; + + static const char** attributes_name; +}; + +/// +/// \brief SliceOp: SliceOp(OpOperand, {StrAttribute, +/// StrAttribute}) +/// +class SliceOp : public ir::Op { + public: + using Op::Op; + + static const char* name() { return "builtin.slice"; } + + static constexpr uint32_t attributes_num = 1; + + static const char* attributes_name[attributes_num]; +}; + } // namespace ir diff --git a/paddle/ir/printer.cc b/paddle/ir/printer.cc index b4d3acb930a342..421dedabb3a100 100644 --- a/paddle/ir/printer.cc +++ b/paddle/ir/printer.cc @@ -28,6 +28,21 @@ namespace ir { namespace { constexpr char newline[] = "\n"; + +template +void PrintInterleave(ForwardIterator begin, + ForwardIterator end, + UnaryFunctor print_func, + NullFunctor between_func) { + if (begin == end) return; + print_func(*begin); + begin++; + for (; begin != end; begin++) { + between_func(); + print_func(*begin); + } +} + } // namespace class Printer { @@ -47,6 +62,15 @@ class Printer { os << "i32"; } else if (type.isa()) { os << "i64"; + } else if (type.isa()) { + os << "vec<"; + auto inner_types = type.dyn_cast().data(); + PrintInterleave( + inner_types.begin(), + inner_types.end(), + [this](ir::Type v) { this->PrintType(v); }, + [this]() { this->os << ","; }); + os << ">"; } else { auto& dialect = type.dialect(); dialect.PrintType(type, os); @@ -77,22 +101,6 @@ class ProgramPrinter : public Printer { } } - template - void PrintInterleave(ForwardIterator begin, - ForwardIterator end, - UnaryFunctor print_func, - NullFunctor between_func) { - if (begin == end) return; - print_func(*begin); - begin++; - for (; begin != end; begin++) { - between_func(); - print_func(*begin); - } - } - void PrintValue(ir::Value v) { const void* key = static_cast(v.impl()); auto ret = aliases.find(key); diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index e88bbb2cf391fd..ac8141b398f19c 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -58,6 +58,7 @@ TEST(PaddleDialectTest, Translator) { auto program = paddle::TranslateLegacyProgramToProgram(p); std::list ops = program->ops(); - EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num()); - VLOG(0) << *program << std::endl; + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + 20); + std::cout << *program << std::endl; } From 489b2cc98cfe77445dd2e03c99ed408bff5018e6 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 07:25:30 +0000 Subject: [PATCH 02/30] polish --- paddle/fluid/translator/translate.h | 2 +- paddle/fluid/translator/type_translator.cc | 4 ++++ paddle/fluid/translator/type_translator.h | 10 +++++----- paddle/ir/builtin_op.h | 6 ++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/translator/translate.h b/paddle/fluid/translator/translate.h index 8d6b7e7b1bbdc2..9cf013ea2bca36 100644 --- a/paddle/fluid/translator/translate.h +++ b/paddle/fluid/translator/translate.h @@ -22,7 +22,7 @@ namespace paddle { -std::unique_ptr<::ir::ProgramProgram> TranslateLegacyProgramToProgram( +std::unique_ptr<::ir::Program> TranslateLegacyProgramToProgram( const ::paddle::framework::ProgramDesc& legacy_program); } // namespace paddle diff --git a/paddle/fluid/translator/type_translator.cc b/paddle/fluid/translator/type_translator.cc index 9792a4b8537cef..e0c31913ffefe2 100644 --- a/paddle/fluid/translator/type_translator.cc +++ b/paddle/fluid/translator/type_translator.cc @@ -22,6 +22,10 @@ namespace paddle { namespace translator { +using OpDesc = paddle::framework::OpDesc; +using BlockDesc = paddle::framework::BlockDesc; +using VarDesc = paddle::framework::VarDesc; +using VarType = paddle::framework::proto::VarType; using DenseTensorType = paddle::dialect::DenseTensorType; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; diff --git a/paddle/fluid/translator/type_translator.h b/paddle/fluid/translator/type_translator.h index b16c1a222a5342..707b8913f3385d 100644 --- a/paddle/fluid/translator/type_translator.h +++ b/paddle/fluid/translator/type_translator.h @@ -27,13 +27,13 @@ namespace paddle { namespace translator { -using OpDesc = paddle::framework::OpDesc; -using BlockDesc = paddle::framework::BlockDesc; -using VarDesc = paddle::framework::VarDesc; -using VarType = paddle::framework::proto::VarType; -using TypeTranslateFn = std::function; +using TypeTranslateFn = + std::function; class TypeTranslator { + public: + using VarType = paddle::framework::proto::VarType; + private: TypeTranslator(); // Disallow instantiation outside of the class. std::unordered_map handlers; diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index ae2c9c768dc5ba..8d4485c988d84c 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -56,8 +56,7 @@ class SetParameterOp : public ir::Op { }; /// -/// \brief CombineOp: CombineOp(OpOperand, {StrAttribute, -/// StrAttribute}) +/// \brief CombineOp: CombineOp(OpOperand) /// class CombineOp : public ir::Op { public: @@ -71,8 +70,7 @@ class CombineOp : public ir::Op { }; /// -/// \brief SliceOp: SliceOp(OpOperand, {StrAttribute, -/// StrAttribute}) +/// \brief SliceOp: SliceOp(OpOperand) /// class SliceOp : public ir::Op { public: From 324d8972dc39771fce65f092ac558585ece7bc0a Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 12:13:04 +0000 Subject: [PATCH 03/30] support basic attribute type --- .../fluid/translator/attribute_translator.cc | 60 +++++++++++++++++++ .../fluid/translator/attribute_translator.h | 48 +++++++++++++++ paddle/fluid/translator/op_translator.cc | 48 ++++++++++++++- 3 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/translator/attribute_translator.cc create mode 100644 paddle/fluid/translator/attribute_translator.h diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc new file mode 100644 index 00000000000000..c2e6f24f39e478 --- /dev/null +++ b/paddle/fluid/translator/attribute_translator.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/translator/attribute_translator.h" + +#include +#include + +#include "paddle/utils/variant.h" + +namespace paddle { +namespace translator { + +class AttributeVisitor { + public: + ir::IrContext* ctx; + AttributeVisitor() { ctx = ir::IrContext::Instance(); } + ~AttributeVisitor() {} + + public: + ir::Attribute operator()(int i) { return ir::Int64_tAttribute::get(ctx, i); } + + ir::Attribute operator()(float f) { return ir::FloatAttribute::get(ctx, f); } + + ir::Attribute operator()(bool b) { return ir::BoolAttribute::get(ctx, b); } + + ir::Attribute operator()(double d) { + return ir::DoubleAttribute::get(ctx, d); + } + + ir::Attribute operator()(std::string str) { + return ir::StrAttribute::get(ctx, str); + } + + template + ir::Attribute operator()(T attr) { + return ir::Attribute(nullptr); + } +}; + +AttributeTranslator::AttributeTranslator() { visitor = new AttributeVisitor(); } + +ir::Attribute AttributeTranslator::operator[]( + const framework::Attribute& attr) { + return paddle::visit(*visitor, attr); +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h new file mode 100644 index 00000000000000..9e7117d5db6412 --- /dev/null +++ b/paddle/fluid/translator/attribute_translator.h @@ -0,0 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/ir/attribute.h" +#include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/ir_context.h" + +#pragma once + +namespace paddle { +namespace translator { + +class AttributeVisitor; + +class AttributeTranslator { + private: + AttributeTranslator(); + AttributeVisitor* visitor; + + public: + AttributeTranslator(const AttributeTranslator&) = delete; + AttributeTranslator& operator=(const AttributeTranslator&) = delete; + AttributeTranslator(AttributeTranslator&&) = delete; + AttributeTranslator& operator=(AttributeTranslator&&) = delete; + + static auto& instance() { + static AttributeTranslator attribute_translator; + return attribute_translator; + } + + ir::Attribute operator[](const framework::Attribute& attr); +}; + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 65102a4eb0b192..85ad10f73db37b 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -22,11 +22,13 @@ #include #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/translator/attribute_translator.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_type.h" #include "paddle/ir/ir_context.h" +#include "paddle/ir/operation.h" #include "paddle/ir/value.h" #include "paddle/phi/core/enforce.h" @@ -237,6 +239,40 @@ inline std::tuple GenerateOperationOutput( return {op_output_types, arg_to_idx}; } +inline ir::AttributeMap TranslateOpAttribute(const OpDesc& op_desc) { + auto& attribute_translator = AttributeTranslator::instance(); + ir::AttributeMap attribute_map = {}; + for (auto attr_in_op_desc : op_desc.GetAttrMap()) { + const auto& attr_name = attr_in_op_desc.first; + const auto& attr_value = attr_in_op_desc.second; + ir::Attribute new_attr = attribute_translator[attr_value]; + attribute_map[attr_name] = new_attr; + if (!new_attr) { + VLOG(0) << "empty attribute in " << op_desc.Type() + << " name: " << attr_name; + } else { + VLOG(10) << "new attribute in " << op_desc.Type() + << " name: " << attr_name << " " << new_attr.storage(); + } + } + + for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { + const auto& attr_name = attr_in_op_desc.first; + const auto& attr_value = attr_in_op_desc.second; + ir::Attribute new_attr = attribute_translator[attr_value]; + attribute_map[attr_name] = new_attr; + if (!new_attr) { + VLOG(0) << "empty runtime attribute in " << op_desc.Type() + << " name: " << attr_name; + } else { + VLOG(10) << "new runtime attribute in " << op_desc.Type() + << " name: " << attr_name << " " << new_attr.storage(); + } + } + + return std::move(attribute_map); +} + inline void RecordOpResultMapping(TranslationContext* param_map, const OpDesc& op_desc, ir::Operation* operation, @@ -271,8 +307,10 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types = {}; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); + ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); program->InsertOp(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -289,8 +327,10 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types = {}; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); + ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); program->InsertOp(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -305,8 +345,10 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); + ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); program->InsertOp(operation); return operation; From f9e17d695f70b80525486093f0fd675c5d0898e6 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 12:42:38 +0000 Subject: [PATCH 04/30] resolve conflicts --- paddle/ir/builtin_dialect.cc | 5 ++++- paddle/ir/builtin_op.cc | 7 +++++++ paddle/ir/builtin_op.h | 8 +++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/paddle/ir/builtin_dialect.cc b/paddle/ir/builtin_dialect.cc index 9c8cacf7bff94a..32f8750b710941 100644 --- a/paddle/ir/builtin_dialect.cc +++ b/paddle/ir/builtin_dialect.cc @@ -44,7 +44,10 @@ void BuiltinDialect::initialize() { ir::Int64_tAttribute, ir::ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 1f4ffaaa8c3510..1f1331c7396f03 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -59,6 +59,13 @@ void SetParameterOp::verify(const std::vector &inputs, } const char **CombineOp::attributes_name = nullptr; +void CombineOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) {} + const char *SliceOp::attributes_name[attributes_num] = {"index"}; +void SliceOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) {} } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index e2e855889cec78..5751632df58652 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -36,7 +36,7 @@ class GetParameterOp : public ir::Op { /// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// StrAttribute}) /// -class SliceOp : public ir::Op { +class SetParameterOp : public ir::Op { public: using Op::Op; static const char *name() { return "builtin.set_parameter"; } @@ -59,6 +59,9 @@ class CombineOp : public ir::Op { static constexpr uint32_t attributes_num = 0; static const char **attributes_name; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); }; /// @@ -73,6 +76,9 @@ class SliceOp : public ir::Op { static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); }; } // namespace ir From 9cd463f0cfce6c25430d8d884638bf1b2ae4bb14 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 03:49:36 +0000 Subject: [PATCH 05/30] add verify for combine/slice and unittests --- paddle/ir/builtin_op.cc | 98 ++++++++++++++++++++++++++++++++-- paddle/ir/builtin_op.h | 2 +- paddle/ir/type.cc | 7 ++- paddle/ir/type.h | 2 + test/cpp/ir/ir_program_test.cc | 59 ++++++++++++++++++++ 5 files changed, 163 insertions(+), 5 deletions(-) diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 1f1331c7396f03..63bfc2196dca35 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -13,7 +13,10 @@ // limitations under the License. #include "paddle/ir/builtin_op.h" + #include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/builtin_type.h" +#include "paddle/phi/core/enforce.h" namespace ir { const char *GetParameterOp::attributes_name[attributes_num] = { @@ -58,14 +61,103 @@ void SetParameterOp::verify(const std::vector &inputs, } } -const char **CombineOp::attributes_name = nullptr; void CombineOp::verify(const std::vector &inputs, const std::vector &outputs, - const ir::AttributeMap &attributes) {} + const ir::AttributeMap &attributes) { + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + // outputs[0].type == Vector + PADDLE_ENFORCE(outputs[0].isa(), + phi::errors::PreconditionNotMet( + "The type %s of outputs[0] must be equal to VectorType.", + outputs[0])); + ir::VectorType output_type = outputs[0].dyn_cast(); + // inputs.size() == outputs[0].size() + PADDLE_ENFORCE_EQ( + output_type.size(), + inputs.size(), + phi::errors::PreconditionNotMet( + "The size %d of outputs[0] must be equal to size %d of inputs.", + output_type.size(), + inputs.size())); + + // forall i in inputs.size(): inputs[i].type == outputs[0][i].type + for (size_t i = 0; i < inputs.size(); i++) { + PADDLE_ENFORCE_EQ( + output_type[i], + inputs[i].type(), + phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be " + "equal to type %s of inputs[%d].", + output_type[i], + i, + inputs[i].type(), + i)); + } +} const char *SliceOp::attributes_name[attributes_num] = {"index"}; void SliceOp::verify(const std::vector &inputs, const std::vector &outputs, - const ir::AttributeMap &attributes) {} + const ir::AttributeMap &attributes) { + // inputs.size() == 1 + PADDLE_ENFORCE_EQ( + inputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", inputs.size())); + + // inputs[0].type == Vector + PADDLE_ENFORCE(inputs[0].type().isa(), + phi::errors::PreconditionNotMet( + "The type %s of inputs[0] must be equal to VectorType.", + inputs[0].type())); + ir::VectorType input_type = inputs[0].type().dyn_cast(); + + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + + // attributes contains index: Int32 + PADDLE_ENFORCE_NE( + attributes.count("index"), + 0, + phi::errors::PreconditionNotMet("The attributes must contains index.")); + const ir::Attribute &attr = attributes.at("index"); + PADDLE_ENFORCE( + attr.isa(), + phi::errors::PreconditionNotMet("The attribute index must be INT32.")); + auto index = attr.dyn_cast().data(); + + // index >= 0 and < inputs[0].size() + PADDLE_ENFORCE_GE( + index, + 0, + phi::errors::PreconditionNotMet( + "The index %d must be greater or equal than 0.", index)); + PADDLE_ENFORCE_LT( + index, + input_type.size(), + phi::errors::PreconditionNotMet( + "The index %d must be less or equal than size %d of inputs[0].", + index, + input_type.size())); + + // inputs[index].type == outputs[0].type + PADDLE_ENFORCE_EQ( + input_type[index], + outputs[0], + phi::errors::PreconditionNotMet( + "The type %s of inputs[%d] must be equal to type %s of outputs[0].", + input_type[index], + index, + outputs[0])); +} } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index 5751632df58652..d1d2c20b5725db 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -58,7 +58,7 @@ class CombineOp : public ir::Op { static constexpr uint32_t attributes_num = 0; - static const char **attributes_name; + static constexpr const char **attributes_name = nullptr; static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); diff --git a/paddle/ir/type.cc b/paddle/ir/type.cc index bde3194f8fd414..e9c24672e5b5e4 100644 --- a/paddle/ir/type.cc +++ b/paddle/ir/type.cc @@ -16,6 +16,11 @@ #include "paddle/ir/dialect.h" namespace ir { -IrContext *Type::ir_context() const { return dialect().ir_context(); } +IrContext* Type::ir_context() const { return dialect().ir_context(); } + +std::ostream& operator<<(std::ostream& os, Type type) { + type.print(os); + return os; +} } // namespace ir diff --git a/paddle/ir/type.h b/paddle/ir/type.h index fce17db82ebf5a..89d153c089476e 100644 --- a/paddle/ir/type.h +++ b/paddle/ir/type.h @@ -89,6 +89,8 @@ class Type { const Storage *storage_{nullptr}; }; +std::ostream &operator<<(std::ostream &os, Type type); + } // namespace ir namespace std { diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index c430d9b320b406..711a40407de686 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -211,3 +211,62 @@ TEST(program_test, program) { EXPECT_EQ(ops.size() == 4, true); EXPECT_EQ(program.parameters_num() == 3, true); } + +TEST(program_test, slice_combine_test) { + // (1) Init environment. + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + // (2) Create an empty program object + ir::Program program; + // ir::Program *program = new ir::Program(); + EXPECT_EQ(program.ops().size() == 0, true); + + // (3) Create a float32 DenseTensor Parameter and save into Program + ir::Type fp32_dtype = ir::Float32Type::get(ctx); + + // (4) Def a = GetParameterOp("a") + std::string op1_name = ir::GetParameterOp::name(); + ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); + std::unordered_map op1_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; + ir::Operation *op1 = + ir::Operation::create({}, {fp32_dtype}, op1_attribute, op1_info); + program.InsertOp(op1); + + // (5) Def b = GetParameterOp("b") + std::string op2_name = std::string(ir::GetParameterOp::name()); + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); + std::unordered_map op2_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; + ir::Operation *op2 = + ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); + program.InsertOp(op2); + + std::string combine_op_name = std::string(ir::CombineOp::name()); + + ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + ir::Type output_type = + ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); + ir::Operation *combine_op = ir::Operation::create( + {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, + {output_type}, + {}, + combine_op_info); + program.InsertOp(combine_op); + + std::string slice_op_name = std::string(ir::SliceOp::name()); + ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); + ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); + ir::Operation *slice_op = + ir::Operation::create({combine_op->GetResultByIndex(0)}, + {fp32_dtype}, + {{"index", index_attr}}, + slice_op_info); + program.InsertOp(slice_op); + + // (8) Traverse Program + std::list ops = program.ops(); + EXPECT_EQ(ops.size() == 4, true); +} From 6d17079c47c48c4afcc140f85644eb9b190ffa3e Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 03:51:56 +0000 Subject: [PATCH 06/30] polish --- test/cpp/ir/ir_program_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index 711a40407de686..7c6c9acaf52c66 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -243,10 +243,9 @@ TEST(program_test, slice_combine_test) { ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); program.InsertOp(op2); + // (6) Def combine_op = CombineOp("a", "b") std::string combine_op_name = std::string(ir::CombineOp::name()); - ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name); - ir::Type output_type = ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); ir::Operation *combine_op = ir::Operation::create( @@ -256,6 +255,7 @@ TEST(program_test, slice_combine_test) { combine_op_info); program.InsertOp(combine_op); + // (7) Def slice_op = SliceOp(combine_op, 0) std::string slice_op_name = std::string(ir::SliceOp::name()); ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); From b45df9c30cdf17cca9617ee8c35f8acc02a8548a Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 08:59:31 +0000 Subject: [PATCH 07/30] support more type in attribute translator --- paddle/fluid/dialect/pd_attribute.cc | 4 +- paddle/fluid/dialect/pd_attribute.h | 2 +- paddle/fluid/dialect/pd_attribute_storage.h | 4 +- .../fluid/translator/attribute_translator.cc | 70 +++++++++++++++++++ paddle/fluid/translator/op_translator.cc | 2 + paddle/fluid/translator/program_translator.cc | 2 +- test/cpp/ir/program_translator_test.cc | 20 +++--- 7 files changed, 89 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/dialect/pd_attribute.cc b/paddle/fluid/dialect/pd_attribute.cc index 49e6865160dd74..7d0cd5c44d0f09 100644 --- a/paddle/fluid/dialect/pd_attribute.cc +++ b/paddle/fluid/dialect/pd_attribute.cc @@ -18,7 +18,9 @@ namespace paddle { namespace dialect { phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); } -phi::Scalar ScalarAttribute::data() const { return storage()->GetAsKey(); } +paddle::experimental::Scalar ScalarAttribute::data() const { + return storage()->GetAsKey(); +} phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); } diff --git a/paddle/fluid/dialect/pd_attribute.h b/paddle/fluid/dialect/pd_attribute.h index 75eed82dfc4b35..a4a48da218849f 100644 --- a/paddle/fluid/dialect/pd_attribute.h +++ b/paddle/fluid/dialect/pd_attribute.h @@ -47,7 +47,7 @@ class ScalarAttribute : public ir::Attribute { return storage() < right.storage(); } - phi::Scalar data() const; + paddle::experimental::Scalar data() const; }; class DataTypeAttribute : public ir::Attribute { diff --git a/paddle/fluid/dialect/pd_attribute_storage.h b/paddle/fluid/dialect/pd_attribute_storage.h index 352dcc8b0e4343..2e066413af7d69 100644 --- a/paddle/fluid/dialect/pd_attribute_storage.h +++ b/paddle/fluid/dialect/pd_attribute_storage.h @@ -55,7 +55,7 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage { }; struct ScalarAttributeStorage : public ir::AttributeStorage { - using ParamKey = phi::Scalar; + using ParamKey = paddle::experimental::Scalar; explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; } @@ -73,7 +73,7 @@ struct ScalarAttributeStorage : public ir::AttributeStorage { ParamKey GetAsKey() const { return ParamKey(data_); } private: - phi::Scalar data_; + paddle::experimental::Scalar data_; }; struct DataTypeAttributeStorage : public ir::AttributeStorage { diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc index c2e6f24f39e478..ababb7e3b50609 100644 --- a/paddle/fluid/translator/attribute_translator.cc +++ b/paddle/fluid/translator/attribute_translator.cc @@ -17,6 +17,8 @@ #include #include +#include "paddle/fluid/dialect/pd_attribute.h" +#include "paddle/phi/common/scalar.h" #include "paddle/utils/variant.h" namespace paddle { @@ -43,6 +45,74 @@ class AttributeVisitor { return ir::StrAttribute::get(ctx, str); } + ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + return paddle::dialect::ScalarAttribute::get(ctx, scalar); + } + + ir::Attribute operator()(const std::vector& strs) { + std::vector attrs; + attrs.reserve(strs.size()); + for (const auto& v : strs) { + attrs.push_back(ir::StrAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& fs) { + std::vector attrs; + attrs.reserve(fs.size()); + for (const auto& v : fs) { + attrs.push_back(ir::FloatAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& is) { + std::vector attrs; + attrs.reserve(is.size()); + for (const auto& v : is) { + attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& bs) { + std::vector attrs; + attrs.reserve(bs.size()); + for (const auto& v : bs) { + attrs.push_back(ir::BoolAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& i64s) { + std::vector attrs; + attrs.reserve(i64s.size()); + for (const auto& v : i64s) { + attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& ds) { + std::vector attrs; + attrs.reserve(ds.size()); + for (const auto& v : ds) { + attrs.push_back(ir::DoubleAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()( + const std::vector& ss) { + std::vector attrs; + attrs.reserve(ss.size()); + for (const auto& v : ss) { + attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + template ir::Attribute operator()(T attr) { return ir::Attribute(nullptr); diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 85ad10f73db37b..adf89cddc0bfdc 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -245,6 +245,8 @@ inline ir::AttributeMap TranslateOpAttribute(const OpDesc& op_desc) { for (auto attr_in_op_desc : op_desc.GetAttrMap()) { const auto& attr_name = attr_in_op_desc.first; const auto& attr_value = attr_in_op_desc.second; + VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name + << " " << attr_value.index(); ir::Attribute new_attr = attribute_translator[attr_value]; attribute_map[attr_name] = new_attr; if (!new_attr) { diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 7618972f108040..c98fc801a81708 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -75,7 +75,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { - {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, + {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, }; ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index 5ae7b0b9b31e7e..d7a35c17cbe878 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -48,18 +48,18 @@ ProgramDesc load_from_file(const std::string &file_name) { TEST(PaddleDialectTest, Translator) { LOG(WARNING) << "TODO"; - // auto p = load_from_file("restnet50_main.prog"); - // std::cout << p.Size() << std::endl; + auto p = load_from_file("restnet50_main.prog"); + std::cout << p.Size() << std::endl; - // EXPECT_EQ(p.Size(), 1u); + EXPECT_EQ(p.Size(), 1u); - // ir::IrContext *ctx = ir::IrContext::Instance(); - // ctx->GetOrRegisterDialect(); - // ctx->GetOrRegisterDialect(); - // auto program = paddle::TranslateLegacyProgramToProgram(p); + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); - // std::list ops = program->ops(); + std::list ops = program->ops(); // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + - // 20); std::cout << *program << std::endl; + EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + 20); + std::cout << *program << std::endl; } From 5940ac4afd381062f37536d70dbaf55e37eb77dd Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 09:03:18 +0000 Subject: [PATCH 08/30] modify by reviews --- paddle/fluid/translator/op_translator.cc | 4 +- paddle/fluid/translator/program_translator.h | 4 +- paddle/ir/builtin_dialect.cc | 43 +++++++++----------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index adf89cddc0bfdc..ec7b53cdaf3976 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -122,7 +122,7 @@ inline ir::Operation* InsertCombineOperationForTarget( std::vector src_values; std::vector types_in_vec; - for (auto arg_name : args) { + for (const auto& arg_name : args) { auto defining_info = param_map->at(arg_name); src_values.push_back(defining_info.value); types_in_vec.push_back(defining_info.value.type()); @@ -139,7 +139,7 @@ inline std::vector GenerateOperationInput( TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - std::vector op_inputs = {}; + std::vector op_inputs; // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index f7fd4e2890ea64..e185d9d0c33881 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -39,9 +39,9 @@ struct VariableDefiningInfo { ir::OpResult value; bool generated_by_vector = - false; // true if target variabe is generated by Vector + false; // true if target variable is generated by Vector int idx_in_vector = - -1; // positive if target variabe is generated by Vector + -1; // positive if target variable is generated by Vector }; using TranslationContext = diff --git a/paddle/ir/builtin_dialect.cc b/paddle/ir/builtin_dialect.cc index 32f8750b710941..955a6ebdce8c8f 100644 --- a/paddle/ir/builtin_dialect.cc +++ b/paddle/ir/builtin_dialect.cc @@ -18,36 +18,33 @@ #include "paddle/ir/builtin_type.h" namespace ir { -BuiltinDialect::BuiltinDialect(ir::IrContext *context) - : ir::Dialect(name(), context, ir::TypeId::get()) { +BuiltinDialect::BuiltinDialect(IrContext *context) + : Dialect(name(), context, TypeId::get()) { initialize(); } void BuiltinDialect::initialize() { // Register all built-in types defined in builtin_type.h. - RegisterTypes(); + RegisterTypes(); - RegisterAttributes(); + RegisterAttributes(); - RegisterOps(); + RegisterOps(); } } // namespace ir From b328cbb00ada67e3e71f50a766dd9c25565dd6e4 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 10:16:33 +0000 Subject: [PATCH 09/30] fix merge mistakes --- paddle/fluid/translator/program_translator.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 851449580de6a6..e185d9d0c33881 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -71,14 +71,6 @@ class ProgramTranslator { /// `ExtractParameterFromSingleBlock` static const std::unordered_set no_cast_var_names; - /// In the legacy program desc, there are two special named varibales: - /// 1. "feed", the input variable of feed op - /// 2. "fetch", the output variable of fetch op - /// However, new feed has no input and new fetch has no output - /// So we don't handle these two vairables when - /// `ExtractParameterFromSingleBlock` - static const std::unordered_set no_cast_var_names; - void ExtractParameterFromSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); }; From e35a288002b45067ca7db91c3057ae9a7e5195f2 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 26 May 2023 05:59:45 +0000 Subject: [PATCH 10/30] refine code --- CMakeLists.txt | 2 +- paddle/fluid/dialect/CMakeLists.txt | 4 +- paddle/fluid/dialect/legacy_pd_op.h | 13 +- paddle/fluid/dialect/op_gen.py | 204 ++++++++++++++++++++++++++-- paddle/fluid/dialect/pd_dialect.cc | 6 +- paddle/fluid/dialect/utils.h | 40 ++++++ 6 files changed, 243 insertions(+), 26 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0bf93f0ed9d26f..bb6432e4175f26 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,7 +307,7 @@ option(WITH_CUDNN_FRONTEND "Compile with CUDNN Frontend API support (experimental)" OFF) option(WITH_CUDNN_DSO "Compile PaddlePaddle with cuDNN dynamic-link libraries" OFF) -option(WITH_NEWIR "Compile PaddlePaddle with NEWIR" OFF) +option(WITH_NEWIR "Compile PaddlePaddle with NEWIR" ON) if(WITH_RECORD_BUILDTIME) set_property( diff --git a/paddle/fluid/dialect/CMakeLists.txt b/paddle/fluid/dialect/CMakeLists.txt index 8130b75f637cf9..3f17ac21724c66 100644 --- a/paddle/fluid/dialect/CMakeLists.txt +++ b/paddle/fluid/dialect/CMakeLists.txt @@ -8,13 +8,13 @@ set(op_forward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml ) set(op_forward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_ops.parsed.yaml + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml ) set(op_backward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml ) set(op_backward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_backward.parsed.yaml + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) set(op_yaml_files ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2} diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h index 21be24720dc3fa..60a0a4b0871826 100644 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ b/paddle/fluid/dialect/legacy_pd_op.h @@ -36,9 +36,10 @@ namespace dialect { // TODO(zhangbo): As operators are supplemented and defined, they are gradually // removed. -REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d -REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed -REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: batch_norm +REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d +REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed +// REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: +// batch_norm REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ REIGSTER_EMPTY_OP(elementwise_add, ElementwiseAddOp); // To be customized: add (elementwise_add) @@ -74,10 +75,10 @@ REIGSTER_EMPTY_OP( FlattenContiguousRangeGradOp); // flatten_grad // (flatten_contiguous_range_grad) REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad -REIGSTER_EMPTY_OP(batch_norm_grad, - BatchNormGradOp); // To be customized: batch_norm_grad +// REIGSTER_EMPTY_OP(batch_norm_grad, +// BatchNormGradOp); // To be customized: batch_norm_grad REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad -REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) +// REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 } // namespace dialect diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 94930d0f133104..6abef860838c32 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -29,7 +29,10 @@ {op_declare} #else +#include + #include "paddle/ir/op_base.h" +#include "paddle/fluid/dialect/utils.h" {input} #endif @@ -45,6 +48,9 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static const char *name() {{ return "{dialect_op_name}"; }} {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; + static std::vector inputs_info(); + static std::vector outputs_info(); + static std::vector attributes_info(); static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); {get_inputs_and_outputs} }}; @@ -79,6 +85,37 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; """ +# get op input info +OP_INPUT_INFO_TEMPLATE = """ +std::vector {op_name}::inputs_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_INPUT_INFO_TEMPLATE = ( + """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})""" +) + +# get op output info +OP_OUTPUT_INFO_TEMPLATE = """ +std::vector {op_name}::outputs_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( + """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" +) + +# get op attribute info +OP_ATTRIBUTE_INFO_TEMPLATE = """ +std::vector {op_name}::attributes_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( + """OpAttributeInfo("{name}", "{typename}", "{data_type}")""" +) + +# verify OP_VERIFY_TEMPLATE = """ void {op_name}::verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}."; @@ -170,32 +207,65 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ """ +def to_phi_and_fluid_op_name(op_item): + # Templat: - op : phi_name (fluid_name) + names = op_item.split('(') + if len(names) == 1: + phi_fluid_name = names[0].strip() + return phi_fluid_name, phi_fluid_name + else: + phi_name = names[0].strip() + fluid_name = names[1].split(')')[0].strip() + return phi_name, fluid_name + + +# ===================================== +# Parse Op Compat From Yaml +# ===================================== +class OpCompatParser: + def __init__(self, ops_compat_yaml_file): + self.ops_compat_yaml_file = ops_compat_yaml_file + with open(self.ops_compat_yaml_file, "r") as f: + self.ops_compat = yaml.safe_load(f) + + def get_compat(self, op_name): + for compat in self.ops_compat: + phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op']) + if op_name == phi_name: + return compat + return None + + # ===================================== -# Parse Op information from Yaml item +# Parse Op Information From Yaml # ===================================== class OpInfoParser: - def __init__(self, op_yaml_item): + def __init__(self, op_yaml_item, op_compat_item): self.op_yaml_item = op_yaml_item + self.op_compat_item = op_compat_item self.op_phi_name = self.parse_op_phi_name() - + # parse inputs self.input_name_list = self.parse_input_name_list() self.input_type_list = self.parse_input_type_list() self.input_optional_list = self.parse_input_optional_list() + self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list() self.cross_check( self.input_name_list, self.input_type_list, self.input_optional_list ) - + # parse outputs self.output_name_list = self.parse_output_name_list() self.output_type_list = self.parse_output_type_list() self.output_optional_list = self.parse_output_optional_list() + self.output_intermediate_list = self.parse_output_intermediate_list() self.cross_check( self.output_name_list, self.output_type_list, self.output_optional_list, ) - + # parse attributes self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() + self.attribute_data_type_list = self.parse_attribute_data_type_list() self.cross_check(self.attribute_name_list, self.attribute_type_list) def cross_check(self, name_list, type_list, optional_list=None): @@ -229,9 +299,21 @@ def parse_input_type_list(self): def parse_input_optional_list(self): optional_list = [] for input_info in self.op_yaml_item['inputs']: - optional_list.append(input_info['optional']) + if input_info['optional']: + optional_list.append("true") + else: + optional_list.append("false") return optional_list + def parse_input_no_need_buffer_list(self): + no_need_buffer_list = [] + for input_info in self.op_yaml_item['inputs']: + if input_info['no_need_buffer']: + no_need_buffer_list.append("true") + else: + no_need_buffer_list.append("false") + return no_need_buffer_list + def parse_output_name_list(self): name_list = [] for output_info in self.op_yaml_item['outputs']: @@ -255,11 +337,26 @@ def parse_output_optional_list(self): optional_list = [] for output_info in self.op_yaml_item['outputs']: if 'optional' in output_info: - optional_list.append(output_info['optional']) + if output_info['optional']: + optional_list.append("true") + else: + optional_list.append("false") else: - optional_list.append(False) + optional_list.append("false") return optional_list + def parse_output_intermediate_list(self): + intermediate_list = [] + for output_info in self.op_yaml_item['outputs']: + if 'intermediate' in output_info: + if output_info['intermediate']: + intermediate_list.append("true") + else: + intermediate_list.append("false") + else: + intermediate_list.append("false") + return intermediate_list + def parse_attribute_name_list(self): name_list = [] for attribute_info in self.op_yaml_item['attrs']: @@ -301,6 +398,15 @@ def parse_attribute_type_list(self): type_list.append(attr_types_map[attribute_info['typename']]) return type_list + def parse_attribute_data_type_list(self): + data_type_list = [] + for attribute_info in self.op_yaml_item['attrs']: + if 'data_type' in attribute_info: + data_type_list.append(attribute_info['data_type']) + else: + data_type_list.append("") + return data_type_list + def parse_op_phi_name(self): return self.op_yaml_item['name'] @@ -314,10 +420,11 @@ def to_pascal_case(s): # ===================================== -# Generate op definition files +# Generate Op Definition Files # ===================================== def OpGenerator( op_yaml_files, + op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, @@ -330,6 +437,8 @@ def OpGenerator( os.remove(op_def_cc_file) # (2) Prepare: Get all op item in all op_yaml_files + op_compat_parser = OpCompatParser(op_compat_yaml_file) + op_yaml_items = [] for yaml_file in op_yaml_files: with open(yaml_file, "r") as f: @@ -337,7 +446,9 @@ def OpGenerator( op_yaml_items = op_yaml_items + ops op_info_items = [] for op in op_yaml_items: - op_info_items.append(OpInfoParser(op)) + op_info_items.append( + OpInfoParser(op, op_compat_parser.get_compat(op['name'])) + ) # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list @@ -351,11 +462,14 @@ def OpGenerator( op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list op_input_optional_list = op_info.input_optional_list + op_input_no_need_buffer_list = op_info.input_no_need_buffer_list op_output_name_list = op_info.output_name_list op_output_type_list = op_info.output_type_list op_output_optional_list = op_info.output_optional_list + op_output_intermediate_list = op_info.output_intermediate_list op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list + op_attribute_data_type_list = op_info.attribute_data_type_list op_interfaces = [] op_traits = [] @@ -370,11 +484,16 @@ def OpGenerator( op_get_inputs_outputs_str = "" for idx in range(len(op_input_name_list)): op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( - input_name=op_input_name_list[idx], input_index=idx + input_name="input_" + str(idx) + "_" + op_input_name_list[idx], + input_index=idx, ) for idx in range(len(op_output_name_list)): op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( - output_name=op_output_name_list[idx], output_index=idx + output_name="output_" + + str(idx) + + "_" + + op_output_name_list[idx], + output_index=idx, ) # gen op_declare_str/op_defined_str @@ -410,6 +529,59 @@ def OpGenerator( attribute_names=attribute_names_str, ) + # generate get op info funciton: inputs + inputs_info_str = "" + if len(op_input_name_list) > 0: + input_info_list = [] + for idx in range(len(op_input_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_input_name_list[idx], + typename=op_input_type_list[idx], + optional=op_input_optional_list[idx], + no_need_buffer=op_input_no_need_buffer_list[idx], + ) + ) + inputs_info_str = ", ".join(input_info_list) + op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=inputs_info_str + ) + + # generate get op info funciton: outputs + outputs_info_str = "" + if len(op_output_name_list) > 0: + output_info_list = [] + for idx in range(len(op_output_name_list)): + output_info_list.append( + CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( + name=op_output_name_list[idx], + typename=op_output_type_list[idx], + optional=op_output_optional_list[idx], + intermediate=op_output_intermediate_list[idx], + ) + ) + outputs_info_str = ", ".join(output_info_list) + op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=outputs_info_str + ) + + # generate get op info funciton: attributes + attribute_info_str = "" + if len(op_attribute_name_list) > 0: + attribute_info_list = [] + for idx in range(len(op_attribute_name_list)): + attribute_info_list.append( + CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( + name=op_attribute_name_list[idx], + typename=op_attribute_type_list[idx], + data_type=op_attribute_data_type_list[idx], + ) + ) + attribute_info_str = ", ".join(attribute_info_list) + op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( + op_name=op_class_name, impl=attribute_info_str + ) + # generate op verify function: inputs_type_check_str if len(op_input_type_list) == 0: inputs_type_check_str = ( @@ -425,7 +597,7 @@ def OpGenerator( is_vector = True input_type = input_type[15:-1] check_str = "" - if is_optional: + if is_optional == "true": if is_vector: check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( index=idx, standard=input_type @@ -460,7 +632,7 @@ def OpGenerator( is_vector = True output_type = output_type[15:-1] check_str = "" - if is_optional: + if is_optional == "true": if is_vector: check_str = ( OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( @@ -515,6 +687,9 @@ def OpGenerator( ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) + ops_defined_list.append(op_inputs_info_func_str) + ops_defined_list.append(op_outputs_info_func_str) + ops_defined_list.append(op_attributes_info_func_str) ops_defined_list.append(op_verify_str) # (4) Generate head file str @@ -588,6 +763,7 @@ def ParseArguments(): # auto code generate OpGenerator( op_yaml_files, + op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index 9baeb4c1f9b1d0..bbb449464acda7 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -113,7 +113,7 @@ void PaddleDialect::initialize() { RegisterInterfaces(); RegisterOps(); } diff --git a/paddle/fluid/dialect/utils.h b/paddle/fluid/dialect/utils.h index 69af6fd1ddce80..85df76ed4e8f48 100644 --- a/paddle/fluid/dialect/utils.h +++ b/paddle/fluid/dialect/utils.h @@ -132,5 +132,45 @@ inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout( } } +struct OpInputInfo { + std::string name_ = ""; + std::string typename_ = ""; + bool optional_ = false; + bool no_need_buffer_ = false; + OpInputInfo(std::string name, + std::string type_name, + bool optional, + bool no_need_buffer) + : name_(name), + typename_(type_name), + optional_(optional), + no_need_buffer_(no_need_buffer) {} +}; + +struct OpOutputInfo { + std::string name_ = ""; + std::string typename_ = ""; + bool optional_ = false; + bool intermediate_ = false; + OpOutputInfo(std::string name, + std::string type_name, + bool optional, + bool intermediate) + : name_(name), + typename_(type_name), + optional_(optional), + intermediate_(intermediate) {} +}; + +struct OpAttributeInfo { + std::string name_ = ""; + std::string typename_ = ""; + std::string data_type_ = ""; + OpAttributeInfo(std::string name, + std::string type_name, + std::string data_type) + : name_(name), typename_(type_name), data_type_(data_type) {} +}; + } // namespace dialect } // namespace paddle From 1ba19dcdbff0cb460b4262541a771854b5807a32 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 26 May 2023 07:49:53 +0000 Subject: [PATCH 11/30] refine code --- paddle/fluid/dialect/legacy_pd_op.h | 45 +-- paddle/fluid/dialect/op_gen.py | 417 +++++++++++++++------------- paddle/fluid/dialect/pd_dialect.cc | 24 +- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- 4 files changed, 222 insertions(+), 266 deletions(-) diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h index 60a0a4b0871826..d835bf929e8d1a 100644 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ b/paddle/fluid/dialect/legacy_pd_op.h @@ -36,50 +36,9 @@ namespace dialect { // TODO(zhangbo): As operators are supplemented and defined, they are gradually // removed. -REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d -REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed -// REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: -// batch_norm +REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ -REIGSTER_EMPTY_OP(elementwise_add, - ElementwiseAddOp); // To be customized: add (elementwise_add) -REIGSTER_EMPTY_OP(pool2d, Pool2DOp); // To be customized: pool2d -REIGSTER_EMPTY_OP( - flatten_contiguous_range, - FlattenContiguousRangeOp); // flatten (flatten_contiguous_range) -REIGSTER_EMPTY_OP(matmul_v2, - MatmulV2Op); // To be customized: matmul (matmul_v2) -REIGSTER_EMPTY_OP(reshape2, Reshape2Op); // To be customized: reshape -REIGSTER_EMPTY_OP(softmax_with_cross_entropy, - SoftmaxWithCrossEntropyOp); // cross_entropy_with_softmax - // (softmax_with_cross_entropy) -REIGSTER_EMPTY_OP(reduce_mean, - ReduceMeanOp); // To be customized: mean (reduce_mean) -REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); // topk (top_k_v2) -REIGSTER_EMPTY_OP(fill_constant, - FillConstantOp); // To be customized: full (fill_constant) -REIGSTER_EMPTY_OP(reduce_mean_grad, - ReduceMeanGradOp); // To be customized: reduce_mean_grad -REIGSTER_EMPTY_OP( - softmax_with_cross_entropy_grad, - SoftmaxWithCrossEntropyGradOp); // cross_entropy_with_softmax_grad - // (softmax_with_cross_entropy_grad) -REIGSTER_EMPTY_OP( - elementwise_add_grad, - ElementwiseAddGradOp); // To be customized: add_grad (elementwise_add_grad) -REIGSTER_EMPTY_OP( - matmul_v2_grad, - MatmulV2GradOp); // To be customized: matmul_grad (matmul_v2_grad) -REIGSTER_EMPTY_OP( - flatten_contiguous_range_grad, - FlattenContiguousRangeGradOp); // flatten_grad - // (flatten_contiguous_range_grad) -REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad -// REIGSTER_EMPTY_OP(batch_norm_grad, -// BatchNormGradOp); // To be customized: batch_norm_grad -REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad -// REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) -REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 +REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 6abef860838c32..e0aec60bf81540 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -408,7 +408,21 @@ def parse_attribute_data_type_list(self): return data_type_list def parse_op_phi_name(self): - return self.op_yaml_item['name'] + if self.parse_op_inplace_info() is None: + return [self.op_yaml_item['name']] + else: + if self.op_yaml_item['name'][-1] == "_": + return [self.op_yaml_item['name']] + else: + return [ + self.op_yaml_item['name'], + self.op_yaml_item['name'] + "_", + ] + + def parse_op_inplace_info(self): + if 'inplace' in self.op_yaml_item: + return self.op_yaml_item['inplace'] + return None def to_pascal_case(s): @@ -456,9 +470,6 @@ def OpGenerator( ops_defined_list = [] # all op class defined store in this list for op_info in op_info_items: # get op info - op_name = op_info.op_phi_name - op_class_name = to_pascal_case(op_name) + "Op" - op_dialect_name = dialect_name + "." + op_name op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list op_input_optional_list = op_info.input_optional_list @@ -473,224 +484,232 @@ def OpGenerator( op_interfaces = [] op_traits = [] - # gen interface/trait str - op_interfaces_str = "" - if len(op_interfaces) > 0: - op_interfaces_str = "," + ",".join(op_interfaces) - op_traits_str = "" - if len(op_interfaces) > 0: - op_traits_str = "," + ",".join(op_traits) - - op_get_inputs_outputs_str = "" - for idx in range(len(op_input_name_list)): - op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( - input_name="input_" + str(idx) + "_" + op_input_name_list[idx], - input_index=idx, - ) - for idx in range(len(op_output_name_list)): - op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( - output_name="output_" - + str(idx) - + "_" - + op_output_name_list[idx], - output_index=idx, - ) - - # gen op_declare_str/op_defined_str - if len(op_attribute_name_list) == 0: - op_declare_str = OP_DECLARE_TEMPLATE.format( - op_name=op_class_name, - dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, - traits=op_traits_str, - attribute_declare=op_0_attribute_declare_str, - attribute_num=0, - get_inputs_and_outputs=op_get_inputs_outputs_str, - ) - op_defined_str = "" - else: - op_declare_str = OP_DECLARE_TEMPLATE.format( - op_name=op_class_name, - dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, - traits=op_traits_str, - attribute_declare=op_n_attribute_declare_str.format( - attribute_num=len(op_attribute_name_list) - ), - attribute_num=len(op_attribute_name_list), - get_inputs_and_outputs=op_get_inputs_outputs_str, - ) - attribute_names_str = ( - '"' + '", "'.join(op_attribute_name_list) + '"' - ) - op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( - op_name=op_class_name, - attribute_num=len(op_attribute_name_list), - attribute_names=attribute_names_str, - ) - - # generate get op info funciton: inputs - inputs_info_str = "" - if len(op_input_name_list) > 0: - input_info_list = [] + # If op has inplace info, we will generate inplace op and non-inplace op. + print("op_info.op_phi_name: ", op_info.op_phi_name) + for op_name in op_info.op_phi_name: + op_class_name = to_pascal_case(op_name) + "Op" + op_dialect_name = dialect_name + "." + op_name + + # gen interface/trait str + op_interfaces_str = "" + if len(op_interfaces) > 0: + op_interfaces_str = "," + ",".join(op_interfaces) + op_traits_str = "" + if len(op_interfaces) > 0: + op_traits_str = "," + ",".join(op_traits) + + op_get_inputs_outputs_str = "" for idx in range(len(op_input_name_list)): - input_info_list.append( - CONSTRUCT_INPUT_INFO_TEMPLATE.format( - name=op_input_name_list[idx], - typename=op_input_type_list[idx], - optional=op_input_optional_list[idx], - no_need_buffer=op_input_no_need_buffer_list[idx], - ) + op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( + input_name=op_input_name_list[idx], + input_index=idx, ) - inputs_info_str = ", ".join(input_info_list) - op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=inputs_info_str - ) - - # generate get op info funciton: outputs - outputs_info_str = "" - if len(op_output_name_list) > 0: - output_info_list = [] for idx in range(len(op_output_name_list)): - output_info_list.append( - CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( - name=op_output_name_list[idx], - typename=op_output_type_list[idx], - optional=op_output_optional_list[idx], - intermediate=op_output_intermediate_list[idx], - ) + op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( + output_name=op_output_name_list[idx], + output_index=idx, ) - outputs_info_str = ", ".join(output_info_list) - op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=outputs_info_str - ) - # generate get op info funciton: attributes - attribute_info_str = "" - if len(op_attribute_name_list) > 0: - attribute_info_list = [] - for idx in range(len(op_attribute_name_list)): - attribute_info_list.append( - CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( - name=op_attribute_name_list[idx], - typename=op_attribute_type_list[idx], - data_type=op_attribute_data_type_list[idx], - ) + # gen op_declare_str/op_defined_str + if len(op_attribute_name_list) == 0: + op_declare_str = OP_DECLARE_TEMPLATE.format( + op_name=op_class_name, + dialect_op_name=op_dialect_name, + interfaces=op_interfaces_str, + traits=op_traits_str, + attribute_declare=op_0_attribute_declare_str, + attribute_num=0, + get_inputs_and_outputs=op_get_inputs_outputs_str, + ) + op_defined_str = "" + else: + op_declare_str = OP_DECLARE_TEMPLATE.format( + op_name=op_class_name, + dialect_op_name=op_dialect_name, + interfaces=op_interfaces_str, + traits=op_traits_str, + attribute_declare=op_n_attribute_declare_str.format( + attribute_num=len(op_attribute_name_list) + ), + attribute_num=len(op_attribute_name_list), + get_inputs_and_outputs=op_get_inputs_outputs_str, + ) + attribute_names_str = ( + '"' + '", "'.join(op_attribute_name_list) + '"' + ) + op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( + op_name=op_class_name, + attribute_num=len(op_attribute_name_list), + attribute_names=attribute_names_str, ) - attribute_info_str = ", ".join(attribute_info_list) - op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( - op_name=op_class_name, impl=attribute_info_str - ) - # generate op verify function: inputs_type_check_str - if len(op_input_type_list) == 0: - inputs_type_check_str = ( - "// Inputs num is 0, not need to check inputs type." + # generate get op info funciton: inputs + inputs_info_str = "" + if len(op_input_name_list) > 0: + input_info_list = [] + for idx in range(len(op_input_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_input_name_list[idx], + typename=op_input_type_list[idx], + optional=op_input_optional_list[idx], + no_need_buffer=op_input_no_need_buffer_list[idx], + ) + ) + inputs_info_str = ", ".join(input_info_list) + op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=inputs_info_str ) - else: - inputs_type_check_str = "" - for idx in range(len(op_input_type_list)): - input_type = op_input_type_list[idx] - is_optional = op_input_optional_list[idx] - is_vector = False - if input_type.startswith("ir::VectorType<"): - is_vector = True - input_type = input_type[15:-1] - check_str = "" - if is_optional == "true": - if is_vector: - check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type + + # generate get op info funciton: outputs + outputs_info_str = "" + if len(op_output_name_list) > 0: + output_info_list = [] + for idx in range(len(op_output_name_list)): + output_info_list.append( + CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( + name=op_output_name_list[idx], + typename=op_output_type_list[idx], + optional=op_output_optional_list[idx], + intermediate=op_output_intermediate_list[idx], + ) ) - else: - check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type + outputs_info_str = ", ".join(output_info_list) + op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=outputs_info_str + ) + + # generate get op info funciton: attributes + attribute_info_str = "" + if len(op_attribute_name_list) > 0: + attribute_info_list = [] + for idx in range(len(op_attribute_name_list)): + attribute_info_list.append( + CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( + name=op_attribute_name_list[idx], + typename=op_attribute_type_list[idx], + data_type=op_attribute_data_type_list[idx], + ) ) + attribute_info_str = ", ".join(attribute_info_list) + op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( + op_name=op_class_name, impl=attribute_info_str + ) + + # generate op verify function: inputs_type_check_str + if len(op_input_type_list) == 0: + inputs_type_check_str = ( + "// Inputs num is 0, not need to check inputs type." + ) else: - if is_vector: - check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) + inputs_type_check_str = "" + for idx in range(len(op_input_type_list)): + input_type = op_input_type_list[idx] + is_optional = op_input_optional_list[idx] + is_vector = False + if input_type.startswith("ir::VectorType<"): + is_vector = True + input_type = input_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = ( + INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + ) + else: + check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) else: - check_str = INPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) - inputs_type_check_str += check_str + if is_vector: + check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + else: + check_str = INPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + inputs_type_check_str += check_str - # generate op verify function: outputs_type_check_str - if len(op_output_type_list) == 0: - outputs_type_check_str = ( - "// Outputs num is 0, not need to check outputs type." - ) - else: - outputs_type_check_str = "" - for idx in range(len(op_output_type_list)): - output_type = op_output_type_list[idx] - is_optional = op_output_optional_list[idx] - is_vector = False - if output_type.startswith("ir::VectorType<"): - is_vector = True - output_type = output_type[15:-1] - check_str = "" - if is_optional == "true": - if is_vector: - check_str = ( - OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + # generate op verify function: outputs_type_check_str + if len(op_output_type_list) == 0: + outputs_type_check_str = ( + "// Outputs num is 0, not need to check outputs type." + ) + else: + outputs_type_check_str = "" + for idx in range(len(op_output_type_list)): + output_type = op_output_type_list[idx] + is_optional = op_output_optional_list[idx] + is_vector = False + if output_type.startswith("ir::VectorType<"): + is_vector = True + output_type = output_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = ( + OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + ) + else: + check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( index=idx, standard=output_type ) - ) else: - check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type - ) + if is_vector: + check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + else: + check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + outputs_type_check_str += check_str + + # generate op verify function: attributes_check_str + if len(op_attribute_name_list) == 0: + attributes_check_str = ( + "// Attributes num is 0, not need to check attributes type." + ) else: - if is_vector: - check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type + attributes_check_str = "" + for idx in range(len(op_attribute_name_list)): + attribute_name = op_attribute_name_list[idx] + attribute_type = op_attribute_type_list[idx] + if attribute_type.startswith("ir::ArrayAttribute<"): + attribute_type = attribute_type[19:-1] + attributes_check_str += ( + ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( + attribute_name=attribute_name, + standard=attribute_type, + ) ) else: - check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type + attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( + attribute_name=attribute_name, standard=attribute_type ) - outputs_type_check_str += check_str - # generate op verify function: attributes_check_str - if len(op_attribute_name_list) == 0: - attributes_check_str = ( - "// Attributes num is 0, not need to check attributes type." + # generate op verify function + op_verify_str = OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + inputs_size=len(op_input_type_list), + outputs_size=len(op_output_type_list), + inputs_type_check=inputs_type_check_str, + outputs_type_check=outputs_type_check_str, + attributes_check=attributes_check_str, ) - else: - attributes_check_str = "" - for idx in range(len(op_attribute_name_list)): - attribute_name = op_attribute_name_list[idx] - attribute_type = op_attribute_type_list[idx] - if attribute_type.startswith("ir::ArrayAttribute<"): - attribute_type = attribute_type[19:-1] - attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( - attribute_name=attribute_name, standard=attribute_type - ) - else: - attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( - attribute_name=attribute_name, standard=attribute_type - ) - - # generate op verify function - op_verify_str = OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - inputs_size=len(op_input_type_list), - outputs_size=len(op_output_type_list), - inputs_type_check=inputs_type_check_str, - outputs_type_check=outputs_type_check_str, - attributes_check=attributes_check_str, - ) - ops_name_list.append(op_class_name) - ops_declare_list.append(op_declare_str) - ops_defined_list.append(op_defined_str) - ops_defined_list.append(op_inputs_info_func_str) - ops_defined_list.append(op_outputs_info_func_str) - ops_defined_list.append(op_attributes_info_func_str) - ops_defined_list.append(op_verify_str) + ops_name_list.append(op_class_name) + ops_declare_list.append(op_declare_str) + ops_defined_list.append(op_defined_str) + ops_defined_list.append(op_inputs_info_func_str) + ops_defined_list.append(op_outputs_info_func_str) + ops_defined_list.append(op_attributes_info_func_str) + ops_defined_list.append(op_verify_str) # (4) Generate head file str op_namespaces_prev = "" diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index bbb449464acda7..05ec2001eaa88e 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -111,29 +111,7 @@ void PaddleDialect::initialize() { >(); RegisterInterfaces(); - RegisterOps(); + RegisterOps(); } void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4dba3fdd74ec80..e4ab20dc5ee8c3 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -296,7 +296,7 @@ - op : einsum args : (Tensor[] x, str equation) - output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} + output : Tensor(out), Tensor[](inner_cache){x.size()}, Tensor[](xshape){x.size()} infer_meta : func : EinsumRawInferMeta param : [x, equation] From 6c2ac5da6cc40702d9ab6ab16a8c6851f809a02a Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 26 May 2023 12:27:37 +0000 Subject: [PATCH 12/30] add interface --- paddle/fluid/dialect/CMakeLists.txt | 3 +- paddle/fluid/dialect/legacy_pd_op.h | 44 ---------------------- paddle/fluid/dialect/op_gen.py | 37 +++++++++--------- paddle/fluid/dialect/pd_dialect.cc | 2 - paddle/fluid/dialect/pd_interface.h | 58 +++++++++++++++++++++++++++++ paddle/fluid/dialect/pd_op.yaml | 52 ++++++++++++++++++++++++++ test/cpp/ir/ir_program_test.cc | 14 ++++++- 7 files changed, 145 insertions(+), 65 deletions(-) delete mode 100644 paddle/fluid/dialect/legacy_pd_op.h create mode 100644 paddle/fluid/dialect/pd_interface.h create mode 100644 paddle/fluid/dialect/pd_op.yaml diff --git a/paddle/fluid/dialect/CMakeLists.txt b/paddle/fluid/dialect/CMakeLists.txt index 3f17ac21724c66..d041f2554887df 100644 --- a/paddle/fluid/dialect/CMakeLists.txt +++ b/paddle/fluid/dialect/CMakeLists.txt @@ -16,8 +16,9 @@ set(op_backward_yaml_file1 set(op_backward_yaml_file2 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) +set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/dialect/pd_op.yaml) set(op_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2} + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3} ) set(op_namespace paddle,dialect) set(dialect_name pd) diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h deleted file mode 100644 index d835bf929e8d1a..00000000000000 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ir/op_base.h" - -namespace paddle { -namespace dialect { - -#define OPNAME(op_name) "pd." #op_name - -#define REIGSTER_EMPTY_OP(op_name, className) \ - class className : public ir::Op { \ - public: \ - static const char *name() { return OPNAME(op_name); } \ - static constexpr const char **attributes_name = nullptr; \ - static constexpr uint32_t attributes_num = 0; \ - static void verify(const std::vector &inputs, \ - const std::vector &outputs, \ - const ir::AttributeMap &attributes) { \ - LOG(WARNING) << "This is a fake verify"; \ - } \ - }; - -// TODO(zhangbo): As operators are supplemented and defined, they are gradually -// removed. -REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed -REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ -REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 - -} // namespace dialect -} // namespace paddle diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index e0aec60bf81540..a9ae9739376a73 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -33,6 +33,7 @@ #include "paddle/ir/op_base.h" #include "paddle/fluid/dialect/utils.h" +#include "paddle/fluid/dialect/pd_interface.h" {input} #endif @@ -48,9 +49,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static const char *name() {{ return "{dialect_op_name}"; }} {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; - static std::vector inputs_info(); - static std::vector outputs_info(); - static std::vector attributes_info(); + static OpInfoTuple GetOpInfo(); static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); {get_inputs_and_outputs} }}; @@ -86,6 +85,15 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ """ # get op input info +OP_INFO_TEMPLATE = """ +OpInfoTuple {op_name}::GetOpInfo() {{ + std::vector inputs = {{ {inputs} }}; + std::vector attributes = {{ {attributes} }}; + std::vector outputs = {{ {outputs} }}; + return std::make_tuple(inputs, attributes, outputs); +}} +""" + OP_INPUT_INFO_TEMPLATE = """ std::vector {op_name}::inputs_info() {{ return {{ {impl} }}; @@ -481,11 +489,10 @@ def OpGenerator( op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list op_attribute_data_type_list = op_info.attribute_data_type_list - op_interfaces = [] + op_interfaces = ["GetOpInfoInterface"] op_traits = [] # If op has inplace info, we will generate inplace op and non-inplace op. - print("op_info.op_phi_name: ", op_info.op_phi_name) for op_name in op_info.op_phi_name: op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name @@ -495,7 +502,7 @@ def OpGenerator( if len(op_interfaces) > 0: op_interfaces_str = "," + ",".join(op_interfaces) op_traits_str = "" - if len(op_interfaces) > 0: + if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) op_get_inputs_outputs_str = "" @@ -557,9 +564,6 @@ def OpGenerator( ) ) inputs_info_str = ", ".join(input_info_list) - op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=inputs_info_str - ) # generate get op info funciton: outputs outputs_info_str = "" @@ -575,9 +579,6 @@ def OpGenerator( ) ) outputs_info_str = ", ".join(output_info_list) - op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=outputs_info_str - ) # generate get op info funciton: attributes attribute_info_str = "" @@ -592,8 +593,12 @@ def OpGenerator( ) ) attribute_info_str = ", ".join(attribute_info_list) - op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( - op_name=op_class_name, impl=attribute_info_str + + op_info_func_str = OP_INFO_TEMPLATE.format( + op_name=op_class_name, + inputs=inputs_info_str, + attributes=attribute_info_str, + outputs=outputs_info_str, ) # generate op verify function: inputs_type_check_str @@ -706,9 +711,7 @@ def OpGenerator( ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) - ops_defined_list.append(op_inputs_info_func_str) - ops_defined_list.append(op_outputs_info_func_str) - ops_defined_list.append(op_attributes_info_func_str) + ops_defined_list.append(op_info_func_str) ops_defined_list.append(op_verify_str) # (4) Generate head file str diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index 05ec2001eaa88e..664d3873828a50 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/dialect/pd_attribute.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/dialect/CMakeLists.txt. -#include "paddle/fluid/dialect/legacy_pd_op.h" #include "paddle/fluid/dialect/pd_op.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/pd_type_storage.h" @@ -111,7 +110,6 @@ void PaddleDialect::initialize() { >(); RegisterInterfaces(); - RegisterOps(); } void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h new file mode 100644 index 00000000000000..25c2397c00d1c1 --- /dev/null +++ b/paddle/fluid/dialect/pd_interface.h @@ -0,0 +1,58 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/dialect/utils.h" +#include "paddle/ir/op_base.h" + +using OpInfoTuple = std::tuple, + std::vector, + std::vector>; + +namespace paddle { +namespace dialect { +class GetOpInfoInterface : public ir::OpInterfaceBase { + public: + struct Concept { + explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) + : get_op_info_(get_op_info) {} + OpInfoTuple (*get_op_info_)(ir::Operation *); + }; + + template + struct Model : public Concept { + static OpInfoTuple GetOpInfo(ir::Operation *op) { + ConcreteOp concret_op = ConcreteOp(op); + if (concret_op == nullptr) throw("concret_op is nullptr"); + return concret_op.GetOpInfo(); + } + + Model() : Concept(GetOpInfo) { + static_assert(sizeof(Model) == sizeof(Concept), + "sizeof(Model) != sizeof(Concept)"); + } + }; + + GetOpInfoInterface(ir::Operation *op, Concept *impl) + : ir::OpInterfaceBase(op), impl_(impl) {} + + OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } + + private: + Concept *impl_; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/dialect/pd_op.yaml b/paddle/fluid/dialect/pd_op.yaml new file mode 100644 index 00000000000000..7ca8646a038c6c --- /dev/null +++ b/paddle/fluid/dialect/pd_op.yaml @@ -0,0 +1,52 @@ +- name: feed + inputs: + - typename: Tensor[] + name: x + optional: false + no_need_buffer: false + data_transform: {} + attrs: + - {typename: int, name: col} + outputs: + - {typename: Tensor, name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: null + param: null + kernel: + func: null + param: null + backend: null + layout: null + data_type: null + dispatch: null + force_backend: null + inplace: null + backward: null +- name: fetch + inputs: + - typename: Tensor + name: x + optional: false + no_need_buffer: false + data_transform: {} + attrs: + - {typename: int, name: col} + outputs: + - {typename: 'Tensor[]', name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: null + param: null + kernel: + func: null + param: null + backend: null + layout: null + data_type: null + dispatch: null + force_backend: null + inplace: null + backward: null diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index 9fb72fec13c23b..0c443f122738a5 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -15,6 +15,7 @@ #include #include "paddle/fluid/dialect/pd_dialect.h" +#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/utils.h" #include "paddle/ir/builtin_attribute.h" @@ -177,7 +178,18 @@ TEST(program_test, program) { EXPECT_EQ(*(dst_tensor->data() + i), data_a[i] + data_b[i]); } - // (7) Def SetParameterOp(c, "c") + // (7) Def AbsOp(b) + ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); + std::unordered_map abs_op_attribute; + ir::Operation *abs_op = ir::Operation::create({op1->GetResultByIndex(0)}, + {dense_tensor_dtype}, + abs_op_attribute, + abs_info); + paddle::dialect::GetOpInfoInterface interface = + abs_op->dyn_cast(); + EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name_ == "x", true); + + // (8) Def SetParameterOp(c, "c") std::string op4_name = builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name()); ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name); From 6967ebacace89cfeee9ccd96209be825bbe4b153 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 29 May 2023 11:09:23 +0000 Subject: [PATCH 13/30] fix: op name normalization --- paddle/fluid/dialect/pd_op.yaml | 14 +++------- .../fluid/translator/attribute_translator.cc | 4 +-- paddle/fluid/translator/op_compat_gen.py | 4 +++ paddle/fluid/translator/op_translator.cc | 19 +++++++++++-- paddle/phi/api/yaml/op_compat.yaml | 11 +++++--- test/cpp/ir/program_translator_test.cc | 28 +++++++++++-------- 6 files changed, 50 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/dialect/pd_op.yaml b/paddle/fluid/dialect/pd_op.yaml index 7ca8646a038c6c..a8053b0d1cb83f 100644 --- a/paddle/fluid/dialect/pd_op.yaml +++ b/paddle/fluid/dialect/pd_op.yaml @@ -1,12 +1,7 @@ - name: feed - inputs: - - typename: Tensor[] - name: x - optional: false - no_need_buffer: false - data_transform: {} + inputs: [] attrs: - - {typename: int, name: col} + - {typename: str, name: name} outputs: - {typename: Tensor, name: out, optional: false, intermediate: false} no_need_buffer: null @@ -32,9 +27,8 @@ no_need_buffer: false data_transform: {} attrs: - - {typename: int, name: col} - outputs: - - {typename: 'Tensor[]', name: out, optional: false, intermediate: false} + - {typename: str, name: name} + outputs: [] no_need_buffer: null data_transform: null infer_meta: diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc index ababb7e3b50609..554b1f3ac7f3b3 100644 --- a/paddle/fluid/translator/attribute_translator.cc +++ b/paddle/fluid/translator/attribute_translator.cc @@ -31,7 +31,7 @@ class AttributeVisitor { ~AttributeVisitor() {} public: - ir::Attribute operator()(int i) { return ir::Int64_tAttribute::get(ctx, i); } + ir::Attribute operator()(int i) { return ir::Int32_tAttribute::get(ctx, i); } ir::Attribute operator()(float f) { return ir::FloatAttribute::get(ctx, f); } @@ -71,7 +71,7 @@ class AttributeVisitor { std::vector attrs; attrs.reserve(is.size()); for (const auto& v : is) { - attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); + attrs.push_back(ir::Int32_tAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index d6aeeeaf8ee4bf..85b50d1d8de5b7 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -57,6 +57,10 @@ def insert_new_mappings(op_name_str): insert_new_mappings(op_compat_item["op"]) if "backward" in op_compat_item: insert_new_mappings(op_compat_item["backward"]) + + # special op mappings + op_name_mappings["fetch_v2"] = "fetch" + op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 1451115a71d423..cdb3cd19ca7b37 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -47,8 +47,15 @@ using OpOutputMapping = std::unordered_map; static const char kTargetDialectPrefix[] = "pd."; +static const std::unordered_set special_inplace_ops = { + "batch_norm", +}; + inline bool IsInplace(const OpDesc& op_desc) { bool inplace = false; + if (special_inplace_ops.count(op_desc.Type())) { + return inplace; + } auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); @@ -319,10 +326,14 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); auto attribute_map = TranslateOpAttribute(op_desc); + VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; ir::Operation* operation = ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); + VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end."; program->InsertOp(operation); + + VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end."; RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); return operation; @@ -338,7 +349,9 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); - auto attribute_map = TranslateOpAttribute(op_desc); + ir::AttributeMap attribute_map = { + {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, + }; ir::Operation* operation = ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); @@ -356,7 +369,9 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types; auto op_info = LoopkUpOpInfo(ctx, op_desc); - auto attribute_map = TranslateOpAttribute(op_desc); + ir::AttributeMap attribute_map = { + {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, + }; ir::Operation* operation = ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index bbe3017e27eba7..f6dfe19ec0c5f3 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -14,6 +14,10 @@ # extra : # attrs : [bool is_test = false] + extra : + attrs : [bool use_mkldnn = false, bool use_quantizer = false, + str mkldnn_data_type = "float32", bool is_test = false] + - op : abs backward : abs_grad inputs : @@ -1532,7 +1536,7 @@ out : Out - op : mean (reduce_mean) - backward : reduce_mean_grad + backward : mean_grad (reduce_mean_grad) inputs : x : X outputs : @@ -1775,9 +1779,8 @@ - op : pool2d backward : pool2d_grad - extra : - attrs : [bool use_mkldnn = false, bool use_quantizer = false, - str mkldnn_data_type = "float32", bool is_test = false] + attrs: + kernel_size: ksize - op : pool3d backward : pool3d_grad diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index 0035f860c5861b..cf2fe272e9c7e8 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -48,16 +48,20 @@ ProgramDesc load_from_file(const std::string &file_name) { TEST(PaddleDialectTest, Translator) { LOG(WARNING) << "TODO"; - // auto p = load_from_file("restnet50_main.prog"); - // EXPECT_EQ(p.Size(), 1u); - - // ir::IrContext *ctx = ir::IrContext::Instance(); - // ctx->GetOrRegisterDialect(); - // ctx->GetOrRegisterDialect(); - // auto program = paddle::TranslateLegacyProgramToProgram(p); - - // size_t op_size = program->block()->size(); - // // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); - // VLOG(0) << *program; + auto p = load_from_file("restnet50_main.prog"); + EXPECT_EQ(p.Size(), 1u); + + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + size_t op_size = program->block()->size(); + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); + + std::ofstream ostrm( + "/home/lvyongkang/Paddle/build/log_resnet_after_translated", + std::ios::out); + ostrm << *program; } From df0b5cd124b2ee9242b0e67a7f0a10d991f5202c Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 29 May 2023 11:25:42 +0000 Subject: [PATCH 14/30] fix typo --- paddle/phi/api/yaml/op_compat.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f6dfe19ec0c5f3..8dad83915075cf 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -14,10 +14,6 @@ # extra : # attrs : [bool is_test = false] - extra : - attrs : [bool use_mkldnn = false, bool use_quantizer = false, - str mkldnn_data_type = "float32", bool is_test = false] - - op : abs backward : abs_grad inputs : From f1c02dd8b9687abd483e81c8b6a3d275048c3061 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 07:23:18 +0000 Subject: [PATCH 15/30] refactor input translator --- paddle/fluid/dialect/pd_interface.h | 12 +-- paddle/fluid/translator/op_compat_gen.py | 43 ++++++++- paddle/fluid/translator/op_compat_info.cc.j2 | 14 ++- paddle/fluid/translator/op_compat_info.h | 16 ++++ paddle/fluid/translator/op_translator.cc | 99 +++++++++++++++----- paddle/fluid/translator/utils.h | 41 ++++++++ paddle/ir/builtin_op.cc | 17 ++++ paddle/ir/builtin_op.h | 14 +++ 8 files changed, 220 insertions(+), 36 deletions(-) create mode 100644 paddle/fluid/translator/utils.h diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h index 25c2397c00d1c1..1dac2f86ec35ca 100644 --- a/paddle/fluid/dialect/pd_interface.h +++ b/paddle/fluid/dialect/pd_interface.h @@ -26,18 +26,14 @@ namespace dialect { class GetOpInfoInterface : public ir::OpInterfaceBase { public: struct Concept { - explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) + explicit Concept(OpInfoTuple (*get_op_info)()) : get_op_info_(get_op_info) {} - OpInfoTuple (*get_op_info_)(ir::Operation *); + OpInfoTuple (*get_op_info_)(); }; template struct Model : public Concept { - static OpInfoTuple GetOpInfo(ir::Operation *op) { - ConcreteOp concret_op = ConcreteOp(op); - if (concret_op == nullptr) throw("concret_op is nullptr"); - return concret_op.GetOpInfo(); - } + static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); } Model() : Concept(GetOpInfo) { static_assert(sizeof(Model) == sizeof(Concept), @@ -48,7 +44,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase { GetOpInfoInterface(ir::Operation *op, Concept *impl) : ir::OpInterfaceBase(op), impl_(impl) {} - OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } + OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); } private: Concept *impl_; diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index 85b50d1d8de5b7..e8d32774cb4517 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -14,6 +14,7 @@ import argparse from pathlib import Path +from typing import Dict import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined @@ -46,17 +47,48 @@ def to_phi_and_fluid_op_name(op_item): with open(op_compat_yaml_file, "r") as f: op_compat_infos = yaml.safe_load(f) op_name_mappings = {} + op_arg_name_mappings = {} for op_compat_item in op_compat_infos: - def insert_new_mappings(op_name_str): + def insert_new_mappings(op_name_str: str) -> str: normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) if normalized_name == legacy_name: - return + return normalized_name op_name_mappings[legacy_name] = normalized_name + return normalized_name + + def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): + if op_name is None: + return + if op_name not in op_arg_name_mappings: + op_arg_name_mappings[op_name] = {} + op_arg_name_mappings[op_name].update(arg_mapping) - insert_new_mappings(op_compat_item["op"]) + normalized_op_name = insert_new_mappings(op_compat_item["op"]) + normailized_backward_op_name = None if "backward" in op_compat_item: - insert_new_mappings(op_compat_item["backward"]) + normailized_backward_op_name = insert_new_mappings( + op_compat_item["backward"] + ) + if "inputs" in op_compat_item: + insert_new_arg_mappings( + normalized_op_name, op_compat_item["inputs"] + ) + insert_new_arg_mappings( + normailized_backward_op_name, op_compat_item["inputs"] + ) + if "attrs" in op_compat_item: + insert_new_arg_mappings(normalized_op_name, op_compat_item["attrs"]) + insert_new_arg_mappings( + normailized_backward_op_name, op_compat_item["attrs"] + ) + if "outputs" in op_compat_item: + insert_new_arg_mappings( + normalized_op_name, op_compat_item["outputs"] + ) + insert_new_arg_mappings( + normailized_backward_op_name, op_compat_item["outputs"] + ) # special op mappings op_name_mappings["fetch_v2"] = "fetch" @@ -64,7 +96,8 @@ def insert_new_mappings(op_name_str): op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( - op_name_paris=op_name_mappings + op_name_pairs=op_name_mappings, + op_arg_name_pairs=op_arg_name_mappings, ) f.write(op_compat_definition) diff --git a/paddle/fluid/translator/op_compat_info.cc.j2 b/paddle/fluid/translator/op_compat_info.cc.j2 index af42cf9b8abdc3..a44941595fbb8d 100644 --- a/paddle/fluid/translator/op_compat_info.cc.j2 +++ b/paddle/fluid/translator/op_compat_info.cc.j2 @@ -5,10 +5,22 @@ namespace translator { OpNameNormalizer::OpNameNormalizer() { op_name_mappings = { - {% for legacy_name, normalized_name in op_name_paris.items() %} + {% for legacy_name, normalized_name in op_name_pairs.items() %} { "{{legacy_name}}", "{{normalized_name}}" }, {% endfor %} }; + op_arg_name_mappings = { + {% for op_name, arg_name_mappings in op_arg_name_pairs.items() %} + { + "{{op_name}}", + { + {% for normalized_name, legacy_name in arg_name_mappings.items() %} + { "{{normalized_name}}", "{{legacy_name}}" }, + {% endfor %} + }, + }, + {% endfor %} + }; } } // namespace translator diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 86acafe7a0f1a5..52258a5b08b9ba 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -17,6 +17,8 @@ #include "glog/logging.h" +#include "paddle/fluid/translator/utils.h" + #pragma once namespace paddle { @@ -26,6 +28,8 @@ class OpNameNormalizer { private: OpNameNormalizer(); // Disallow instantiation outside of the class. std::unordered_map op_name_mappings; + std::unordered_map> + op_arg_name_mappings; public: OpNameNormalizer(const OpNameNormalizer&) = delete; @@ -44,6 +48,18 @@ class OpNameNormalizer { } return op_name_mappings.at(op_type); } + + std::string GetLegacyArgName(const std::string& op_type, + const std::string& arg_name) { + if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { + return UnderscoreToCamelCase(arg_name); + } + auto& arg_mappings = op_arg_name_mappings[op_type]; + if (arg_mappings.find(arg_name) == arg_mappings.end()) { + return UnderscoreToCamelCase(arg_name); + } + return arg_mappings.at(op_type); + } }; } // namespace translator diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index cdb3cd19ca7b37..d22b430b8569d6 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -15,12 +15,14 @@ #include "paddle/fluid/translator/op_translator.h" #include +#include #include #include #include #include #include +#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/translator/attribute_translator.h" #include "paddle/fluid/translator/op_compat_info.h" @@ -44,6 +46,12 @@ using BlockDesc = paddle::framework::BlockDesc; using VarDesc = paddle::framework::VarDesc; using OpOutputTypeList = std::vector; using OpOutputMapping = std::unordered_map; +using OpInputInfo = paddle::dialect::OpInputInfo; +using OpInputInfoList = std::vector; +using OpAttributeInfo = paddle::dialect::OpAttributeInfo; +using OpAttributeInfoList = std::vector; +using OpOutputInfo = paddle::dialect::OpOutputInfo; +using OpOutputInfoList = std::vector; static const char kTargetDialectPrefix[] = "pd."; @@ -150,13 +158,25 @@ inline ir::Operation* InsertCombineOperationForTarget( return operation; } +inline ir::Operation* InsertConstantOperationForOptionalArg( + ir::IrContext* ctx, ir::Program* program) { + std::string constant_op_name(ir::ConstantOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); + + ir::Type null_type = ir::Type(nullptr); + ir::Operation* operation = + ir::Operation::create({}, {null_type}, {}, op_info); + program->InsertOp(operation); + return operation; +} + inline std::vector GenerateOperationInput( ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, - const OpDesc& op_desc) { - std::vector op_inputs; - + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos) { // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { @@ -168,7 +188,7 @@ inline std::vector GenerateOperationInput( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s.%s as input should be exists before prasing %d", + "arg %s.%s as input should be exists before prasing %s", name, arg_name, op_desc.Type())); @@ -180,27 +200,41 @@ inline std::vector GenerateOperationInput( } } - for (const auto& n : op_desc.Inputs()) { - auto& name = n.first; - VLOG(10) << "[input retriving]" - << "[" << op_desc.Type() << "]" << name; - auto& args = n.second; + std::vector op_inputs; + auto& op_normalizer = OpNameNormalizer::instance(); - // if src type is Tensor or a Vector with size <= 1 - if (args.size() <= 1) { - for (const auto& arg_name : args) { - auto defining_info = (*param_map)[arg_name]; - op_inputs.push_back(defining_info.value); - } + for (const auto& info : input_infos) { + std::string legacy_input_name = + op_normalizer.GetLegacyArgName(normalized_op_name, info.name_); + + // return empty type if this arg is optional and not shown in OpDesc + // TODO(lyk): HasInput doesnot consider variadic attribute + if (!op_desc.HasInput(legacy_input_name)) { + PADDLE_ENFORCE(info.optional_, + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_input_name); + auto* constant_op = InsertConstantOperationForOptionalArg(ctx, program); + op_inputs.push_back(constant_op->GetResultByIndex(0)); + continue; + } + + const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); + bool is_vector = (info.typename_.find("VectorType") != std::string::npos); + // if src type is Tensor + if (!is_vector) { + auto defining_info = (*param_map)[legacy_input_vars[0]]; + op_inputs.push_back(defining_info.value); // if src type is Vector , need an additional `CombineOp` to // assemble them. } else { - auto* combine_op = - InsertCombineOperationForTarget(ctx, param_map, program, args); + auto* combine_op = InsertCombineOperationForTarget( + ctx, param_map, program, legacy_input_vars); op_inputs.push_back(combine_op->GetResultByIndex(0)); } } + return op_inputs; } @@ -319,12 +353,23 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto* op_info_concept = + op_info.GetInterfaceImpl(); + + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos) = + op_info_concept->get_op_info_(); + + auto op_inputs = GenerateOperationInput( + ctx, param_map, program, op_desc, op_info.name(), input_infos); OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; @@ -343,12 +388,12 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { + auto op_info = LoopkUpOpInfo(ctx, op_desc); std::vector op_inputs; OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::AttributeMap attribute_map = { {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, }; @@ -365,10 +410,20 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); + + auto* op_info_concept = + op_info.GetInterfaceImpl(); + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos) = + op_info_concept->get_op_info_(); + + auto op_inputs = GenerateOperationInput( + ctx, param_map, program, op_desc, op_info.name(), input_infos); OpOutputTypeList op_output_types; - auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::AttributeMap attribute_map = { {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, }; diff --git a/paddle/fluid/translator/utils.h b/paddle/fluid/translator/utils.h new file mode 100644 index 00000000000000..5711f1f550fb14 --- /dev/null +++ b/paddle/fluid/translator/utils.h @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace translator { + +static std::string UnderscoreToCamelCase(std::string str) { + std::string camel_case; + bool next_upper = true; + for (char c : str) { + if (c == '_') { + next_upper = true; + } else { + if (next_upper) { + camel_case += toupper(c); + next_upper = false; + } else { + camel_case += c; + } + } + } + return camel_case; +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 63bfc2196dca35..2b4900d6a8ff93 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -160,4 +160,21 @@ void SliceOp::verify(const std::vector &inputs, outputs[0])); } +void ConstantOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + // inputs.size() == 0 + PADDLE_ENFORCE_EQ( + inputs.size(), + 0, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); +} + } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index d1d2c20b5725db..b4042e3abc0f53 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -81,4 +81,18 @@ class SliceOp : public ir::Op { const ir::AttributeMap &attributes); }; +class ConstantOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.constant"; } + + static constexpr uint32_t attributes_num = 0; + + static constexpr const char **attributes_name = nullptr; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); +}; + } // namespace ir From 0854418515640d005bab7587316646c638132070 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 08:01:47 +0000 Subject: [PATCH 16/30] fix merge conflicts --- paddle/fluid/translator/attribute_translator.h | 6 +++--- paddle/fluid/translator/op_translator.cc | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h index 9e7117d5db6412..ed64c27b120eb1 100644 --- a/paddle/fluid/translator/attribute_translator.h +++ b/paddle/fluid/translator/attribute_translator.h @@ -14,9 +14,9 @@ #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/type_defs.h" -#include "paddle/ir/attribute.h" -#include "paddle/ir/builtin_attribute.h" -#include "paddle/ir/ir_context.h" +#include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/ir_context.h" #pragma once diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 9b3dd54d802dbe..42a9d9b05748cf 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -165,7 +165,7 @@ inline ir::Operation* InsertConstantOperationForOptionalArg( ir::Type null_type = ir::Type(nullptr); ir::Operation* operation = - ir::Operation::create({}, {null_type}, {}, op_info); + ir::Operation::create({}, {}, {null_type}, op_info); program->InsertOp(operation); return operation; } @@ -205,12 +205,12 @@ inline std::vector GenerateOperationInput( for (const auto& info : input_infos) { std::string legacy_input_name = - op_normalizer.GetLegacyArgName(normalized_op_name, info.name_); + op_normalizer.GetLegacyArgName(normalized_op_name, info.name); // return empty type if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute if (!op_desc.HasInput(legacy_input_name)) { - PADDLE_ENFORCE(info.optional_, + PADDLE_ENFORCE(info.optional, "Op %s arg %s should be optional if it can be empty", op_desc.Type(), legacy_input_name); @@ -220,7 +220,7 @@ inline std::vector GenerateOperationInput( } const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); - bool is_vector = (info.typename_.find("VectorType") != std::string::npos); + bool is_vector = (info.type_name.find("VectorType") != std::string::npos); // if src type is Tensor if (!is_vector) { From 38f74537495272727f95fda09277e49d1772b457 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 08:36:39 +0000 Subject: [PATCH 17/30] fix op normalizer bug --- paddle/fluid/translator/op_compat_gen.py | 26 ++++----- paddle/fluid/translator/op_compat_info.h | 16 ++++- paddle/fluid/translator/op_translator.cc | 74 ++++++++++++++++-------- 3 files changed, 75 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index e8d32774cb4517..edf85211b8cb16 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -53,9 +53,9 @@ def to_phi_and_fluid_op_name(op_item): def insert_new_mappings(op_name_str: str) -> str: normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) if normalized_name == legacy_name: - return normalized_name + return normalized_name, legacy_name op_name_mappings[legacy_name] = normalized_name - return normalized_name + return normalized_name, legacy_name def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): if op_name is None: @@ -64,30 +64,26 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): op_arg_name_mappings[op_name] = {} op_arg_name_mappings[op_name].update(arg_mapping) - normalized_op_name = insert_new_mappings(op_compat_item["op"]) - normailized_backward_op_name = None + _, legacy_name = insert_new_mappings(op_compat_item["op"]) + legacy_backward_op_name = None if "backward" in op_compat_item: - normailized_backward_op_name = insert_new_mappings( + _, legacy_backward_op_name = insert_new_mappings( op_compat_item["backward"] ) if "inputs" in op_compat_item: + insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) insert_new_arg_mappings( - normalized_op_name, op_compat_item["inputs"] - ) - insert_new_arg_mappings( - normailized_backward_op_name, op_compat_item["inputs"] + legacy_backward_op_name, op_compat_item["inputs"] ) if "attrs" in op_compat_item: - insert_new_arg_mappings(normalized_op_name, op_compat_item["attrs"]) + insert_new_arg_mappings(legacy_name, op_compat_item["attrs"]) insert_new_arg_mappings( - normailized_backward_op_name, op_compat_item["attrs"] + legacy_backward_op_name, op_compat_item["attrs"] ) if "outputs" in op_compat_item: + insert_new_arg_mappings(legacy_name, op_compat_item["outputs"]) insert_new_arg_mappings( - normalized_op_name, op_compat_item["outputs"] - ) - insert_new_arg_mappings( - normailized_backward_op_name, op_compat_item["outputs"] + legacy_backward_op_name, op_compat_item["outputs"] ) # special op mappings diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 52258a5b08b9ba..38ae7c0ae68233 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -58,7 +58,21 @@ class OpNameNormalizer { if (arg_mappings.find(arg_name) == arg_mappings.end()) { return UnderscoreToCamelCase(arg_name); } - return arg_mappings.at(op_type); + return arg_mappings.at(arg_name); + } + + std::string GetLegacyAttrName(const std::string& op_type, + const std::string& arg_name) { + if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { + VLOG(10) << "[" << op_type << "] not found"; + return arg_name; + } + auto& arg_mappings = op_arg_name_mappings[op_type]; + if (arg_mappings.find(arg_name) == arg_mappings.end()) { + VLOG(10) << "[" << op_type << "][" << arg_name << "] not found"; + return arg_name; + } + return arg_mappings.at(arg_name); } }; diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 42a9d9b05748cf..f9268c009cd925 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -205,7 +205,7 @@ inline std::vector GenerateOperationInput( for (const auto& info : input_infos) { std::string legacy_input_name = - op_normalizer.GetLegacyArgName(normalized_op_name, info.name); + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); // return empty type if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute @@ -289,38 +289,61 @@ inline std::tuple GenerateOperationOutput( return {op_output_types, arg_to_idx}; } -inline ir::AttributeMap TranslateOpAttribute(const OpDesc& op_desc) { +inline ir::AttributeMap TranslateOpAttribute( + std::string normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) { auto& attribute_translator = AttributeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); ir::AttributeMap attribute_map = {}; - for (auto attr_in_op_desc : op_desc.GetAttrMap()) { - const auto& attr_name = attr_in_op_desc.first; - const auto& attr_value = attr_in_op_desc.second; - VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name - << " " << attr_value.index(); - ir::Attribute new_attr = attribute_translator[attr_value]; - attribute_map[attr_name] = new_attr; + + for (const auto& info : op_attr_infos) { + auto legacy_attr_name = + op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); + paddle::framework::Attribute legacy_attr = + op_desc.GetAttr(legacy_attr_name); + VLOG(10) << "attribute in " << op_desc.Type() + << " name: " << legacy_attr_name << " " << legacy_attr.index(); + ir::Attribute new_attr = attribute_translator[legacy_attr]; + attribute_map[info.name] = new_attr; if (!new_attr) { VLOG(0) << "empty attribute in " << op_desc.Type() - << " name: " << attr_name; + << " name: " << info.name; } else { VLOG(10) << "new attribute in " << op_desc.Type() - << " name: " << attr_name << " " << new_attr.storage(); + << " name: " << info.name << " " << new_attr.storage(); } } - for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { - const auto& attr_name = attr_in_op_desc.first; - const auto& attr_value = attr_in_op_desc.second; - ir::Attribute new_attr = attribute_translator[attr_value]; - attribute_map[attr_name] = new_attr; - if (!new_attr) { - VLOG(0) << "empty runtime attribute in " << op_desc.Type() - << " name: " << attr_name; - } else { - VLOG(10) << "new runtime attribute in " << op_desc.Type() - << " name: " << attr_name << " " << new_attr.storage(); - } - } + // for (auto attr_in_op_desc : op_desc.GetAttrMap()) { + // const auto& attr_name = attr_in_op_desc.first; + // const auto& attr_value = attr_in_op_desc.second; + // VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name + // << " " << attr_value.index(); + // ir::Attribute new_attr = attribute_translator[attr_value]; + // attribute_map[attr_name] = new_attr; + // if (!new_attr) { + // VLOG(0) << "empty attribute in " << op_desc.Type() + // << " name: " << attr_name; + // } else { + // VLOG(10) << "new attribute in " << op_desc.Type() + // << " name: " << attr_name << " " << new_attr.storage(); + // } + // } + + // for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { + // const auto& attr_name = attr_in_op_desc.first; + // const auto& attr_value = attr_in_op_desc.second; + // ir::Attribute new_attr = attribute_translator[attr_value]; + // attribute_map[attr_name] = new_attr; + // if (!new_attr) { + // VLOG(0) << "empty runtime attribute in " << op_desc.Type() + // << " name: " << attr_name; + // } else { + // VLOG(10) << "new runtime attribute in " << op_desc.Type() + // << " name: " << attr_name << " " << new_attr.storage(); + // } + // } return std::move(attribute_map); } @@ -370,7 +393,8 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto attribute_map = TranslateOpAttribute(op_desc); + auto attribute_map = + TranslateOpAttribute(op_info.name(), attr_infos, op_desc); VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; ir::Operation* operation = From e8c623459ba2ed21928154015d519ae39b1d0cef Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 09:57:01 +0000 Subject: [PATCH 18/30] refactor attribute translator --- .../fluid/translator/attribute_translator.cc | 133 +++++++++++++++--- .../fluid/translator/attribute_translator.h | 10 +- paddle/fluid/translator/op_translator.cc | 9 +- 3 files changed, 131 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc index 554b1f3ac7f3b3..4390ba68616d56 100644 --- a/paddle/fluid/translator/attribute_translator.cc +++ b/paddle/fluid/translator/attribute_translator.cc @@ -18,6 +18,10 @@ #include #include "paddle/fluid/dialect/pd_attribute.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" #include "paddle/utils/variant.h" @@ -31,25 +35,38 @@ class AttributeVisitor { ~AttributeVisitor() {} public: - ir::Attribute operator()(int i) { return ir::Int32_tAttribute::get(ctx, i); } + virtual ir::Attribute operator()(int i) { + VLOG(10) << "translating int"; + return ir::Int32_tAttribute::get(ctx, i); + } - ir::Attribute operator()(float f) { return ir::FloatAttribute::get(ctx, f); } + virtual ir::Attribute operator()(float f) { + VLOG(10) << "translating float"; + return ir::FloatAttribute::get(ctx, f); + } - ir::Attribute operator()(bool b) { return ir::BoolAttribute::get(ctx, b); } + virtual ir::Attribute operator()(bool b) { + VLOG(10) << "translating bool"; + return ir::BoolAttribute::get(ctx, b); + } - ir::Attribute operator()(double d) { + virtual ir::Attribute operator()(double d) { + VLOG(10) << "translating double"; return ir::DoubleAttribute::get(ctx, d); } - ir::Attribute operator()(std::string str) { + virtual ir::Attribute operator()(std::string str) { + VLOG(10) << "translating string"; return ir::StrAttribute::get(ctx, str); } - ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + VLOG(10) << "translating scalar"; return paddle::dialect::ScalarAttribute::get(ctx, scalar); } - ir::Attribute operator()(const std::vector& strs) { + virtual ir::Attribute operator()(const std::vector& strs) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(strs.size()); for (const auto& v : strs) { @@ -58,7 +75,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& fs) { + virtual ir::Attribute operator()(const std::vector& fs) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(fs.size()); for (const auto& v : fs) { @@ -67,7 +85,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& is) { + virtual ir::Attribute operator()(const std::vector& is) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(is.size()); for (const auto& v : is) { @@ -76,7 +95,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& bs) { + virtual ir::Attribute operator()(const std::vector& bs) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(bs.size()); for (const auto& v : bs) { @@ -85,7 +105,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& i64s) { + virtual ir::Attribute operator()(const std::vector& i64s) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(i64s.size()); for (const auto& v : i64s) { @@ -94,7 +115,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& ds) { + virtual ir::Attribute operator()(const std::vector& ds) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(ds.size()); for (const auto& v : ds) { @@ -103,8 +125,9 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()( + virtual ir::Attribute operator()( const std::vector& ss) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(ss.size()); for (const auto& v : ss) { @@ -113,17 +136,95 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } + virtual ir::Attribute operator()(const paddle::blank& blank) { + VLOG(10) << "translating paddle::blank"; + return ir::Attribute(nullptr); + } + template ir::Attribute operator()(T attr) { + VLOG(10) << "translating null type"; return ir::Attribute(nullptr); } }; -AttributeTranslator::AttributeTranslator() { visitor = new AttributeVisitor(); } +class IntArrayAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(const std::vector& is) override { + VLOG(10) << "translating vector to IntArray"; + phi::IntArray data(is); + return paddle::dialect::IntArrayAttribute::get(ctx, data); + } + + ir::Attribute operator()(const std::vector& is) override { + VLOG(10) << "translating vector to IntArray"; + phi::IntArray data(is); + return paddle::dialect::IntArrayAttribute::get(ctx, data); + } +}; + +class ScalarAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(int i) override { + VLOG(10) << "translating int to Scalar"; + phi::Scalar data(i); + return paddle::dialect::ScalarAttribute::get(ctx, data); + } -ir::Attribute AttributeTranslator::operator[]( + ir::Attribute operator()(float f) override { + VLOG(10) << "translating float to Scalar"; + phi::Scalar data(f); + return paddle::dialect::ScalarAttribute::get(ctx, data); + } +}; + +class DataTypeAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(int i) override { + VLOG(10) << "translating int to DataType: " << i; + phi::DataType data = static_cast(i); + return paddle::dialect::DataTypeAttribute::get(ctx, data); + } +}; + +class PlaceAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + + ir::Attribute operator()(const paddle::blank& blank) override { + VLOG(10) << "translating paddle::blank"; + phi::Place data(phi::AllocationType::CPU); + return paddle::dialect::PlaceAttribute::get(ctx, data); + } +}; + +AttributeTranslator::AttributeTranslator() { + general_visitor = new AttributeVisitor(); + special_visitors["paddle::dialect::IntArrayAttribute"] = + new IntArrayAttributeVisitor(); + special_visitors["paddle::dialect::ScalarAttribute"] = + new ScalarAttributeVisitor(); + special_visitors["paddle::dialect::DataTypeAttribute"] = + new DataTypeAttributeVisitor(); + special_visitors["paddle::dialect::PlaceAttribute"] = + new PlaceAttributeVisitor(); +} + +ir::Attribute AttributeTranslator::operator()( const framework::Attribute& attr) { - return paddle::visit(*visitor, attr); + return paddle::visit(*general_visitor, attr); +} + +ir::Attribute AttributeTranslator::operator()( + const std::string& target_type, const framework::Attribute& attr) { + if (special_visitors.find(target_type) == special_visitors.end()) { + VLOG(10) << "[" << target_type << "] not found"; + return paddle::visit(*general_visitor, attr); + } + return paddle::visit(*(special_visitors.at(target_type)), attr); } } // namespace translator diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h index ed64c27b120eb1..ea509c7e346736 100644 --- a/paddle/fluid/translator/attribute_translator.h +++ b/paddle/fluid/translator/attribute_translator.h @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/ir/core/attribute.h" @@ -28,7 +31,8 @@ class AttributeVisitor; class AttributeTranslator { private: AttributeTranslator(); - AttributeVisitor* visitor; + AttributeVisitor* general_visitor; + std::unordered_map special_visitors; public: AttributeTranslator(const AttributeTranslator&) = delete; @@ -41,7 +45,9 @@ class AttributeTranslator { return attribute_translator; } - ir::Attribute operator[](const framework::Attribute& attr); + ir::Attribute operator()(const framework::Attribute& attr); + ir::Attribute operator()(const std::string& target_type, + const framework::Attribute& attr); }; } // namespace translator diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index f9268c009cd925..9cb0a3e41916c0 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -300,11 +300,14 @@ inline ir::AttributeMap TranslateOpAttribute( for (const auto& info : op_attr_infos) { auto legacy_attr_name = op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); - paddle::framework::Attribute legacy_attr = - op_desc.GetAttr(legacy_attr_name); + + paddle::framework::Attribute legacy_attr; + if (op_desc.HasAttr(legacy_attr_name)) { + legacy_attr = op_desc.GetAttr(legacy_attr_name); + } VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = attribute_translator[legacy_attr]; + ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); attribute_map[info.name] = new_attr; if (!new_attr) { VLOG(0) << "empty attribute in " << op_desc.Type() From 261fe9729692d066d1c1f2ee2d401e46b78efc09 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 13:07:05 +0000 Subject: [PATCH 19/30] fix bug --- paddle/fluid/translator/op_compat_info.h | 18 +++++++++++++ paddle/fluid/translator/op_translator.cc | 33 +----------------------- paddle/fluid/translator/utils.h | 1 + paddle/ir/core/builtin_dialect.cc | 2 +- paddle/ir/core/operation.cc | 6 +++++ paddle/ir/core/value.cc | 3 +++ paddle/phi/api/yaml/op_compat.yaml | 2 ++ 7 files changed, 32 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 38ae7c0ae68233..654329dbbe357e 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -51,6 +52,23 @@ class OpNameNormalizer { std::string GetLegacyArgName(const std::string& op_type, const std::string& arg_name) { + bool is_grad_op = (op_type.find("grad") != std::string::npos); + bool is_grad_arg = (arg_name.find("grad") != std::string::npos); + if (is_grad_op && is_grad_arg) { + std::string target = "_grad"; + std::string data = "@GRAD"; + + size_t first_grad_pos = arg_name.find_first_of(target); + std::string legacy_name = + this->GetLegacyArgName(op_type, arg_name.substr(0, first_grad_pos)); + legacy_name += arg_name.substr(first_grad_pos); + for (size_t pos = 0; + legacy_name.npos != (pos = legacy_name.find(target, pos)); + pos += data.length()) { + legacy_name.replace(pos, target.length(), data); + } + return legacy_name; + } if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { return UnderscoreToCamelCase(arg_name); } diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 9cb0a3e41916c0..a2cceaf282176b 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -214,8 +214,7 @@ inline std::vector GenerateOperationInput( "Op %s arg %s should be optional if it can be empty", op_desc.Type(), legacy_input_name); - auto* constant_op = InsertConstantOperationForOptionalArg(ctx, program); - op_inputs.push_back(constant_op->GetResultByIndex(0)); + op_inputs.push_back(ir::OpResult(nullptr)); continue; } @@ -318,36 +317,6 @@ inline ir::AttributeMap TranslateOpAttribute( } } - // for (auto attr_in_op_desc : op_desc.GetAttrMap()) { - // const auto& attr_name = attr_in_op_desc.first; - // const auto& attr_value = attr_in_op_desc.second; - // VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name - // << " " << attr_value.index(); - // ir::Attribute new_attr = attribute_translator[attr_value]; - // attribute_map[attr_name] = new_attr; - // if (!new_attr) { - // VLOG(0) << "empty attribute in " << op_desc.Type() - // << " name: " << attr_name; - // } else { - // VLOG(10) << "new attribute in " << op_desc.Type() - // << " name: " << attr_name << " " << new_attr.storage(); - // } - // } - - // for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { - // const auto& attr_name = attr_in_op_desc.first; - // const auto& attr_value = attr_in_op_desc.second; - // ir::Attribute new_attr = attribute_translator[attr_value]; - // attribute_map[attr_name] = new_attr; - // if (!new_attr) { - // VLOG(0) << "empty runtime attribute in " << op_desc.Type() - // << " name: " << attr_name; - // } else { - // VLOG(10) << "new runtime attribute in " << op_desc.Type() - // << " name: " << attr_name << " " << new_attr.storage(); - // } - // } - return std::move(attribute_map); } diff --git a/paddle/fluid/translator/utils.h b/paddle/fluid/translator/utils.h index 5711f1f550fb14..7065f46992c6aa 100644 --- a/paddle/fluid/translator/utils.h +++ b/paddle/fluid/translator/utils.h @@ -15,6 +15,7 @@ #pragma once #include +#include namespace paddle { namespace translator { diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index 600efd4bc240a7..eecc5dc05c154b 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -44,7 +44,7 @@ void BuiltinDialect::initialize() { Int64_tAttribute, ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 4f9575c03d349a..d6f2f4592fc42c 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -64,6 +64,7 @@ Operation *Operation::create(const std::vector &inputs, // 2. Malloc memory. char *base_ptr = reinterpret_cast(aligned_malloc(base_size, 8)); // 3.1. Construct OpResults. + VLOG(10) << "debug"; for (size_t idx = num_results; idx > 0; idx--) { if (idx > max_inline_result_num) { new (base_ptr) @@ -73,7 +74,9 @@ Operation *Operation::create(const std::vector &inputs, new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1); base_ptr += sizeof(detail::OpInlineResultImpl); } + VLOG(10) << "debug " << idx; } + VLOG(10) << "debug"; // 3.2. Construct Operation. Operation *op = new (base_ptr) Operation(attribute, op_info, num_results, num_operands, num_regions); @@ -82,10 +85,13 @@ Operation *Operation::create(const std::vector &inputs, if ((reinterpret_cast(base_ptr) & 0x7) != 0) { throw("The address of OpOperandImpl must be divisible by 8."); } + VLOG(10) << "debug"; for (size_t idx = 0; idx < num_operands; idx++) { new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); + VLOG(10) << "debug " << idx; base_ptr += sizeof(detail::OpOperandImpl); } + VLOG(10) << "debug"; // 3.4. Construct Regions if (num_regions > 0) { op->regions_ = reinterpret_cast(base_ptr); diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 631f0ba7adfd2a..30e920e089fd70 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -112,6 +112,9 @@ void OpOperandImpl::release_source() { source_ = nullptr; } OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) : source_(source), owner_(owner) { + if (!source) { + return; + } prev_use_addr_ = source.impl()->first_use_addr(); next_use_ = source.impl()->first_use(); if (next_use_) { diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index a1aa11bc39d581..eea7491b0633ca 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -462,6 +462,8 @@ - op : conv2d backward : conv2d_grad + inputs : + out : Output extra : attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, bool use_addto = false, bool force_fp32_output = false, From 6d57a5f5cbbd8e53d76dd2ce495cde6fe7cd6774 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 31 May 2023 08:17:01 +0000 Subject: [PATCH 20/30] refactor output translator --- paddle/fluid/dialect/op_gen.py | 29 +++++-- paddle/fluid/translator/op_translator.cc | 87 +++++++++++++++------ paddle/ir/core/operation.cc | 6 -- paddle/ir/core/printer.cc | 28 +++++-- paddle/ir/core/value.cc | 3 +- paddle/phi/api/yaml/op_compat.yaml | 1 + test/cpp/ir/core/program_translator_test.cc | 8 +- 7 files changed, 110 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index acb3b2721278f3..01914a5951f220 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -141,6 +141,14 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ }} """ +GRAD_OP_VERIFY_TEMPLATE = """ +void {op_name}::verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ + (void)inputs; + (void)outputs; + (void)attributes; +}} +""" + INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); """ @@ -703,14 +711,19 @@ def OpGenerator( ) # generate op verify function - op_verify_str = OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - inputs_size=len(op_input_type_list), - outputs_size=len(op_output_type_list), - inputs_type_check=inputs_type_check_str, - outputs_type_check=outputs_type_check_str, - attributes_check=attributes_check_str, - ) + if "Grad" in op_class_name: + op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + ) + else: + op_verify_str = OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + inputs_size=len(op_input_type_list), + outputs_size=len(op_output_type_list), + inputs_type_check=inputs_type_check_str, + outputs_type_check=outputs_type_check_str, + attributes_check=attributes_check_str, + ) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index a2cceaf282176b..e7d8112577c7ba 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -207,7 +207,7 @@ inline std::vector GenerateOperationInput( std::string legacy_input_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); - // return empty type if this arg is optional and not shown in OpDesc + // return empty OpResult if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute if (!op_desc.HasInput(legacy_input_name)) { PADDLE_ENFORCE(info.optional, @@ -225,6 +225,7 @@ inline std::vector GenerateOperationInput( if (!is_vector) { auto defining_info = (*param_map)[legacy_input_vars[0]]; op_inputs.push_back(defining_info.value); + // if src type is Vector , need an additional `CombineOp` to // assemble them. } else { @@ -238,48 +239,75 @@ inline std::vector GenerateOperationInput( } inline std::tuple GenerateOperationOutput( - ir::IrContext* ctx, const OpDesc& op_desc) { + ir::IrContext* ctx, + const OpDesc& op_desc, + const OpOutputInfoList& output_infos) { OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; auto& type_translator = TypeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); const BlockDesc* block = op_desc.Block(); - for (const auto& n : op_desc.Outputs()) { - auto& name = n.first; - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name; - auto& args = n.second; + for (const auto& info : output_infos) { size_t cur_output_idx = op_output_types.size(); + std::string legacy_output_name = + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); - // if src type is Tensor or a Vector with size <= 1 - if (args.size() <= 1) { - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name - << " " << var->GetType(); + // return empty type if this arg is optional and not shown in OpDesc + // TODO(lyk): HasOutput doesnot consider variadic attribute + if (!op_desc.HasOutput(legacy_output_name)) { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "] optional " << info.name << " :" + << info.type_name << " " << legacy_output_name; + PADDLE_ENFORCE(info.optional, + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name); + op_output_types.push_back(ir::Type(nullptr)); + continue; + } - ir::Type translated_var_type = - type_translator[var->GetType()](ctx, *var); + const auto& legacy_output_vars = op_desc.Output(legacy_output_name); + bool is_vector = (info.type_name.find("VectorType") != std::string::npos); - arg_to_idx[arg_name] = cur_output_idx; - op_output_types.push_back(translated_var_type); + // if src type is Tensor + if (!is_vector) { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " :" + << info.type_name << " " << legacy_output_name; + if (legacy_output_vars.size() == 0) { + op_output_types.push_back(ir::Type(nullptr)); + continue; } + auto& var_name = legacy_output_vars[0]; + VarDesc* var = block->FindVarRecursive(var_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " " << var_name + << " " << var->GetType(); + + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + + arg_to_idx[var_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); + // if src type is Vector } else { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " :" + << info.type_name << " " << legacy_output_name; std::vector types; - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); + for (const auto& var_name : legacy_output_vars) { + VarDesc* var = block->FindVarRecursive(var_name); VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name + << "[" << op_desc.Type() << "]" << info.name << " " << var_name << " " << var->GetType(); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); types.push_back(translated_var_type); - arg_to_idx[arg_name] = cur_output_idx; + arg_to_idx[var_name] = cur_output_idx; } ir::Type vec_type = ir::VectorType::get(ctx, types); op_output_types.push_back(vec_type); @@ -363,7 +391,8 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; - std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + std::tie(op_output_types, arg_to_idx) = + GenerateOperationOutput(ctx, op_desc, output_infos); auto attribute_map = TranslateOpAttribute(op_info.name(), attr_infos, op_desc); @@ -385,11 +414,21 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ir::Program* program, const OpDesc& op_desc) { auto op_info = LoopkUpOpInfo(ctx, op_desc); + + auto* op_info_concept = + op_info.GetInterfaceImpl(); + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos) = + op_info_concept->get_op_info_(); + std::vector op_inputs; OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; - std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + std::tie(op_output_types, arg_to_idx) = + GenerateOperationOutput(ctx, op_desc, output_infos); ir::AttributeMap attribute_map = { {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, }; diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index d6f2f4592fc42c..4f9575c03d349a 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -64,7 +64,6 @@ Operation *Operation::create(const std::vector &inputs, // 2. Malloc memory. char *base_ptr = reinterpret_cast(aligned_malloc(base_size, 8)); // 3.1. Construct OpResults. - VLOG(10) << "debug"; for (size_t idx = num_results; idx > 0; idx--) { if (idx > max_inline_result_num) { new (base_ptr) @@ -74,9 +73,7 @@ Operation *Operation::create(const std::vector &inputs, new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1); base_ptr += sizeof(detail::OpInlineResultImpl); } - VLOG(10) << "debug " << idx; } - VLOG(10) << "debug"; // 3.2. Construct Operation. Operation *op = new (base_ptr) Operation(attribute, op_info, num_results, num_operands, num_regions); @@ -85,13 +82,10 @@ Operation *Operation::create(const std::vector &inputs, if ((reinterpret_cast(base_ptr) & 0x7) != 0) { throw("The address of OpOperandImpl must be divisible by 8."); } - VLOG(10) << "debug"; for (size_t idx = 0; idx < num_operands; idx++) { new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); - VLOG(10) << "debug " << idx; base_ptr += sizeof(detail::OpOperandImpl); } - VLOG(10) << "debug"; // 3.4. Construct Regions if (num_regions > 0) { op->regions_ = reinterpret_cast(base_ptr); diff --git a/paddle/ir/core/printer.cc b/paddle/ir/core/printer.cc index 5dc91142fb5e28..8788365b08e40f 100644 --- a/paddle/ir/core/printer.cc +++ b/paddle/ir/core/printer.cc @@ -50,6 +50,11 @@ class Printer { explicit Printer(std::ostream& os) : os(os) {} void PrintType(ir::Type type) { + if (!type) { + os << ""; + return; + } + if (type.isa()) { os << "f16"; } else if (type.isa()) { @@ -82,10 +87,6 @@ class Printer { }; void Type::print(std::ostream& os) const { - if (!*this) { - os << ""; - return; - } Printer p(os); p.PrintType(*this); } @@ -104,6 +105,10 @@ class ProgramPrinter : public Printer { } void PrintValue(ir::Value v) { + if (!v) { + os << "<>"; + return; + } const void* key = static_cast(v.impl()); auto ret = aliases.find(key); if (ret != aliases.end()) { @@ -175,8 +180,12 @@ class ProgramPrinter : public Printer { std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - op_operand_types.push_back( - op->GetOperandByIndex(idx).impl()->source().type()); + auto op_operand = op->GetOperandByIndex(idx); + if (op_operand) { + op_operand_types.push_back(op_operand.impl()->source().type()); + } else { + op_operand_types.push_back(ir::Type(nullptr)); + } } PrintInterleave( op_operand_types.begin(), @@ -190,7 +199,12 @@ class ProgramPrinter : public Printer { std::vector op_result_types; op_result_types.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { - op_result_types.push_back(op->GetResultByIndex(idx).type()); + auto op_result = op->GetResultByIndex(idx); + if (op_result) { + op_result_types.push_back(op_result.type()); + } else { + op_result_types.push_back(ir::Type(nullptr)); + } } PrintInterleave( op_result_types.begin(), diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 30e920e089fd70..022d1cc105a575 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -42,7 +42,7 @@ bool OpOperand::operator!=(OpOperand other) const { bool OpOperand::operator!() const { return impl_ == nullptr; } -OpOperand::operator bool() const { return impl_; } +OpOperand::operator bool() const { return impl_ && impl_->source(); } detail::OpOperandImpl *OpOperand::impl() const { return impl_; } @@ -124,6 +124,7 @@ OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) } void OpOperandImpl::remove_from_ud_chain() { + if (!source_) return; if (!prev_use_addr_) return; if (prev_use_addr_ == source_.impl()->first_use_addr()) { /// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index eea7491b0633ca..ad21e9a5c74ac5 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1931,6 +1931,7 @@ x : X outputs: out : Out + xshape: XShape int_array: shape : data_type : int diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 81ac1d07ea44da..b98bedd016fca8 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -47,7 +47,6 @@ ProgramDesc load_from_file(const std::string &file_name) { } TEST(PaddleDialectTest, Translator) { - LOG(WARNING) << "TODO"; auto p = load_from_file("restnet50_main.prog"); EXPECT_EQ(p.Size(), 1u); @@ -58,10 +57,7 @@ TEST(PaddleDialectTest, Translator) { size_t op_size = program->block()->size(); // ops.size() = op size in BlockDesc + get_parameter_op + combine op - EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); + EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); - std::ofstream ostrm( - "/home/lvyongkang/Paddle/build/log_resnet_after_translated", - std::ios::out); - ostrm << *program; + std::cout << *program << std::endl; } From 0717f1105462a9fc1a219df5ba81cc4e3a1c54f5 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 31 May 2023 08:29:03 +0000 Subject: [PATCH 21/30] fix typo --- paddle/phi/api/yaml/op_compat.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index ad21e9a5c74ac5..96e1752f6eb489 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1780,6 +1780,9 @@ backward : pool2d_grad attrs: kernel_size: ksize + extra : + attrs : [bool use_mkldnn = false, bool use_quantizer = false, + str mkldnn_data_type = "float32", bool is_test = false] - op : pool3d backward : pool3d_grad From c9d7c63a936c3aef071ce629404f67a389338d8f Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 31 May 2023 10:59:36 +0000 Subject: [PATCH 22/30] fix --- paddle/fluid/translator/op_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index e7d8112577c7ba..bac7ea2efb1d1a 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -345,7 +345,7 @@ inline ir::AttributeMap TranslateOpAttribute( } } - return std::move(attribute_map); + return attribute_map; } inline void RecordOpResultMapping(TranslationContext* param_map, From 2fb54d8ee98072656e4ff1df427d28777a7e8b4a Mon Sep 17 00:00:00 2001 From: kangguangli Date: Thu, 1 Jun 2023 07:07:21 +0000 Subject: [PATCH 23/30] fix approval error --- paddle/fluid/translator/op_translator.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index bac7ea2efb1d1a..6d86108f2cfb36 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -211,9 +211,10 @@ inline std::vector GenerateOperationInput( // TODO(lyk): HasInput doesnot consider variadic attribute if (!op_desc.HasInput(legacy_input_name)) { PADDLE_ENFORCE(info.optional, - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_input_name); + platform::errors::PreconditionNotMet( + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_input_name)); op_inputs.push_back(ir::OpResult(nullptr)); continue; } @@ -262,9 +263,10 @@ inline std::tuple GenerateOperationOutput( << "[" << op_desc.Type() << "] optional " << info.name << " :" << info.type_name << " " << legacy_output_name; PADDLE_ENFORCE(info.optional, - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_output_name); + platform::errors::PreconditionNotMet( + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name)); op_output_types.push_back(ir::Type(nullptr)); continue; } From e945a413f81ade27ac3eab8d9e0024c242e8f983 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Thu, 1 Jun 2023 07:10:54 +0000 Subject: [PATCH 24/30] fix coverage --- test/cpp/ir/core/ir_program_test.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 8b78b739a65943..5eb6e4e11d0980 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -248,13 +248,10 @@ TEST(program_test, slice_combine_test) { ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info); program.InsertOp(op1); - // (5) Def b = GetParameterOp("b") - std::string op2_name = std::string(ir::GetParameterOp::name()); + // (5) Def b = Constant("b") + std::string op2_name = std::string(ir::ConstantOp::name()); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); - std::unordered_map op2_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; - ir::Operation *op2 = - ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info); + ir::Operation *op2 = ir::Operation::create({}, {}, {fp32_dtype}, op2_info); program.InsertOp(op2); // (6) Def combine_op = CombineOp("a", "b") From c9358f472a62f54bcd11433e94db732961f67596 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Thu, 1 Jun 2023 08:24:14 +0000 Subject: [PATCH 25/30] fix op_compat parser --- paddle/fluid/translator/op_compat_gen.py | 28 +++++++++++++----------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index edf85211b8cb16..5bc9df7ee8b34b 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -34,7 +34,7 @@ def OpNameNormalizerInitialization( op_compat_yaml_file: str = "", output_source_file: str = "" ) -> None: def to_phi_and_fluid_op_name(op_item): - # Templat: - op : phi_name (fluid_name) + # Template: - op : phi_name (fluid_name) names = op_item.split('(') if len(names) == 1: phi_fluid_name = names[0].strip() @@ -65,26 +65,28 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): op_arg_name_mappings[op_name].update(arg_mapping) _, legacy_name = insert_new_mappings(op_compat_item["op"]) - legacy_backward_op_name = None + legacy_backward_op_names = [] if "backward" in op_compat_item: - _, legacy_backward_op_name = insert_new_mappings( - op_compat_item["backward"] + backward_op_name_mapping_paris = op_compat_item["backward"].split( + "," ) + for pair in backward_op_name_mapping_paris: + _, legacy_backward_op_name = insert_new_mappings(pair) + legacy_backward_op_names.append(legacy_backward_op_name) + if "inputs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) - insert_new_arg_mappings( - legacy_backward_op_name, op_compat_item["inputs"] - ) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["inputs"]) + if "attrs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["attrs"]) - insert_new_arg_mappings( - legacy_backward_op_name, op_compat_item["attrs"] - ) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["attrs"]) if "outputs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["outputs"]) - insert_new_arg_mappings( - legacy_backward_op_name, op_compat_item["outputs"] - ) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["outputs"]) # special op mappings op_name_mappings["fetch_v2"] = "fetch" From b785d35a97df8e6a9abc612b2316304e0919b7bf Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 03:01:54 +0000 Subject: [PATCH 26/30] fix merge conflicts --- paddle/fluid/translator/op_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index e8e1f10865a67b..205dc88d3daaa2 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -166,7 +166,7 @@ inline ir::Operation* InsertConstantOperationForOptionalArg( ir::Type null_type = ir::Type(nullptr); ir::Operation* operation = ir::Operation::create({}, {}, {null_type}, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); return operation; } From a0f21e6677705103fad6da4eae52bf0e83241a49 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 03:21:25 +0000 Subject: [PATCH 27/30] fix merge conflicts --- paddle/fluid/translator/op_translator.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 205dc88d3daaa2..5dca36db1c8b59 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -385,7 +385,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpInputInfoList input_infos; OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos) = + std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); auto op_inputs = GenerateOperationInput( @@ -422,7 +422,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, OpInputInfoList input_infos; OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos) = + std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); std::vector op_inputs; @@ -454,7 +454,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, OpInputInfoList input_infos; OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos) = + std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); auto op_inputs = GenerateOperationInput( From c2ede6faaca4de8fba7b11ffc1d69671e9375c23 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 06:21:31 +0000 Subject: [PATCH 28/30] fix merge conflicts --- paddle/fluid/dialect/op_gen.py | 2 +- paddle/fluid/dialect/pd_interface.h | 5 +---- paddle/ir/core/builtin_dialect.cc | 7 ++++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 0e085734f42fbf..0d8c4d336f1329 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -1186,7 +1186,7 @@ def OpGenerator( ) # generate op verify function - if "Grad" in op_class_name: + if "GradOp" in op_class_name or "Grad_Op" in op_class_name: op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( op_name=op_class_name, ) diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h index 9ca4871e346747..f49ba35fcc3e77 100644 --- a/paddle/fluid/dialect/pd_interface.h +++ b/paddle/fluid/dialect/pd_interface.h @@ -36,10 +36,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase { struct Model : public Concept { static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); } - Model() : Concept(GetOpInfo) { - static_assert(sizeof(Model) == sizeof(Concept), - "sizeof(Model) != sizeof(Concept)"); - } + Model() : Concept(GetOpInfo) {} }; GetOpInfoInterface(ir::Operation *op, Concept *impl) diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index f3e1dbe77fe723..f3a5183cd6ca61 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -45,7 +45,12 @@ void BuiltinDialect::initialize() { Int64_tAttribute, ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir From 7f97345c4e7c853c13d92da7eb330a42c5cde4a3 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 06:31:44 +0000 Subject: [PATCH 29/30] fix merge conflicts --- paddle/ir/core/value.cc | 2 -- paddle/ir/core/value.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 3a4ddb8b9cd4b2..92537a17a9dd08 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -38,8 +38,6 @@ Value OpOperand::source() const { return impl_->source(); } Operation *OpOperand::owner() const { return impl_->owner(); } -OpOperand::operator bool() const { return impl_ && impl_->source(); } - // Value Value::Value(const detail::ValueImpl *impl) : impl_(const_cast(impl)) {} diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index 795aa6f5a80663..ee2221c1bebc5d 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -49,7 +49,7 @@ class OpOperand { bool operator!() const { return impl_ == nullptr; } - operator bool() const { return impl_; } + operator bool() const { return impl_ && impl_->source(); } OpOperand next_use() const; From 5f81713280bd5b3d10a9d01328306cc2072a7649 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 06:35:02 +0000 Subject: [PATCH 30/30] fix merge conflicts --- paddle/ir/core/value.cc | 1 + paddle/ir/core/value.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 92537a17a9dd08..d9032e8ae90a63 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -31,6 +31,7 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { impl_ = const_cast(impl); return *this; } +OpOperand::operator bool() const { return impl_ && impl_->source(); } OpOperand OpOperand::next_use() const { return impl_->next_use(); } diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index ee2221c1bebc5d..0825dd06cbd9bf 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -49,7 +49,7 @@ class OpOperand { bool operator!() const { return impl_ == nullptr; } - operator bool() const { return impl_ && impl_->source(); } + operator bool() const; OpOperand next_use() const;