From 2b5cc895fec3f5e9251ab04ad934354a5e157426 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 02:51:42 +0000 Subject: [PATCH 01/43] add vector type support for program translator --- paddle/fluid/translator/op_translator.cc | 142 +++++++++++++++--- paddle/fluid/translator/op_translator.h | 3 +- paddle/fluid/translator/program_translator.cc | 15 +- paddle/fluid/translator/program_translator.h | 28 +++- paddle/fluid/translator/translate.h | 7 +- paddle/ir/builtin_op.cc | 3 + paddle/ir/builtin_op.h | 35 ++++- paddle/ir/printer.cc | 40 +++-- test/cpp/ir/program_translator_test.cc | 5 +- 9 files changed, 231 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index c6ff3f94125fb5..65102a4eb0b192 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -24,6 +24,8 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" +#include "paddle/ir/builtin_op.h" +#include "paddle/ir/builtin_type.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/value.h" #include "paddle/phi/core/enforce.h" @@ -84,23 +86,101 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { return op_info; } +inline ir::Operation* InsertSliceOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const VariableDefiningInfo& defining_info, + const std::string& arg_name) { + std::string slice_op_name(ir::SliceOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); + std::unordered_map op_attribute_map = { + {"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)}, + }; + ir::VectorType src_vec_type = + defining_info.value.type().dyn_cast(); + ir::Operation* operation = + ir::Operation::create({defining_info.value}, + {src_vec_type[defining_info.idx_in_vector]}, + op_attribute_map, + op_info); + program->InsertOp(operation); + ir::OpResult target_op_result = operation->GetResultByIndex(0); + (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); + return operation; +} + +inline ir::Operation* InsertCombineOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const std::vector& args) { + std::string combine_op_name(ir::CombineOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + std::vector src_values; + std::vector types_in_vec; + for (auto arg_name : args) { + auto defining_info = param_map->at(arg_name); + src_values.push_back(defining_info.value); + types_in_vec.push_back(defining_info.value.type()); + } + ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); + ir::Operation* operation = + ir::Operation::create(src_values, {target_vec_type}, {}, op_info); + program->InsertOp(operation); + return operation; +} + inline std::vector GenerateOperationInput( - TranslationContext* param_map, const OpDesc& op_desc) { + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const OpDesc& op_desc) { std::vector op_inputs = {}; + + // scan all inputs to see if any of them is generated as a vector + // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { auto& name = n.first; - VLOG(10) << "[input retriving]" - << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + for (const auto& arg_name : args) { PADDLE_ENFORCE_NE( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s as input should be exists before prasing %d", + "arg %s.%s as input should be exists before prasing %d", + name, arg_name, op_desc.Type())); - op_inputs.push_back((*param_map)[arg_name]); + auto defining_info = (*param_map)[arg_name]; + if (defining_info.generated_by_vector) { + InsertSliceOperationForTarget( + ctx, param_map, program, defining_info, arg_name); + } + } + } + + for (const auto& n : op_desc.Inputs()) { + auto& name = n.first; + VLOG(10) << "[input retriving]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + auto defining_info = (*param_map)[arg_name]; + op_inputs.push_back(defining_info.value); + } + + // if src type is Vector , need an additional `CombineOp` to + // assemble them. + } else { + auto* combine_op = + InsertCombineOperationForTarget(ctx, param_map, program, args); + op_inputs.push_back(combine_op->GetResultByIndex(0)); } } return op_inputs; @@ -119,16 +199,39 @@ inline std::tuple GenerateOperationOutput( VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name << " " - << var->GetType(); - - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); - arg_to_idx[arg_name] = op_output_types.size(); - op_output_types.push_back(translated_var_type); + size_t cur_output_idx = op_output_types.size(); + + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + + arg_to_idx[arg_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); + } + + // if src type is Vector + } else { + std::vector types; + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); + types.push_back(translated_var_type); + arg_to_idx[arg_name] = cur_output_idx; + } + ir::Type vec_type = ir::VectorType::get(ctx, types); + op_output_types.push_back(vec_type); } } return {op_output_types, arg_to_idx}; @@ -143,12 +246,17 @@ inline void RecordOpResultMapping(TranslationContext* param_map, VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << name; auto& args = n.second; + size_t idx_in_vector = 0; for (const auto& arg_name : args) { auto idx = arg_to_idx.at(arg_name); VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << arg_name << " " << idx; - (*param_map)[arg_name] = operation->GetResultByIndex(idx); + ir::OpResult value = operation->GetResultByIndex(idx); + bool generated_by_vector = value.type().isa(); + (*param_map)[arg_name] = VariableDefiningInfo( + value, generated_by_vector, generated_by_vector ? idx_in_vector : -1); + idx_in_vector++; } } } @@ -157,7 +265,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; @@ -193,7 +301,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(param_map, op_desc); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); diff --git a/paddle/fluid/translator/op_translator.h b/paddle/fluid/translator/op_translator.h index c767f639d534b9..92c03458300257 100644 --- a/paddle/fluid/translator/op_translator.h +++ b/paddle/fluid/translator/op_translator.h @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/translator/program_translator.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/operation.h" #include "paddle/ir/program.h" @@ -28,8 +29,6 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; - class OpTranslator { public: using ResultIdx = size_t; diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 7cdbef589067ba..7618972f108040 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/translator/op_translator.h" +#include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/attribute.h" #include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_type.h" @@ -38,6 +39,11 @@ ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ctx = ir::IrContext::Instance(); } +const std::unordered_set ProgramTranslator::no_cast_var_names = { + "feed", + "fetch", +}; + void ProgramTranslator::Translate() { PADDLE_ENFORCE_EQ( legacy_program->Size(), @@ -59,19 +65,24 @@ void ProgramTranslator::Translate() { void ProgramTranslator::ExtractParameterFromSingleBlock( const BlockDesc& block) { + auto& type_translator = TypeTranslator::instance(); + for (auto& var : block.AllVars()) { if (!var->Persistable()) continue; if (param_map.count(var->Name()) != 0) continue; + if (no_cast_var_names.count(var->Name()) != 0) continue; std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, }; + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( - {}, {ir::Float32Type::get(ctx)}, op_attribute_map, op_info); + {}, {translated_var_type}, op_attribute_map, op_info); program->InsertOp(operation); - param_map[var->Name()] = operation->GetResultByIndex(0); + param_map[var->Name()] = + VariableDefiningInfo(operation->GetResultByIndex(0)); VLOG(10) << "[op translated][get parameter]" << operation; program->SetParameter(var->Name(), nullptr); diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 569b93b06aa6d9..f7fd4e2890ea64 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -27,7 +27,25 @@ namespace paddle { namespace translator { -using TranslationContext = std::unordered_map; +struct VariableDefiningInfo { + VariableDefiningInfo(ir::OpResult value, + bool generated_by_vector = false, + int idx_in_vector = -1) + : value(value), + generated_by_vector(generated_by_vector), + idx_in_vector(idx_in_vector) {} + VariableDefiningInfo() {} + + ir::OpResult value; + + bool generated_by_vector = + false; // true if target variabe is generated by Vector + int idx_in_vector = + -1; // positive if target variabe is generated by Vector +}; + +using TranslationContext = + std::unordered_map; class ProgramTranslator { using ProgramDesc = ::paddle::framework::ProgramDesc; @@ -45,6 +63,14 @@ class ProgramTranslator { TranslationContext param_map; ir::IrContext* ctx; + /// In the legacy program desc, there are two special named varibales: + /// 1. "feed", the input variable of feed op + /// 2. "fetch", the output variable of fetch op + /// However, new feed has no input and new fetch has no output + /// So we don't handle these two vairables when + /// `ExtractParameterFromSingleBlock` + static const std::unordered_set no_cast_var_names; + void ExtractParameterFromSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); }; diff --git a/paddle/fluid/translator/translate.h b/paddle/fluid/translator/translate.h index aa2571f74c4cf4..8d6b7e7b1bbdc2 100644 --- a/paddle/fluid/translator/translate.h +++ b/paddle/fluid/translator/translate.h @@ -22,10 +22,7 @@ namespace paddle { -using LegacyProgramDesc = ::paddle::framework::ProgramDesc; -using Program = ::ir::Program; - -std::unique_ptr TranslateLegacyProgramToProgram( - const LegacyProgramDesc& legacy_program); +std::unique_ptr<::ir::ProgramProgram> TranslateLegacyProgramToProgram( + const ::paddle::framework::ProgramDesc& legacy_program); } // namespace paddle diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index d000d086b0f4cd..5d618af72c8784 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -21,4 +21,7 @@ const char *GetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = { "parameter_name"}; +const char **CombineOp::attributes_name = nullptr; +const char *SliceOp::attributes_name[attributes_num] = {"index"}; + } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index ca29867ff4a132..ae2c9c768dc5ba 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -22,7 +22,8 @@ namespace ir { /// The built-in Dialect will use this macro to quickly register all built-in /// OPs. /// -#define GET_BUILT_IN_OP_LIST ir::GetParameterOp, ir::SetParameterOp +#define GET_BUILT_IN_OP_LIST \ + ir::GetParameterOp, ir::SetParameterOp, ir::CombineOp, ir::SliceOp /// /// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute, @@ -40,7 +41,7 @@ class GetParameterOp : public ir::Op { }; /// -/// \brief GetParameterOp: SetParameterOp(OpOperand, {StrAttribute, +/// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// StrAttribute}) /// class SetParameterOp : public ir::Op { @@ -54,4 +55,34 @@ class SetParameterOp : public ir::Op { static const char* attributes_name[attributes_num]; }; +/// +/// \brief CombineOp: CombineOp(OpOperand, {StrAttribute, +/// StrAttribute}) +/// +class CombineOp : public ir::Op { + public: + using Op::Op; + + static const char* name() { return "builtin.combine"; } + + static constexpr uint32_t attributes_num = 0; + + static const char** attributes_name; +}; + +/// +/// \brief SliceOp: SliceOp(OpOperand, {StrAttribute, +/// StrAttribute}) +/// +class SliceOp : public ir::Op { + public: + using Op::Op; + + static const char* name() { return "builtin.slice"; } + + static constexpr uint32_t attributes_num = 1; + + static const char* attributes_name[attributes_num]; +}; + } // namespace ir diff --git a/paddle/ir/printer.cc b/paddle/ir/printer.cc index b4d3acb930a342..421dedabb3a100 100644 --- a/paddle/ir/printer.cc +++ b/paddle/ir/printer.cc @@ -28,6 +28,21 @@ namespace ir { namespace { constexpr char newline[] = "\n"; + +template +void PrintInterleave(ForwardIterator begin, + ForwardIterator end, + UnaryFunctor print_func, + NullFunctor between_func) { + if (begin == end) return; + print_func(*begin); + begin++; + for (; begin != end; begin++) { + between_func(); + print_func(*begin); + } +} + } // namespace class Printer { @@ -47,6 +62,15 @@ class Printer { os << "i32"; } else if (type.isa()) { os << "i64"; + } else if (type.isa()) { + os << "vec<"; + auto inner_types = type.dyn_cast().data(); + PrintInterleave( + inner_types.begin(), + inner_types.end(), + [this](ir::Type v) { this->PrintType(v); }, + [this]() { this->os << ","; }); + os << ">"; } else { auto& dialect = type.dialect(); dialect.PrintType(type, os); @@ -77,22 +101,6 @@ class ProgramPrinter : public Printer { } } - template - void PrintInterleave(ForwardIterator begin, - ForwardIterator end, - UnaryFunctor print_func, - NullFunctor between_func) { - if (begin == end) return; - print_func(*begin); - begin++; - for (; begin != end; begin++) { - between_func(); - print_func(*begin); - } - } - void PrintValue(ir::Value v) { const void* key = static_cast(v.impl()); auto ret = aliases.find(key); diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index e88bbb2cf391fd..ac8141b398f19c 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -58,6 +58,7 @@ TEST(PaddleDialectTest, Translator) { auto program = paddle::TranslateLegacyProgramToProgram(p); std::list ops = program->ops(); - EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num()); - VLOG(0) << *program << std::endl; + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + 20); + std::cout << *program << std::endl; } From 489b2cc98cfe77445dd2e03c99ed408bff5018e6 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 07:25:30 +0000 Subject: [PATCH 02/43] polish --- paddle/fluid/translator/translate.h | 2 +- paddle/fluid/translator/type_translator.cc | 4 ++++ paddle/fluid/translator/type_translator.h | 10 +++++----- paddle/ir/builtin_op.h | 6 ++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/translator/translate.h b/paddle/fluid/translator/translate.h index 8d6b7e7b1bbdc2..9cf013ea2bca36 100644 --- a/paddle/fluid/translator/translate.h +++ b/paddle/fluid/translator/translate.h @@ -22,7 +22,7 @@ namespace paddle { -std::unique_ptr<::ir::ProgramProgram> TranslateLegacyProgramToProgram( +std::unique_ptr<::ir::Program> TranslateLegacyProgramToProgram( const ::paddle::framework::ProgramDesc& legacy_program); } // namespace paddle diff --git a/paddle/fluid/translator/type_translator.cc b/paddle/fluid/translator/type_translator.cc index 9792a4b8537cef..e0c31913ffefe2 100644 --- a/paddle/fluid/translator/type_translator.cc +++ b/paddle/fluid/translator/type_translator.cc @@ -22,6 +22,10 @@ namespace paddle { namespace translator { +using OpDesc = paddle::framework::OpDesc; +using BlockDesc = paddle::framework::BlockDesc; +using VarDesc = paddle::framework::VarDesc; +using VarType = paddle::framework::proto::VarType; using DenseTensorType = paddle::dialect::DenseTensorType; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; diff --git a/paddle/fluid/translator/type_translator.h b/paddle/fluid/translator/type_translator.h index b16c1a222a5342..707b8913f3385d 100644 --- a/paddle/fluid/translator/type_translator.h +++ b/paddle/fluid/translator/type_translator.h @@ -27,13 +27,13 @@ namespace paddle { namespace translator { -using OpDesc = paddle::framework::OpDesc; -using BlockDesc = paddle::framework::BlockDesc; -using VarDesc = paddle::framework::VarDesc; -using VarType = paddle::framework::proto::VarType; -using TypeTranslateFn = std::function; +using TypeTranslateFn = + std::function; class TypeTranslator { + public: + using VarType = paddle::framework::proto::VarType; + private: TypeTranslator(); // Disallow instantiation outside of the class. std::unordered_map handlers; diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index ae2c9c768dc5ba..8d4485c988d84c 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -56,8 +56,7 @@ class SetParameterOp : public ir::Op { }; /// -/// \brief CombineOp: CombineOp(OpOperand, {StrAttribute, -/// StrAttribute}) +/// \brief CombineOp: CombineOp(OpOperand) /// class CombineOp : public ir::Op { public: @@ -71,8 +70,7 @@ class CombineOp : public ir::Op { }; /// -/// \brief SliceOp: SliceOp(OpOperand, {StrAttribute, -/// StrAttribute}) +/// \brief SliceOp: SliceOp(OpOperand) /// class SliceOp : public ir::Op { public: From 324d8972dc39771fce65f092ac558585ece7bc0a Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 12:13:04 +0000 Subject: [PATCH 03/43] support basic attribute type --- .../fluid/translator/attribute_translator.cc | 60 +++++++++++++++++++ .../fluid/translator/attribute_translator.h | 48 +++++++++++++++ paddle/fluid/translator/op_translator.cc | 48 ++++++++++++++- 3 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/translator/attribute_translator.cc create mode 100644 paddle/fluid/translator/attribute_translator.h diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc new file mode 100644 index 00000000000000..c2e6f24f39e478 --- /dev/null +++ b/paddle/fluid/translator/attribute_translator.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/translator/attribute_translator.h" + +#include +#include + +#include "paddle/utils/variant.h" + +namespace paddle { +namespace translator { + +class AttributeVisitor { + public: + ir::IrContext* ctx; + AttributeVisitor() { ctx = ir::IrContext::Instance(); } + ~AttributeVisitor() {} + + public: + ir::Attribute operator()(int i) { return ir::Int64_tAttribute::get(ctx, i); } + + ir::Attribute operator()(float f) { return ir::FloatAttribute::get(ctx, f); } + + ir::Attribute operator()(bool b) { return ir::BoolAttribute::get(ctx, b); } + + ir::Attribute operator()(double d) { + return ir::DoubleAttribute::get(ctx, d); + } + + ir::Attribute operator()(std::string str) { + return ir::StrAttribute::get(ctx, str); + } + + template + ir::Attribute operator()(T attr) { + return ir::Attribute(nullptr); + } +}; + +AttributeTranslator::AttributeTranslator() { visitor = new AttributeVisitor(); } + +ir::Attribute AttributeTranslator::operator[]( + const framework::Attribute& attr) { + return paddle::visit(*visitor, attr); +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h new file mode 100644 index 00000000000000..9e7117d5db6412 --- /dev/null +++ b/paddle/fluid/translator/attribute_translator.h @@ -0,0 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/ir/attribute.h" +#include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/ir_context.h" + +#pragma once + +namespace paddle { +namespace translator { + +class AttributeVisitor; + +class AttributeTranslator { + private: + AttributeTranslator(); + AttributeVisitor* visitor; + + public: + AttributeTranslator(const AttributeTranslator&) = delete; + AttributeTranslator& operator=(const AttributeTranslator&) = delete; + AttributeTranslator(AttributeTranslator&&) = delete; + AttributeTranslator& operator=(AttributeTranslator&&) = delete; + + static auto& instance() { + static AttributeTranslator attribute_translator; + return attribute_translator; + } + + ir::Attribute operator[](const framework::Attribute& attr); +}; + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 65102a4eb0b192..85ad10f73db37b 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -22,11 +22,13 @@ #include #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/translator/attribute_translator.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_type.h" #include "paddle/ir/ir_context.h" +#include "paddle/ir/operation.h" #include "paddle/ir/value.h" #include "paddle/phi/core/enforce.h" @@ -237,6 +239,40 @@ inline std::tuple GenerateOperationOutput( return {op_output_types, arg_to_idx}; } +inline ir::AttributeMap TranslateOpAttribute(const OpDesc& op_desc) { + auto& attribute_translator = AttributeTranslator::instance(); + ir::AttributeMap attribute_map = {}; + for (auto attr_in_op_desc : op_desc.GetAttrMap()) { + const auto& attr_name = attr_in_op_desc.first; + const auto& attr_value = attr_in_op_desc.second; + ir::Attribute new_attr = attribute_translator[attr_value]; + attribute_map[attr_name] = new_attr; + if (!new_attr) { + VLOG(0) << "empty attribute in " << op_desc.Type() + << " name: " << attr_name; + } else { + VLOG(10) << "new attribute in " << op_desc.Type() + << " name: " << attr_name << " " << new_attr.storage(); + } + } + + for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { + const auto& attr_name = attr_in_op_desc.first; + const auto& attr_value = attr_in_op_desc.second; + ir::Attribute new_attr = attribute_translator[attr_value]; + attribute_map[attr_name] = new_attr; + if (!new_attr) { + VLOG(0) << "empty runtime attribute in " << op_desc.Type() + << " name: " << attr_name; + } else { + VLOG(10) << "new runtime attribute in " << op_desc.Type() + << " name: " << attr_name << " " << new_attr.storage(); + } + } + + return std::move(attribute_map); +} + inline void RecordOpResultMapping(TranslationContext* param_map, const OpDesc& op_desc, ir::Operation* operation, @@ -271,8 +307,10 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types = {}; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); + ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); program->InsertOp(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -289,8 +327,10 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types = {}; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); + ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); program->InsertOp(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -305,8 +345,10 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); + ir::Operation* operation = - ir::Operation::create(op_inputs, op_output_types, {}, op_info); + ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); program->InsertOp(operation); return operation; From f9e17d695f70b80525486093f0fd675c5d0898e6 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 12:42:38 +0000 Subject: [PATCH 04/43] resolve conflicts --- paddle/ir/builtin_dialect.cc | 5 ++++- paddle/ir/builtin_op.cc | 7 +++++++ paddle/ir/builtin_op.h | 8 +++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/paddle/ir/builtin_dialect.cc b/paddle/ir/builtin_dialect.cc index 9c8cacf7bff94a..32f8750b710941 100644 --- a/paddle/ir/builtin_dialect.cc +++ b/paddle/ir/builtin_dialect.cc @@ -44,7 +44,10 @@ void BuiltinDialect::initialize() { ir::Int64_tAttribute, ir::ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 1f4ffaaa8c3510..1f1331c7396f03 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -59,6 +59,13 @@ void SetParameterOp::verify(const std::vector &inputs, } const char **CombineOp::attributes_name = nullptr; +void CombineOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) {} + const char *SliceOp::attributes_name[attributes_num] = {"index"}; +void SliceOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) {} } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index e2e855889cec78..5751632df58652 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -36,7 +36,7 @@ class GetParameterOp : public ir::Op { /// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// StrAttribute}) /// -class SliceOp : public ir::Op { +class SetParameterOp : public ir::Op { public: using Op::Op; static const char *name() { return "builtin.set_parameter"; } @@ -59,6 +59,9 @@ class CombineOp : public ir::Op { static constexpr uint32_t attributes_num = 0; static const char **attributes_name; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); }; /// @@ -73,6 +76,9 @@ class SliceOp : public ir::Op { static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); }; } // namespace ir From 9cd463f0cfce6c25430d8d884638bf1b2ae4bb14 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 03:49:36 +0000 Subject: [PATCH 05/43] add verify for combine/slice and unittests --- paddle/ir/builtin_op.cc | 98 ++++++++++++++++++++++++++++++++-- paddle/ir/builtin_op.h | 2 +- paddle/ir/type.cc | 7 ++- paddle/ir/type.h | 2 + test/cpp/ir/ir_program_test.cc | 59 ++++++++++++++++++++ 5 files changed, 163 insertions(+), 5 deletions(-) diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 1f1331c7396f03..63bfc2196dca35 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -13,7 +13,10 @@ // limitations under the License. #include "paddle/ir/builtin_op.h" + #include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/builtin_type.h" +#include "paddle/phi/core/enforce.h" namespace ir { const char *GetParameterOp::attributes_name[attributes_num] = { @@ -58,14 +61,103 @@ void SetParameterOp::verify(const std::vector &inputs, } } -const char **CombineOp::attributes_name = nullptr; void CombineOp::verify(const std::vector &inputs, const std::vector &outputs, - const ir::AttributeMap &attributes) {} + const ir::AttributeMap &attributes) { + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + // outputs[0].type == Vector + PADDLE_ENFORCE(outputs[0].isa(), + phi::errors::PreconditionNotMet( + "The type %s of outputs[0] must be equal to VectorType.", + outputs[0])); + ir::VectorType output_type = outputs[0].dyn_cast(); + // inputs.size() == outputs[0].size() + PADDLE_ENFORCE_EQ( + output_type.size(), + inputs.size(), + phi::errors::PreconditionNotMet( + "The size %d of outputs[0] must be equal to size %d of inputs.", + output_type.size(), + inputs.size())); + + // forall i in inputs.size(): inputs[i].type == outputs[0][i].type + for (size_t i = 0; i < inputs.size(); i++) { + PADDLE_ENFORCE_EQ( + output_type[i], + inputs[i].type(), + phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be " + "equal to type %s of inputs[%d].", + output_type[i], + i, + inputs[i].type(), + i)); + } +} const char *SliceOp::attributes_name[attributes_num] = {"index"}; void SliceOp::verify(const std::vector &inputs, const std::vector &outputs, - const ir::AttributeMap &attributes) {} + const ir::AttributeMap &attributes) { + // inputs.size() == 1 + PADDLE_ENFORCE_EQ( + inputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", inputs.size())); + + // inputs[0].type == Vector + PADDLE_ENFORCE(inputs[0].type().isa(), + phi::errors::PreconditionNotMet( + "The type %s of inputs[0] must be equal to VectorType.", + inputs[0].type())); + ir::VectorType input_type = inputs[0].type().dyn_cast(); + + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + + // attributes contains index: Int32 + PADDLE_ENFORCE_NE( + attributes.count("index"), + 0, + phi::errors::PreconditionNotMet("The attributes must contains index.")); + const ir::Attribute &attr = attributes.at("index"); + PADDLE_ENFORCE( + attr.isa(), + phi::errors::PreconditionNotMet("The attribute index must be INT32.")); + auto index = attr.dyn_cast().data(); + + // index >= 0 and < inputs[0].size() + PADDLE_ENFORCE_GE( + index, + 0, + phi::errors::PreconditionNotMet( + "The index %d must be greater or equal than 0.", index)); + PADDLE_ENFORCE_LT( + index, + input_type.size(), + phi::errors::PreconditionNotMet( + "The index %d must be less or equal than size %d of inputs[0].", + index, + input_type.size())); + + // inputs[index].type == outputs[0].type + PADDLE_ENFORCE_EQ( + input_type[index], + outputs[0], + phi::errors::PreconditionNotMet( + "The type %s of inputs[%d] must be equal to type %s of outputs[0].", + input_type[index], + index, + outputs[0])); +} } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index 5751632df58652..d1d2c20b5725db 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -58,7 +58,7 @@ class CombineOp : public ir::Op { static constexpr uint32_t attributes_num = 0; - static const char **attributes_name; + static constexpr const char **attributes_name = nullptr; static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); diff --git a/paddle/ir/type.cc b/paddle/ir/type.cc index bde3194f8fd414..e9c24672e5b5e4 100644 --- a/paddle/ir/type.cc +++ b/paddle/ir/type.cc @@ -16,6 +16,11 @@ #include "paddle/ir/dialect.h" namespace ir { -IrContext *Type::ir_context() const { return dialect().ir_context(); } +IrContext* Type::ir_context() const { return dialect().ir_context(); } + +std::ostream& operator<<(std::ostream& os, Type type) { + type.print(os); + return os; +} } // namespace ir diff --git a/paddle/ir/type.h b/paddle/ir/type.h index fce17db82ebf5a..89d153c089476e 100644 --- a/paddle/ir/type.h +++ b/paddle/ir/type.h @@ -89,6 +89,8 @@ class Type { const Storage *storage_{nullptr}; }; +std::ostream &operator<<(std::ostream &os, Type type); + } // namespace ir namespace std { diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index c430d9b320b406..711a40407de686 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -211,3 +211,62 @@ TEST(program_test, program) { EXPECT_EQ(ops.size() == 4, true); EXPECT_EQ(program.parameters_num() == 3, true); } + +TEST(program_test, slice_combine_test) { + // (1) Init environment. + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + // (2) Create an empty program object + ir::Program program; + // ir::Program *program = new ir::Program(); + EXPECT_EQ(program.ops().size() == 0, true); + + // (3) Create a float32 DenseTensor Parameter and save into Program + ir::Type fp32_dtype = ir::Float32Type::get(ctx); + + // (4) Def a = GetParameterOp("a") + std::string op1_name = ir::GetParameterOp::name(); + ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); + std::unordered_map op1_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "a")}}; + ir::Operation *op1 = + ir::Operation::create({}, {fp32_dtype}, op1_attribute, op1_info); + program.InsertOp(op1); + + // (5) Def b = GetParameterOp("b") + std::string op2_name = std::string(ir::GetParameterOp::name()); + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); + std::unordered_map op2_attribute{ + {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; + ir::Operation *op2 = + ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); + program.InsertOp(op2); + + std::string combine_op_name = std::string(ir::CombineOp::name()); + + ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + ir::Type output_type = + ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); + ir::Operation *combine_op = ir::Operation::create( + {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, + {output_type}, + {}, + combine_op_info); + program.InsertOp(combine_op); + + std::string slice_op_name = std::string(ir::SliceOp::name()); + ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); + ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); + ir::Operation *slice_op = + ir::Operation::create({combine_op->GetResultByIndex(0)}, + {fp32_dtype}, + {{"index", index_attr}}, + slice_op_info); + program.InsertOp(slice_op); + + // (8) Traverse Program + std::list ops = program.ops(); + EXPECT_EQ(ops.size() == 4, true); +} From 6d17079c47c48c4afcc140f85644eb9b190ffa3e Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 03:51:56 +0000 Subject: [PATCH 06/43] polish --- test/cpp/ir/ir_program_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index 711a40407de686..7c6c9acaf52c66 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -243,10 +243,9 @@ TEST(program_test, slice_combine_test) { ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info); program.InsertOp(op2); + // (6) Def combine_op = CombineOp("a", "b") std::string combine_op_name = std::string(ir::CombineOp::name()); - ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name); - ir::Type output_type = ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); ir::Operation *combine_op = ir::Operation::create( @@ -256,6 +255,7 @@ TEST(program_test, slice_combine_test) { combine_op_info); program.InsertOp(combine_op); + // (7) Def slice_op = SliceOp(combine_op, 0) std::string slice_op_name = std::string(ir::SliceOp::name()); ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); From b45df9c30cdf17cca9617ee8c35f8acc02a8548a Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 08:59:31 +0000 Subject: [PATCH 07/43] support more type in attribute translator --- paddle/fluid/dialect/pd_attribute.cc | 4 +- paddle/fluid/dialect/pd_attribute.h | 2 +- paddle/fluid/dialect/pd_attribute_storage.h | 4 +- .../fluid/translator/attribute_translator.cc | 70 +++++++++++++++++++ paddle/fluid/translator/op_translator.cc | 2 + paddle/fluid/translator/program_translator.cc | 2 +- test/cpp/ir/program_translator_test.cc | 20 +++--- 7 files changed, 89 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/dialect/pd_attribute.cc b/paddle/fluid/dialect/pd_attribute.cc index 49e6865160dd74..7d0cd5c44d0f09 100644 --- a/paddle/fluid/dialect/pd_attribute.cc +++ b/paddle/fluid/dialect/pd_attribute.cc @@ -18,7 +18,9 @@ namespace paddle { namespace dialect { phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); } -phi::Scalar ScalarAttribute::data() const { return storage()->GetAsKey(); } +paddle::experimental::Scalar ScalarAttribute::data() const { + return storage()->GetAsKey(); +} phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); } diff --git a/paddle/fluid/dialect/pd_attribute.h b/paddle/fluid/dialect/pd_attribute.h index 75eed82dfc4b35..a4a48da218849f 100644 --- a/paddle/fluid/dialect/pd_attribute.h +++ b/paddle/fluid/dialect/pd_attribute.h @@ -47,7 +47,7 @@ class ScalarAttribute : public ir::Attribute { return storage() < right.storage(); } - phi::Scalar data() const; + paddle::experimental::Scalar data() const; }; class DataTypeAttribute : public ir::Attribute { diff --git a/paddle/fluid/dialect/pd_attribute_storage.h b/paddle/fluid/dialect/pd_attribute_storage.h index 352dcc8b0e4343..2e066413af7d69 100644 --- a/paddle/fluid/dialect/pd_attribute_storage.h +++ b/paddle/fluid/dialect/pd_attribute_storage.h @@ -55,7 +55,7 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage { }; struct ScalarAttributeStorage : public ir::AttributeStorage { - using ParamKey = phi::Scalar; + using ParamKey = paddle::experimental::Scalar; explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; } @@ -73,7 +73,7 @@ struct ScalarAttributeStorage : public ir::AttributeStorage { ParamKey GetAsKey() const { return ParamKey(data_); } private: - phi::Scalar data_; + paddle::experimental::Scalar data_; }; struct DataTypeAttributeStorage : public ir::AttributeStorage { diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc index c2e6f24f39e478..ababb7e3b50609 100644 --- a/paddle/fluid/translator/attribute_translator.cc +++ b/paddle/fluid/translator/attribute_translator.cc @@ -17,6 +17,8 @@ #include #include +#include "paddle/fluid/dialect/pd_attribute.h" +#include "paddle/phi/common/scalar.h" #include "paddle/utils/variant.h" namespace paddle { @@ -43,6 +45,74 @@ class AttributeVisitor { return ir::StrAttribute::get(ctx, str); } + ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + return paddle::dialect::ScalarAttribute::get(ctx, scalar); + } + + ir::Attribute operator()(const std::vector& strs) { + std::vector attrs; + attrs.reserve(strs.size()); + for (const auto& v : strs) { + attrs.push_back(ir::StrAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& fs) { + std::vector attrs; + attrs.reserve(fs.size()); + for (const auto& v : fs) { + attrs.push_back(ir::FloatAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& is) { + std::vector attrs; + attrs.reserve(is.size()); + for (const auto& v : is) { + attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& bs) { + std::vector attrs; + attrs.reserve(bs.size()); + for (const auto& v : bs) { + attrs.push_back(ir::BoolAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& i64s) { + std::vector attrs; + attrs.reserve(i64s.size()); + for (const auto& v : i64s) { + attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()(const std::vector& ds) { + std::vector attrs; + attrs.reserve(ds.size()); + for (const auto& v : ds) { + attrs.push_back(ir::DoubleAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + ir::Attribute operator()( + const std::vector& ss) { + std::vector attrs; + attrs.reserve(ss.size()); + for (const auto& v : ss) { + attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + template ir::Attribute operator()(T attr) { return ir::Attribute(nullptr); diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 85ad10f73db37b..adf89cddc0bfdc 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -245,6 +245,8 @@ inline ir::AttributeMap TranslateOpAttribute(const OpDesc& op_desc) { for (auto attr_in_op_desc : op_desc.GetAttrMap()) { const auto& attr_name = attr_in_op_desc.first; const auto& attr_value = attr_in_op_desc.second; + VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name + << " " << attr_value.index(); ir::Attribute new_attr = attribute_translator[attr_value]; attribute_map[attr_name] = new_attr; if (!new_attr) { diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 7618972f108040..c98fc801a81708 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -75,7 +75,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { - {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, + {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, }; ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index 5ae7b0b9b31e7e..d7a35c17cbe878 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -48,18 +48,18 @@ ProgramDesc load_from_file(const std::string &file_name) { TEST(PaddleDialectTest, Translator) { LOG(WARNING) << "TODO"; - // auto p = load_from_file("restnet50_main.prog"); - // std::cout << p.Size() << std::endl; + auto p = load_from_file("restnet50_main.prog"); + std::cout << p.Size() << std::endl; - // EXPECT_EQ(p.Size(), 1u); + EXPECT_EQ(p.Size(), 1u); - // ir::IrContext *ctx = ir::IrContext::Instance(); - // ctx->GetOrRegisterDialect(); - // ctx->GetOrRegisterDialect(); - // auto program = paddle::TranslateLegacyProgramToProgram(p); + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); - // std::list ops = program->ops(); + std::list ops = program->ops(); // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + - // 20); std::cout << *program << std::endl; + EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + 20); + std::cout << *program << std::endl; } From 5940ac4afd381062f37536d70dbaf55e37eb77dd Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 09:03:18 +0000 Subject: [PATCH 08/43] modify by reviews --- paddle/fluid/translator/op_translator.cc | 4 +- paddle/fluid/translator/program_translator.h | 4 +- paddle/ir/builtin_dialect.cc | 43 +++++++++----------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index adf89cddc0bfdc..ec7b53cdaf3976 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -122,7 +122,7 @@ inline ir::Operation* InsertCombineOperationForTarget( std::vector src_values; std::vector types_in_vec; - for (auto arg_name : args) { + for (const auto& arg_name : args) { auto defining_info = param_map->at(arg_name); src_values.push_back(defining_info.value); types_in_vec.push_back(defining_info.value.type()); @@ -139,7 +139,7 @@ inline std::vector GenerateOperationInput( TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - std::vector op_inputs = {}; + std::vector op_inputs; // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index f7fd4e2890ea64..e185d9d0c33881 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -39,9 +39,9 @@ struct VariableDefiningInfo { ir::OpResult value; bool generated_by_vector = - false; // true if target variabe is generated by Vector + false; // true if target variable is generated by Vector int idx_in_vector = - -1; // positive if target variabe is generated by Vector + -1; // positive if target variable is generated by Vector }; using TranslationContext = diff --git a/paddle/ir/builtin_dialect.cc b/paddle/ir/builtin_dialect.cc index 32f8750b710941..955a6ebdce8c8f 100644 --- a/paddle/ir/builtin_dialect.cc +++ b/paddle/ir/builtin_dialect.cc @@ -18,36 +18,33 @@ #include "paddle/ir/builtin_type.h" namespace ir { -BuiltinDialect::BuiltinDialect(ir::IrContext *context) - : ir::Dialect(name(), context, ir::TypeId::get()) { +BuiltinDialect::BuiltinDialect(IrContext *context) + : Dialect(name(), context, TypeId::get()) { initialize(); } void BuiltinDialect::initialize() { // Register all built-in types defined in builtin_type.h. - RegisterTypes(); + RegisterTypes(); - RegisterAttributes(); + RegisterAttributes(); - RegisterOps(); + RegisterOps(); } } // namespace ir From b328cbb00ada67e3e71f50a766dd9c25565dd6e4 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 24 May 2023 10:16:33 +0000 Subject: [PATCH 09/43] fix merge mistakes --- paddle/fluid/translator/program_translator.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 851449580de6a6..e185d9d0c33881 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -71,14 +71,6 @@ class ProgramTranslator { /// `ExtractParameterFromSingleBlock` static const std::unordered_set no_cast_var_names; - /// In the legacy program desc, there are two special named varibales: - /// 1. "feed", the input variable of feed op - /// 2. "fetch", the output variable of fetch op - /// However, new feed has no input and new fetch has no output - /// So we don't handle these two vairables when - /// `ExtractParameterFromSingleBlock` - static const std::unordered_set no_cast_var_names; - void ExtractParameterFromSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); }; From e35a288002b45067ca7db91c3057ae9a7e5195f2 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 26 May 2023 05:59:45 +0000 Subject: [PATCH 10/43] refine code --- CMakeLists.txt | 2 +- paddle/fluid/dialect/CMakeLists.txt | 4 +- paddle/fluid/dialect/legacy_pd_op.h | 13 +- paddle/fluid/dialect/op_gen.py | 204 ++++++++++++++++++++++++++-- paddle/fluid/dialect/pd_dialect.cc | 6 +- paddle/fluid/dialect/utils.h | 40 ++++++ 6 files changed, 243 insertions(+), 26 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0bf93f0ed9d26f..bb6432e4175f26 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,7 +307,7 @@ option(WITH_CUDNN_FRONTEND "Compile with CUDNN Frontend API support (experimental)" OFF) option(WITH_CUDNN_DSO "Compile PaddlePaddle with cuDNN dynamic-link libraries" OFF) -option(WITH_NEWIR "Compile PaddlePaddle with NEWIR" OFF) +option(WITH_NEWIR "Compile PaddlePaddle with NEWIR" ON) if(WITH_RECORD_BUILDTIME) set_property( diff --git a/paddle/fluid/dialect/CMakeLists.txt b/paddle/fluid/dialect/CMakeLists.txt index 8130b75f637cf9..3f17ac21724c66 100644 --- a/paddle/fluid/dialect/CMakeLists.txt +++ b/paddle/fluid/dialect/CMakeLists.txt @@ -8,13 +8,13 @@ set(op_forward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml ) set(op_forward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_ops.parsed.yaml + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml ) set(op_backward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml ) set(op_backward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_backward.parsed.yaml + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) set(op_yaml_files ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2} diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h index 21be24720dc3fa..60a0a4b0871826 100644 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ b/paddle/fluid/dialect/legacy_pd_op.h @@ -36,9 +36,10 @@ namespace dialect { // TODO(zhangbo): As operators are supplemented and defined, they are gradually // removed. -REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d -REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed -REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: batch_norm +REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d +REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed +// REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: +// batch_norm REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ REIGSTER_EMPTY_OP(elementwise_add, ElementwiseAddOp); // To be customized: add (elementwise_add) @@ -74,10 +75,10 @@ REIGSTER_EMPTY_OP( FlattenContiguousRangeGradOp); // flatten_grad // (flatten_contiguous_range_grad) REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad -REIGSTER_EMPTY_OP(batch_norm_grad, - BatchNormGradOp); // To be customized: batch_norm_grad +// REIGSTER_EMPTY_OP(batch_norm_grad, +// BatchNormGradOp); // To be customized: batch_norm_grad REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad -REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) +// REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 } // namespace dialect diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 94930d0f133104..6abef860838c32 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -29,7 +29,10 @@ {op_declare} #else +#include + #include "paddle/ir/op_base.h" +#include "paddle/fluid/dialect/utils.h" {input} #endif @@ -45,6 +48,9 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static const char *name() {{ return "{dialect_op_name}"; }} {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; + static std::vector inputs_info(); + static std::vector outputs_info(); + static std::vector attributes_info(); static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); {get_inputs_and_outputs} }}; @@ -79,6 +85,37 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; """ +# get op input info +OP_INPUT_INFO_TEMPLATE = """ +std::vector {op_name}::inputs_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_INPUT_INFO_TEMPLATE = ( + """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})""" +) + +# get op output info +OP_OUTPUT_INFO_TEMPLATE = """ +std::vector {op_name}::outputs_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( + """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" +) + +# get op attribute info +OP_ATTRIBUTE_INFO_TEMPLATE = """ +std::vector {op_name}::attributes_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( + """OpAttributeInfo("{name}", "{typename}", "{data_type}")""" +) + +# verify OP_VERIFY_TEMPLATE = """ void {op_name}::verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}."; @@ -170,32 +207,65 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ """ +def to_phi_and_fluid_op_name(op_item): + # Templat: - op : phi_name (fluid_name) + names = op_item.split('(') + if len(names) == 1: + phi_fluid_name = names[0].strip() + return phi_fluid_name, phi_fluid_name + else: + phi_name = names[0].strip() + fluid_name = names[1].split(')')[0].strip() + return phi_name, fluid_name + + +# ===================================== +# Parse Op Compat From Yaml +# ===================================== +class OpCompatParser: + def __init__(self, ops_compat_yaml_file): + self.ops_compat_yaml_file = ops_compat_yaml_file + with open(self.ops_compat_yaml_file, "r") as f: + self.ops_compat = yaml.safe_load(f) + + def get_compat(self, op_name): + for compat in self.ops_compat: + phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op']) + if op_name == phi_name: + return compat + return None + + # ===================================== -# Parse Op information from Yaml item +# Parse Op Information From Yaml # ===================================== class OpInfoParser: - def __init__(self, op_yaml_item): + def __init__(self, op_yaml_item, op_compat_item): self.op_yaml_item = op_yaml_item + self.op_compat_item = op_compat_item self.op_phi_name = self.parse_op_phi_name() - + # parse inputs self.input_name_list = self.parse_input_name_list() self.input_type_list = self.parse_input_type_list() self.input_optional_list = self.parse_input_optional_list() + self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list() self.cross_check( self.input_name_list, self.input_type_list, self.input_optional_list ) - + # parse outputs self.output_name_list = self.parse_output_name_list() self.output_type_list = self.parse_output_type_list() self.output_optional_list = self.parse_output_optional_list() + self.output_intermediate_list = self.parse_output_intermediate_list() self.cross_check( self.output_name_list, self.output_type_list, self.output_optional_list, ) - + # parse attributes self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() + self.attribute_data_type_list = self.parse_attribute_data_type_list() self.cross_check(self.attribute_name_list, self.attribute_type_list) def cross_check(self, name_list, type_list, optional_list=None): @@ -229,9 +299,21 @@ def parse_input_type_list(self): def parse_input_optional_list(self): optional_list = [] for input_info in self.op_yaml_item['inputs']: - optional_list.append(input_info['optional']) + if input_info['optional']: + optional_list.append("true") + else: + optional_list.append("false") return optional_list + def parse_input_no_need_buffer_list(self): + no_need_buffer_list = [] + for input_info in self.op_yaml_item['inputs']: + if input_info['no_need_buffer']: + no_need_buffer_list.append("true") + else: + no_need_buffer_list.append("false") + return no_need_buffer_list + def parse_output_name_list(self): name_list = [] for output_info in self.op_yaml_item['outputs']: @@ -255,11 +337,26 @@ def parse_output_optional_list(self): optional_list = [] for output_info in self.op_yaml_item['outputs']: if 'optional' in output_info: - optional_list.append(output_info['optional']) + if output_info['optional']: + optional_list.append("true") + else: + optional_list.append("false") else: - optional_list.append(False) + optional_list.append("false") return optional_list + def parse_output_intermediate_list(self): + intermediate_list = [] + for output_info in self.op_yaml_item['outputs']: + if 'intermediate' in output_info: + if output_info['intermediate']: + intermediate_list.append("true") + else: + intermediate_list.append("false") + else: + intermediate_list.append("false") + return intermediate_list + def parse_attribute_name_list(self): name_list = [] for attribute_info in self.op_yaml_item['attrs']: @@ -301,6 +398,15 @@ def parse_attribute_type_list(self): type_list.append(attr_types_map[attribute_info['typename']]) return type_list + def parse_attribute_data_type_list(self): + data_type_list = [] + for attribute_info in self.op_yaml_item['attrs']: + if 'data_type' in attribute_info: + data_type_list.append(attribute_info['data_type']) + else: + data_type_list.append("") + return data_type_list + def parse_op_phi_name(self): return self.op_yaml_item['name'] @@ -314,10 +420,11 @@ def to_pascal_case(s): # ===================================== -# Generate op definition files +# Generate Op Definition Files # ===================================== def OpGenerator( op_yaml_files, + op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, @@ -330,6 +437,8 @@ def OpGenerator( os.remove(op_def_cc_file) # (2) Prepare: Get all op item in all op_yaml_files + op_compat_parser = OpCompatParser(op_compat_yaml_file) + op_yaml_items = [] for yaml_file in op_yaml_files: with open(yaml_file, "r") as f: @@ -337,7 +446,9 @@ def OpGenerator( op_yaml_items = op_yaml_items + ops op_info_items = [] for op in op_yaml_items: - op_info_items.append(OpInfoParser(op)) + op_info_items.append( + OpInfoParser(op, op_compat_parser.get_compat(op['name'])) + ) # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list @@ -351,11 +462,14 @@ def OpGenerator( op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list op_input_optional_list = op_info.input_optional_list + op_input_no_need_buffer_list = op_info.input_no_need_buffer_list op_output_name_list = op_info.output_name_list op_output_type_list = op_info.output_type_list op_output_optional_list = op_info.output_optional_list + op_output_intermediate_list = op_info.output_intermediate_list op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list + op_attribute_data_type_list = op_info.attribute_data_type_list op_interfaces = [] op_traits = [] @@ -370,11 +484,16 @@ def OpGenerator( op_get_inputs_outputs_str = "" for idx in range(len(op_input_name_list)): op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( - input_name=op_input_name_list[idx], input_index=idx + input_name="input_" + str(idx) + "_" + op_input_name_list[idx], + input_index=idx, ) for idx in range(len(op_output_name_list)): op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( - output_name=op_output_name_list[idx], output_index=idx + output_name="output_" + + str(idx) + + "_" + + op_output_name_list[idx], + output_index=idx, ) # gen op_declare_str/op_defined_str @@ -410,6 +529,59 @@ def OpGenerator( attribute_names=attribute_names_str, ) + # generate get op info funciton: inputs + inputs_info_str = "" + if len(op_input_name_list) > 0: + input_info_list = [] + for idx in range(len(op_input_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_input_name_list[idx], + typename=op_input_type_list[idx], + optional=op_input_optional_list[idx], + no_need_buffer=op_input_no_need_buffer_list[idx], + ) + ) + inputs_info_str = ", ".join(input_info_list) + op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=inputs_info_str + ) + + # generate get op info funciton: outputs + outputs_info_str = "" + if len(op_output_name_list) > 0: + output_info_list = [] + for idx in range(len(op_output_name_list)): + output_info_list.append( + CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( + name=op_output_name_list[idx], + typename=op_output_type_list[idx], + optional=op_output_optional_list[idx], + intermediate=op_output_intermediate_list[idx], + ) + ) + outputs_info_str = ", ".join(output_info_list) + op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=outputs_info_str + ) + + # generate get op info funciton: attributes + attribute_info_str = "" + if len(op_attribute_name_list) > 0: + attribute_info_list = [] + for idx in range(len(op_attribute_name_list)): + attribute_info_list.append( + CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( + name=op_attribute_name_list[idx], + typename=op_attribute_type_list[idx], + data_type=op_attribute_data_type_list[idx], + ) + ) + attribute_info_str = ", ".join(attribute_info_list) + op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( + op_name=op_class_name, impl=attribute_info_str + ) + # generate op verify function: inputs_type_check_str if len(op_input_type_list) == 0: inputs_type_check_str = ( @@ -425,7 +597,7 @@ def OpGenerator( is_vector = True input_type = input_type[15:-1] check_str = "" - if is_optional: + if is_optional == "true": if is_vector: check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( index=idx, standard=input_type @@ -460,7 +632,7 @@ def OpGenerator( is_vector = True output_type = output_type[15:-1] check_str = "" - if is_optional: + if is_optional == "true": if is_vector: check_str = ( OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( @@ -515,6 +687,9 @@ def OpGenerator( ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) + ops_defined_list.append(op_inputs_info_func_str) + ops_defined_list.append(op_outputs_info_func_str) + ops_defined_list.append(op_attributes_info_func_str) ops_defined_list.append(op_verify_str) # (4) Generate head file str @@ -588,6 +763,7 @@ def ParseArguments(): # auto code generate OpGenerator( op_yaml_files, + op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index 9baeb4c1f9b1d0..bbb449464acda7 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -113,7 +113,7 @@ void PaddleDialect::initialize() { RegisterInterfaces(); RegisterOps(); } diff --git a/paddle/fluid/dialect/utils.h b/paddle/fluid/dialect/utils.h index 69af6fd1ddce80..85df76ed4e8f48 100644 --- a/paddle/fluid/dialect/utils.h +++ b/paddle/fluid/dialect/utils.h @@ -132,5 +132,45 @@ inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout( } } +struct OpInputInfo { + std::string name_ = ""; + std::string typename_ = ""; + bool optional_ = false; + bool no_need_buffer_ = false; + OpInputInfo(std::string name, + std::string type_name, + bool optional, + bool no_need_buffer) + : name_(name), + typename_(type_name), + optional_(optional), + no_need_buffer_(no_need_buffer) {} +}; + +struct OpOutputInfo { + std::string name_ = ""; + std::string typename_ = ""; + bool optional_ = false; + bool intermediate_ = false; + OpOutputInfo(std::string name, + std::string type_name, + bool optional, + bool intermediate) + : name_(name), + typename_(type_name), + optional_(optional), + intermediate_(intermediate) {} +}; + +struct OpAttributeInfo { + std::string name_ = ""; + std::string typename_ = ""; + std::string data_type_ = ""; + OpAttributeInfo(std::string name, + std::string type_name, + std::string data_type) + : name_(name), typename_(type_name), data_type_(data_type) {} +}; + } // namespace dialect } // namespace paddle From 1ba19dcdbff0cb460b4262541a771854b5807a32 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 26 May 2023 07:49:53 +0000 Subject: [PATCH 11/43] refine code --- paddle/fluid/dialect/legacy_pd_op.h | 45 +-- paddle/fluid/dialect/op_gen.py | 417 +++++++++++++++------------- paddle/fluid/dialect/pd_dialect.cc | 24 +- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- 4 files changed, 222 insertions(+), 266 deletions(-) diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h index 60a0a4b0871826..d835bf929e8d1a 100644 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ b/paddle/fluid/dialect/legacy_pd_op.h @@ -36,50 +36,9 @@ namespace dialect { // TODO(zhangbo): As operators are supplemented and defined, they are gradually // removed. -REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d -REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed -// REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: -// batch_norm +REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ -REIGSTER_EMPTY_OP(elementwise_add, - ElementwiseAddOp); // To be customized: add (elementwise_add) -REIGSTER_EMPTY_OP(pool2d, Pool2DOp); // To be customized: pool2d -REIGSTER_EMPTY_OP( - flatten_contiguous_range, - FlattenContiguousRangeOp); // flatten (flatten_contiguous_range) -REIGSTER_EMPTY_OP(matmul_v2, - MatmulV2Op); // To be customized: matmul (matmul_v2) -REIGSTER_EMPTY_OP(reshape2, Reshape2Op); // To be customized: reshape -REIGSTER_EMPTY_OP(softmax_with_cross_entropy, - SoftmaxWithCrossEntropyOp); // cross_entropy_with_softmax - // (softmax_with_cross_entropy) -REIGSTER_EMPTY_OP(reduce_mean, - ReduceMeanOp); // To be customized: mean (reduce_mean) -REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); // topk (top_k_v2) -REIGSTER_EMPTY_OP(fill_constant, - FillConstantOp); // To be customized: full (fill_constant) -REIGSTER_EMPTY_OP(reduce_mean_grad, - ReduceMeanGradOp); // To be customized: reduce_mean_grad -REIGSTER_EMPTY_OP( - softmax_with_cross_entropy_grad, - SoftmaxWithCrossEntropyGradOp); // cross_entropy_with_softmax_grad - // (softmax_with_cross_entropy_grad) -REIGSTER_EMPTY_OP( - elementwise_add_grad, - ElementwiseAddGradOp); // To be customized: add_grad (elementwise_add_grad) -REIGSTER_EMPTY_OP( - matmul_v2_grad, - MatmulV2GradOp); // To be customized: matmul_grad (matmul_v2_grad) -REIGSTER_EMPTY_OP( - flatten_contiguous_range_grad, - FlattenContiguousRangeGradOp); // flatten_grad - // (flatten_contiguous_range_grad) -REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad -// REIGSTER_EMPTY_OP(batch_norm_grad, -// BatchNormGradOp); // To be customized: batch_norm_grad -REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad -// REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) -REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 +REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 6abef860838c32..e0aec60bf81540 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -408,7 +408,21 @@ def parse_attribute_data_type_list(self): return data_type_list def parse_op_phi_name(self): - return self.op_yaml_item['name'] + if self.parse_op_inplace_info() is None: + return [self.op_yaml_item['name']] + else: + if self.op_yaml_item['name'][-1] == "_": + return [self.op_yaml_item['name']] + else: + return [ + self.op_yaml_item['name'], + self.op_yaml_item['name'] + "_", + ] + + def parse_op_inplace_info(self): + if 'inplace' in self.op_yaml_item: + return self.op_yaml_item['inplace'] + return None def to_pascal_case(s): @@ -456,9 +470,6 @@ def OpGenerator( ops_defined_list = [] # all op class defined store in this list for op_info in op_info_items: # get op info - op_name = op_info.op_phi_name - op_class_name = to_pascal_case(op_name) + "Op" - op_dialect_name = dialect_name + "." + op_name op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list op_input_optional_list = op_info.input_optional_list @@ -473,224 +484,232 @@ def OpGenerator( op_interfaces = [] op_traits = [] - # gen interface/trait str - op_interfaces_str = "" - if len(op_interfaces) > 0: - op_interfaces_str = "," + ",".join(op_interfaces) - op_traits_str = "" - if len(op_interfaces) > 0: - op_traits_str = "," + ",".join(op_traits) - - op_get_inputs_outputs_str = "" - for idx in range(len(op_input_name_list)): - op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( - input_name="input_" + str(idx) + "_" + op_input_name_list[idx], - input_index=idx, - ) - for idx in range(len(op_output_name_list)): - op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( - output_name="output_" - + str(idx) - + "_" - + op_output_name_list[idx], - output_index=idx, - ) - - # gen op_declare_str/op_defined_str - if len(op_attribute_name_list) == 0: - op_declare_str = OP_DECLARE_TEMPLATE.format( - op_name=op_class_name, - dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, - traits=op_traits_str, - attribute_declare=op_0_attribute_declare_str, - attribute_num=0, - get_inputs_and_outputs=op_get_inputs_outputs_str, - ) - op_defined_str = "" - else: - op_declare_str = OP_DECLARE_TEMPLATE.format( - op_name=op_class_name, - dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, - traits=op_traits_str, - attribute_declare=op_n_attribute_declare_str.format( - attribute_num=len(op_attribute_name_list) - ), - attribute_num=len(op_attribute_name_list), - get_inputs_and_outputs=op_get_inputs_outputs_str, - ) - attribute_names_str = ( - '"' + '", "'.join(op_attribute_name_list) + '"' - ) - op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( - op_name=op_class_name, - attribute_num=len(op_attribute_name_list), - attribute_names=attribute_names_str, - ) - - # generate get op info funciton: inputs - inputs_info_str = "" - if len(op_input_name_list) > 0: - input_info_list = [] + # If op has inplace info, we will generate inplace op and non-inplace op. + print("op_info.op_phi_name: ", op_info.op_phi_name) + for op_name in op_info.op_phi_name: + op_class_name = to_pascal_case(op_name) + "Op" + op_dialect_name = dialect_name + "." + op_name + + # gen interface/trait str + op_interfaces_str = "" + if len(op_interfaces) > 0: + op_interfaces_str = "," + ",".join(op_interfaces) + op_traits_str = "" + if len(op_interfaces) > 0: + op_traits_str = "," + ",".join(op_traits) + + op_get_inputs_outputs_str = "" for idx in range(len(op_input_name_list)): - input_info_list.append( - CONSTRUCT_INPUT_INFO_TEMPLATE.format( - name=op_input_name_list[idx], - typename=op_input_type_list[idx], - optional=op_input_optional_list[idx], - no_need_buffer=op_input_no_need_buffer_list[idx], - ) + op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( + input_name=op_input_name_list[idx], + input_index=idx, ) - inputs_info_str = ", ".join(input_info_list) - op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=inputs_info_str - ) - - # generate get op info funciton: outputs - outputs_info_str = "" - if len(op_output_name_list) > 0: - output_info_list = [] for idx in range(len(op_output_name_list)): - output_info_list.append( - CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( - name=op_output_name_list[idx], - typename=op_output_type_list[idx], - optional=op_output_optional_list[idx], - intermediate=op_output_intermediate_list[idx], - ) + op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( + output_name=op_output_name_list[idx], + output_index=idx, ) - outputs_info_str = ", ".join(output_info_list) - op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=outputs_info_str - ) - # generate get op info funciton: attributes - attribute_info_str = "" - if len(op_attribute_name_list) > 0: - attribute_info_list = [] - for idx in range(len(op_attribute_name_list)): - attribute_info_list.append( - CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( - name=op_attribute_name_list[idx], - typename=op_attribute_type_list[idx], - data_type=op_attribute_data_type_list[idx], - ) + # gen op_declare_str/op_defined_str + if len(op_attribute_name_list) == 0: + op_declare_str = OP_DECLARE_TEMPLATE.format( + op_name=op_class_name, + dialect_op_name=op_dialect_name, + interfaces=op_interfaces_str, + traits=op_traits_str, + attribute_declare=op_0_attribute_declare_str, + attribute_num=0, + get_inputs_and_outputs=op_get_inputs_outputs_str, + ) + op_defined_str = "" + else: + op_declare_str = OP_DECLARE_TEMPLATE.format( + op_name=op_class_name, + dialect_op_name=op_dialect_name, + interfaces=op_interfaces_str, + traits=op_traits_str, + attribute_declare=op_n_attribute_declare_str.format( + attribute_num=len(op_attribute_name_list) + ), + attribute_num=len(op_attribute_name_list), + get_inputs_and_outputs=op_get_inputs_outputs_str, + ) + attribute_names_str = ( + '"' + '", "'.join(op_attribute_name_list) + '"' + ) + op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( + op_name=op_class_name, + attribute_num=len(op_attribute_name_list), + attribute_names=attribute_names_str, ) - attribute_info_str = ", ".join(attribute_info_list) - op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( - op_name=op_class_name, impl=attribute_info_str - ) - # generate op verify function: inputs_type_check_str - if len(op_input_type_list) == 0: - inputs_type_check_str = ( - "// Inputs num is 0, not need to check inputs type." + # generate get op info funciton: inputs + inputs_info_str = "" + if len(op_input_name_list) > 0: + input_info_list = [] + for idx in range(len(op_input_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_input_name_list[idx], + typename=op_input_type_list[idx], + optional=op_input_optional_list[idx], + no_need_buffer=op_input_no_need_buffer_list[idx], + ) + ) + inputs_info_str = ", ".join(input_info_list) + op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=inputs_info_str ) - else: - inputs_type_check_str = "" - for idx in range(len(op_input_type_list)): - input_type = op_input_type_list[idx] - is_optional = op_input_optional_list[idx] - is_vector = False - if input_type.startswith("ir::VectorType<"): - is_vector = True - input_type = input_type[15:-1] - check_str = "" - if is_optional == "true": - if is_vector: - check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type + + # generate get op info funciton: outputs + outputs_info_str = "" + if len(op_output_name_list) > 0: + output_info_list = [] + for idx in range(len(op_output_name_list)): + output_info_list.append( + CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( + name=op_output_name_list[idx], + typename=op_output_type_list[idx], + optional=op_output_optional_list[idx], + intermediate=op_output_intermediate_list[idx], + ) ) - else: - check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type + outputs_info_str = ", ".join(output_info_list) + op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( + op_name=op_class_name, impl=outputs_info_str + ) + + # generate get op info funciton: attributes + attribute_info_str = "" + if len(op_attribute_name_list) > 0: + attribute_info_list = [] + for idx in range(len(op_attribute_name_list)): + attribute_info_list.append( + CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( + name=op_attribute_name_list[idx], + typename=op_attribute_type_list[idx], + data_type=op_attribute_data_type_list[idx], + ) ) + attribute_info_str = ", ".join(attribute_info_list) + op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( + op_name=op_class_name, impl=attribute_info_str + ) + + # generate op verify function: inputs_type_check_str + if len(op_input_type_list) == 0: + inputs_type_check_str = ( + "// Inputs num is 0, not need to check inputs type." + ) else: - if is_vector: - check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) + inputs_type_check_str = "" + for idx in range(len(op_input_type_list)): + input_type = op_input_type_list[idx] + is_optional = op_input_optional_list[idx] + is_vector = False + if input_type.startswith("ir::VectorType<"): + is_vector = True + input_type = input_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = ( + INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + ) + else: + check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) else: - check_str = INPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) - inputs_type_check_str += check_str + if is_vector: + check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + else: + check_str = INPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + inputs_type_check_str += check_str - # generate op verify function: outputs_type_check_str - if len(op_output_type_list) == 0: - outputs_type_check_str = ( - "// Outputs num is 0, not need to check outputs type." - ) - else: - outputs_type_check_str = "" - for idx in range(len(op_output_type_list)): - output_type = op_output_type_list[idx] - is_optional = op_output_optional_list[idx] - is_vector = False - if output_type.startswith("ir::VectorType<"): - is_vector = True - output_type = output_type[15:-1] - check_str = "" - if is_optional == "true": - if is_vector: - check_str = ( - OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + # generate op verify function: outputs_type_check_str + if len(op_output_type_list) == 0: + outputs_type_check_str = ( + "// Outputs num is 0, not need to check outputs type." + ) + else: + outputs_type_check_str = "" + for idx in range(len(op_output_type_list)): + output_type = op_output_type_list[idx] + is_optional = op_output_optional_list[idx] + is_vector = False + if output_type.startswith("ir::VectorType<"): + is_vector = True + output_type = output_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = ( + OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + ) + else: + check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( index=idx, standard=output_type ) - ) else: - check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type - ) + if is_vector: + check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + else: + check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + outputs_type_check_str += check_str + + # generate op verify function: attributes_check_str + if len(op_attribute_name_list) == 0: + attributes_check_str = ( + "// Attributes num is 0, not need to check attributes type." + ) else: - if is_vector: - check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type + attributes_check_str = "" + for idx in range(len(op_attribute_name_list)): + attribute_name = op_attribute_name_list[idx] + attribute_type = op_attribute_type_list[idx] + if attribute_type.startswith("ir::ArrayAttribute<"): + attribute_type = attribute_type[19:-1] + attributes_check_str += ( + ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( + attribute_name=attribute_name, + standard=attribute_type, + ) ) else: - check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type + attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( + attribute_name=attribute_name, standard=attribute_type ) - outputs_type_check_str += check_str - # generate op verify function: attributes_check_str - if len(op_attribute_name_list) == 0: - attributes_check_str = ( - "// Attributes num is 0, not need to check attributes type." + # generate op verify function + op_verify_str = OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + inputs_size=len(op_input_type_list), + outputs_size=len(op_output_type_list), + inputs_type_check=inputs_type_check_str, + outputs_type_check=outputs_type_check_str, + attributes_check=attributes_check_str, ) - else: - attributes_check_str = "" - for idx in range(len(op_attribute_name_list)): - attribute_name = op_attribute_name_list[idx] - attribute_type = op_attribute_type_list[idx] - if attribute_type.startswith("ir::ArrayAttribute<"): - attribute_type = attribute_type[19:-1] - attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( - attribute_name=attribute_name, standard=attribute_type - ) - else: - attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( - attribute_name=attribute_name, standard=attribute_type - ) - - # generate op verify function - op_verify_str = OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - inputs_size=len(op_input_type_list), - outputs_size=len(op_output_type_list), - inputs_type_check=inputs_type_check_str, - outputs_type_check=outputs_type_check_str, - attributes_check=attributes_check_str, - ) - ops_name_list.append(op_class_name) - ops_declare_list.append(op_declare_str) - ops_defined_list.append(op_defined_str) - ops_defined_list.append(op_inputs_info_func_str) - ops_defined_list.append(op_outputs_info_func_str) - ops_defined_list.append(op_attributes_info_func_str) - ops_defined_list.append(op_verify_str) + ops_name_list.append(op_class_name) + ops_declare_list.append(op_declare_str) + ops_defined_list.append(op_defined_str) + ops_defined_list.append(op_inputs_info_func_str) + ops_defined_list.append(op_outputs_info_func_str) + ops_defined_list.append(op_attributes_info_func_str) + ops_defined_list.append(op_verify_str) # (4) Generate head file str op_namespaces_prev = "" diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index bbb449464acda7..05ec2001eaa88e 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -111,29 +111,7 @@ void PaddleDialect::initialize() { >(); RegisterInterfaces(); - RegisterOps(); + RegisterOps(); } void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4dba3fdd74ec80..e4ab20dc5ee8c3 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -296,7 +296,7 @@ - op : einsum args : (Tensor[] x, str equation) - output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} + output : Tensor(out), Tensor[](inner_cache){x.size()}, Tensor[](xshape){x.size()} infer_meta : func : EinsumRawInferMeta param : [x, equation] From 6c2ac5da6cc40702d9ab6ab16a8c6851f809a02a Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 26 May 2023 12:27:37 +0000 Subject: [PATCH 12/43] add interface --- paddle/fluid/dialect/CMakeLists.txt | 3 +- paddle/fluid/dialect/legacy_pd_op.h | 44 ---------------------- paddle/fluid/dialect/op_gen.py | 37 +++++++++--------- paddle/fluid/dialect/pd_dialect.cc | 2 - paddle/fluid/dialect/pd_interface.h | 58 +++++++++++++++++++++++++++++ paddle/fluid/dialect/pd_op.yaml | 52 ++++++++++++++++++++++++++ test/cpp/ir/ir_program_test.cc | 14 ++++++- 7 files changed, 145 insertions(+), 65 deletions(-) delete mode 100644 paddle/fluid/dialect/legacy_pd_op.h create mode 100644 paddle/fluid/dialect/pd_interface.h create mode 100644 paddle/fluid/dialect/pd_op.yaml diff --git a/paddle/fluid/dialect/CMakeLists.txt b/paddle/fluid/dialect/CMakeLists.txt index 3f17ac21724c66..d041f2554887df 100644 --- a/paddle/fluid/dialect/CMakeLists.txt +++ b/paddle/fluid/dialect/CMakeLists.txt @@ -16,8 +16,9 @@ set(op_backward_yaml_file1 set(op_backward_yaml_file2 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) +set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/dialect/pd_op.yaml) set(op_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2} + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3} ) set(op_namespace paddle,dialect) set(dialect_name pd) diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h deleted file mode 100644 index d835bf929e8d1a..00000000000000 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ir/op_base.h" - -namespace paddle { -namespace dialect { - -#define OPNAME(op_name) "pd." #op_name - -#define REIGSTER_EMPTY_OP(op_name, className) \ - class className : public ir::Op { \ - public: \ - static const char *name() { return OPNAME(op_name); } \ - static constexpr const char **attributes_name = nullptr; \ - static constexpr uint32_t attributes_num = 0; \ - static void verify(const std::vector &inputs, \ - const std::vector &outputs, \ - const ir::AttributeMap &attributes) { \ - LOG(WARNING) << "This is a fake verify"; \ - } \ - }; - -// TODO(zhangbo): As operators are supplemented and defined, they are gradually -// removed. -REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed -REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ -REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 - -} // namespace dialect -} // namespace paddle diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index e0aec60bf81540..a9ae9739376a73 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -33,6 +33,7 @@ #include "paddle/ir/op_base.h" #include "paddle/fluid/dialect/utils.h" +#include "paddle/fluid/dialect/pd_interface.h" {input} #endif @@ -48,9 +49,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static const char *name() {{ return "{dialect_op_name}"; }} {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; - static std::vector inputs_info(); - static std::vector outputs_info(); - static std::vector attributes_info(); + static OpInfoTuple GetOpInfo(); static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); {get_inputs_and_outputs} }}; @@ -86,6 +85,15 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ """ # get op input info +OP_INFO_TEMPLATE = """ +OpInfoTuple {op_name}::GetOpInfo() {{ + std::vector inputs = {{ {inputs} }}; + std::vector attributes = {{ {attributes} }}; + std::vector outputs = {{ {outputs} }}; + return std::make_tuple(inputs, attributes, outputs); +}} +""" + OP_INPUT_INFO_TEMPLATE = """ std::vector {op_name}::inputs_info() {{ return {{ {impl} }}; @@ -481,11 +489,10 @@ def OpGenerator( op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list op_attribute_data_type_list = op_info.attribute_data_type_list - op_interfaces = [] + op_interfaces = ["GetOpInfoInterface"] op_traits = [] # If op has inplace info, we will generate inplace op and non-inplace op. - print("op_info.op_phi_name: ", op_info.op_phi_name) for op_name in op_info.op_phi_name: op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name @@ -495,7 +502,7 @@ def OpGenerator( if len(op_interfaces) > 0: op_interfaces_str = "," + ",".join(op_interfaces) op_traits_str = "" - if len(op_interfaces) > 0: + if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) op_get_inputs_outputs_str = "" @@ -557,9 +564,6 @@ def OpGenerator( ) ) inputs_info_str = ", ".join(input_info_list) - op_inputs_info_func_str = OP_INPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=inputs_info_str - ) # generate get op info funciton: outputs outputs_info_str = "" @@ -575,9 +579,6 @@ def OpGenerator( ) ) outputs_info_str = ", ".join(output_info_list) - op_outputs_info_func_str = OP_OUTPUT_INFO_TEMPLATE.format( - op_name=op_class_name, impl=outputs_info_str - ) # generate get op info funciton: attributes attribute_info_str = "" @@ -592,8 +593,12 @@ def OpGenerator( ) ) attribute_info_str = ", ".join(attribute_info_list) - op_attributes_info_func_str = OP_ATTRIBUTE_INFO_TEMPLATE.format( - op_name=op_class_name, impl=attribute_info_str + + op_info_func_str = OP_INFO_TEMPLATE.format( + op_name=op_class_name, + inputs=inputs_info_str, + attributes=attribute_info_str, + outputs=outputs_info_str, ) # generate op verify function: inputs_type_check_str @@ -706,9 +711,7 @@ def OpGenerator( ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) - ops_defined_list.append(op_inputs_info_func_str) - ops_defined_list.append(op_outputs_info_func_str) - ops_defined_list.append(op_attributes_info_func_str) + ops_defined_list.append(op_info_func_str) ops_defined_list.append(op_verify_str) # (4) Generate head file str diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index 05ec2001eaa88e..664d3873828a50 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/dialect/pd_attribute.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/dialect/CMakeLists.txt. -#include "paddle/fluid/dialect/legacy_pd_op.h" #include "paddle/fluid/dialect/pd_op.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/pd_type_storage.h" @@ -111,7 +110,6 @@ void PaddleDialect::initialize() { >(); RegisterInterfaces(); - RegisterOps(); } void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h new file mode 100644 index 00000000000000..25c2397c00d1c1 --- /dev/null +++ b/paddle/fluid/dialect/pd_interface.h @@ -0,0 +1,58 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/dialect/utils.h" +#include "paddle/ir/op_base.h" + +using OpInfoTuple = std::tuple, + std::vector, + std::vector>; + +namespace paddle { +namespace dialect { +class GetOpInfoInterface : public ir::OpInterfaceBase { + public: + struct Concept { + explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) + : get_op_info_(get_op_info) {} + OpInfoTuple (*get_op_info_)(ir::Operation *); + }; + + template + struct Model : public Concept { + static OpInfoTuple GetOpInfo(ir::Operation *op) { + ConcreteOp concret_op = ConcreteOp(op); + if (concret_op == nullptr) throw("concret_op is nullptr"); + return concret_op.GetOpInfo(); + } + + Model() : Concept(GetOpInfo) { + static_assert(sizeof(Model) == sizeof(Concept), + "sizeof(Model) != sizeof(Concept)"); + } + }; + + GetOpInfoInterface(ir::Operation *op, Concept *impl) + : ir::OpInterfaceBase(op), impl_(impl) {} + + OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } + + private: + Concept *impl_; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/dialect/pd_op.yaml b/paddle/fluid/dialect/pd_op.yaml new file mode 100644 index 00000000000000..7ca8646a038c6c --- /dev/null +++ b/paddle/fluid/dialect/pd_op.yaml @@ -0,0 +1,52 @@ +- name: feed + inputs: + - typename: Tensor[] + name: x + optional: false + no_need_buffer: false + data_transform: {} + attrs: + - {typename: int, name: col} + outputs: + - {typename: Tensor, name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: null + param: null + kernel: + func: null + param: null + backend: null + layout: null + data_type: null + dispatch: null + force_backend: null + inplace: null + backward: null +- name: fetch + inputs: + - typename: Tensor + name: x + optional: false + no_need_buffer: false + data_transform: {} + attrs: + - {typename: int, name: col} + outputs: + - {typename: 'Tensor[]', name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: null + param: null + kernel: + func: null + param: null + backend: null + layout: null + data_type: null + dispatch: null + force_backend: null + inplace: null + backward: null diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index 9fb72fec13c23b..0c443f122738a5 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -15,6 +15,7 @@ #include #include "paddle/fluid/dialect/pd_dialect.h" +#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/utils.h" #include "paddle/ir/builtin_attribute.h" @@ -177,7 +178,18 @@ TEST(program_test, program) { EXPECT_EQ(*(dst_tensor->data() + i), data_a[i] + data_b[i]); } - // (7) Def SetParameterOp(c, "c") + // (7) Def AbsOp(b) + ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); + std::unordered_map abs_op_attribute; + ir::Operation *abs_op = ir::Operation::create({op1->GetResultByIndex(0)}, + {dense_tensor_dtype}, + abs_op_attribute, + abs_info); + paddle::dialect::GetOpInfoInterface interface = + abs_op->dyn_cast(); + EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name_ == "x", true); + + // (8) Def SetParameterOp(c, "c") std::string op4_name = builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name()); ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name); From 6967ebacace89cfeee9ccd96209be825bbe4b153 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 29 May 2023 11:09:23 +0000 Subject: [PATCH 13/43] fix: op name normalization --- paddle/fluid/dialect/pd_op.yaml | 14 +++------- .../fluid/translator/attribute_translator.cc | 4 +-- paddle/fluid/translator/op_compat_gen.py | 4 +++ paddle/fluid/translator/op_translator.cc | 19 +++++++++++-- paddle/phi/api/yaml/op_compat.yaml | 11 +++++--- test/cpp/ir/program_translator_test.cc | 28 +++++++++++-------- 6 files changed, 50 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/dialect/pd_op.yaml b/paddle/fluid/dialect/pd_op.yaml index 7ca8646a038c6c..a8053b0d1cb83f 100644 --- a/paddle/fluid/dialect/pd_op.yaml +++ b/paddle/fluid/dialect/pd_op.yaml @@ -1,12 +1,7 @@ - name: feed - inputs: - - typename: Tensor[] - name: x - optional: false - no_need_buffer: false - data_transform: {} + inputs: [] attrs: - - {typename: int, name: col} + - {typename: str, name: name} outputs: - {typename: Tensor, name: out, optional: false, intermediate: false} no_need_buffer: null @@ -32,9 +27,8 @@ no_need_buffer: false data_transform: {} attrs: - - {typename: int, name: col} - outputs: - - {typename: 'Tensor[]', name: out, optional: false, intermediate: false} + - {typename: str, name: name} + outputs: [] no_need_buffer: null data_transform: null infer_meta: diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc index ababb7e3b50609..554b1f3ac7f3b3 100644 --- a/paddle/fluid/translator/attribute_translator.cc +++ b/paddle/fluid/translator/attribute_translator.cc @@ -31,7 +31,7 @@ class AttributeVisitor { ~AttributeVisitor() {} public: - ir::Attribute operator()(int i) { return ir::Int64_tAttribute::get(ctx, i); } + ir::Attribute operator()(int i) { return ir::Int32_tAttribute::get(ctx, i); } ir::Attribute operator()(float f) { return ir::FloatAttribute::get(ctx, f); } @@ -71,7 +71,7 @@ class AttributeVisitor { std::vector attrs; attrs.reserve(is.size()); for (const auto& v : is) { - attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); + attrs.push_back(ir::Int32_tAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index d6aeeeaf8ee4bf..85b50d1d8de5b7 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -57,6 +57,10 @@ def insert_new_mappings(op_name_str): insert_new_mappings(op_compat_item["op"]) if "backward" in op_compat_item: insert_new_mappings(op_compat_item["backward"]) + + # special op mappings + op_name_mappings["fetch_v2"] = "fetch" + op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 1451115a71d423..cdb3cd19ca7b37 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -47,8 +47,15 @@ using OpOutputMapping = std::unordered_map; static const char kTargetDialectPrefix[] = "pd."; +static const std::unordered_set special_inplace_ops = { + "batch_norm", +}; + inline bool IsInplace(const OpDesc& op_desc) { bool inplace = false; + if (special_inplace_ops.count(op_desc.Type())) { + return inplace; + } auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); @@ -319,10 +326,14 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); auto attribute_map = TranslateOpAttribute(op_desc); + VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; ir::Operation* operation = ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); + VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end."; program->InsertOp(operation); + + VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end."; RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); return operation; @@ -338,7 +349,9 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc); - auto attribute_map = TranslateOpAttribute(op_desc); + ir::AttributeMap attribute_map = { + {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, + }; ir::Operation* operation = ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); @@ -356,7 +369,9 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types; auto op_info = LoopkUpOpInfo(ctx, op_desc); - auto attribute_map = TranslateOpAttribute(op_desc); + ir::AttributeMap attribute_map = { + {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, + }; ir::Operation* operation = ir::Operation::create(op_inputs, op_output_types, attribute_map, op_info); diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index bbe3017e27eba7..f6dfe19ec0c5f3 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -14,6 +14,10 @@ # extra : # attrs : [bool is_test = false] + extra : + attrs : [bool use_mkldnn = false, bool use_quantizer = false, + str mkldnn_data_type = "float32", bool is_test = false] + - op : abs backward : abs_grad inputs : @@ -1532,7 +1536,7 @@ out : Out - op : mean (reduce_mean) - backward : reduce_mean_grad + backward : mean_grad (reduce_mean_grad) inputs : x : X outputs : @@ -1775,9 +1779,8 @@ - op : pool2d backward : pool2d_grad - extra : - attrs : [bool use_mkldnn = false, bool use_quantizer = false, - str mkldnn_data_type = "float32", bool is_test = false] + attrs: + kernel_size: ksize - op : pool3d backward : pool3d_grad diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index 0035f860c5861b..cf2fe272e9c7e8 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -48,16 +48,20 @@ ProgramDesc load_from_file(const std::string &file_name) { TEST(PaddleDialectTest, Translator) { LOG(WARNING) << "TODO"; - // auto p = load_from_file("restnet50_main.prog"); - // EXPECT_EQ(p.Size(), 1u); - - // ir::IrContext *ctx = ir::IrContext::Instance(); - // ctx->GetOrRegisterDialect(); - // ctx->GetOrRegisterDialect(); - // auto program = paddle::TranslateLegacyProgramToProgram(p); - - // size_t op_size = program->block()->size(); - // // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); - // VLOG(0) << *program; + auto p = load_from_file("restnet50_main.prog"); + EXPECT_EQ(p.Size(), 1u); + + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + size_t op_size = program->block()->size(); + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); + + std::ofstream ostrm( + "/home/lvyongkang/Paddle/build/log_resnet_after_translated", + std::ios::out); + ostrm << *program; } From df0b5cd124b2ee9242b0e67a7f0a10d991f5202c Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 29 May 2023 11:25:42 +0000 Subject: [PATCH 14/43] fix typo --- paddle/phi/api/yaml/op_compat.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f6dfe19ec0c5f3..8dad83915075cf 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -14,10 +14,6 @@ # extra : # attrs : [bool is_test = false] - extra : - attrs : [bool use_mkldnn = false, bool use_quantizer = false, - str mkldnn_data_type = "float32", bool is_test = false] - - op : abs backward : abs_grad inputs : From f1c02dd8b9687abd483e81c8b6a3d275048c3061 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 07:23:18 +0000 Subject: [PATCH 15/43] refactor input translator --- paddle/fluid/dialect/pd_interface.h | 12 +-- paddle/fluid/translator/op_compat_gen.py | 43 ++++++++- paddle/fluid/translator/op_compat_info.cc.j2 | 14 ++- paddle/fluid/translator/op_compat_info.h | 16 ++++ paddle/fluid/translator/op_translator.cc | 99 +++++++++++++++----- paddle/fluid/translator/utils.h | 41 ++++++++ paddle/ir/builtin_op.cc | 17 ++++ paddle/ir/builtin_op.h | 14 +++ 8 files changed, 220 insertions(+), 36 deletions(-) create mode 100644 paddle/fluid/translator/utils.h diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h index 25c2397c00d1c1..1dac2f86ec35ca 100644 --- a/paddle/fluid/dialect/pd_interface.h +++ b/paddle/fluid/dialect/pd_interface.h @@ -26,18 +26,14 @@ namespace dialect { class GetOpInfoInterface : public ir::OpInterfaceBase { public: struct Concept { - explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) + explicit Concept(OpInfoTuple (*get_op_info)()) : get_op_info_(get_op_info) {} - OpInfoTuple (*get_op_info_)(ir::Operation *); + OpInfoTuple (*get_op_info_)(); }; template struct Model : public Concept { - static OpInfoTuple GetOpInfo(ir::Operation *op) { - ConcreteOp concret_op = ConcreteOp(op); - if (concret_op == nullptr) throw("concret_op is nullptr"); - return concret_op.GetOpInfo(); - } + static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); } Model() : Concept(GetOpInfo) { static_assert(sizeof(Model) == sizeof(Concept), @@ -48,7 +44,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase { GetOpInfoInterface(ir::Operation *op, Concept *impl) : ir::OpInterfaceBase(op), impl_(impl) {} - OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } + OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); } private: Concept *impl_; diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index 85b50d1d8de5b7..e8d32774cb4517 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -14,6 +14,7 @@ import argparse from pathlib import Path +from typing import Dict import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined @@ -46,17 +47,48 @@ def to_phi_and_fluid_op_name(op_item): with open(op_compat_yaml_file, "r") as f: op_compat_infos = yaml.safe_load(f) op_name_mappings = {} + op_arg_name_mappings = {} for op_compat_item in op_compat_infos: - def insert_new_mappings(op_name_str): + def insert_new_mappings(op_name_str: str) -> str: normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) if normalized_name == legacy_name: - return + return normalized_name op_name_mappings[legacy_name] = normalized_name + return normalized_name + + def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): + if op_name is None: + return + if op_name not in op_arg_name_mappings: + op_arg_name_mappings[op_name] = {} + op_arg_name_mappings[op_name].update(arg_mapping) - insert_new_mappings(op_compat_item["op"]) + normalized_op_name = insert_new_mappings(op_compat_item["op"]) + normailized_backward_op_name = None if "backward" in op_compat_item: - insert_new_mappings(op_compat_item["backward"]) + normailized_backward_op_name = insert_new_mappings( + op_compat_item["backward"] + ) + if "inputs" in op_compat_item: + insert_new_arg_mappings( + normalized_op_name, op_compat_item["inputs"] + ) + insert_new_arg_mappings( + normailized_backward_op_name, op_compat_item["inputs"] + ) + if "attrs" in op_compat_item: + insert_new_arg_mappings(normalized_op_name, op_compat_item["attrs"]) + insert_new_arg_mappings( + normailized_backward_op_name, op_compat_item["attrs"] + ) + if "outputs" in op_compat_item: + insert_new_arg_mappings( + normalized_op_name, op_compat_item["outputs"] + ) + insert_new_arg_mappings( + normailized_backward_op_name, op_compat_item["outputs"] + ) # special op mappings op_name_mappings["fetch_v2"] = "fetch" @@ -64,7 +96,8 @@ def insert_new_mappings(op_name_str): op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( - op_name_paris=op_name_mappings + op_name_pairs=op_name_mappings, + op_arg_name_pairs=op_arg_name_mappings, ) f.write(op_compat_definition) diff --git a/paddle/fluid/translator/op_compat_info.cc.j2 b/paddle/fluid/translator/op_compat_info.cc.j2 index af42cf9b8abdc3..a44941595fbb8d 100644 --- a/paddle/fluid/translator/op_compat_info.cc.j2 +++ b/paddle/fluid/translator/op_compat_info.cc.j2 @@ -5,10 +5,22 @@ namespace translator { OpNameNormalizer::OpNameNormalizer() { op_name_mappings = { - {% for legacy_name, normalized_name in op_name_paris.items() %} + {% for legacy_name, normalized_name in op_name_pairs.items() %} { "{{legacy_name}}", "{{normalized_name}}" }, {% endfor %} }; + op_arg_name_mappings = { + {% for op_name, arg_name_mappings in op_arg_name_pairs.items() %} + { + "{{op_name}}", + { + {% for normalized_name, legacy_name in arg_name_mappings.items() %} + { "{{normalized_name}}", "{{legacy_name}}" }, + {% endfor %} + }, + }, + {% endfor %} + }; } } // namespace translator diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 86acafe7a0f1a5..52258a5b08b9ba 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -17,6 +17,8 @@ #include "glog/logging.h" +#include "paddle/fluid/translator/utils.h" + #pragma once namespace paddle { @@ -26,6 +28,8 @@ class OpNameNormalizer { private: OpNameNormalizer(); // Disallow instantiation outside of the class. std::unordered_map op_name_mappings; + std::unordered_map> + op_arg_name_mappings; public: OpNameNormalizer(const OpNameNormalizer&) = delete; @@ -44,6 +48,18 @@ class OpNameNormalizer { } return op_name_mappings.at(op_type); } + + std::string GetLegacyArgName(const std::string& op_type, + const std::string& arg_name) { + if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { + return UnderscoreToCamelCase(arg_name); + } + auto& arg_mappings = op_arg_name_mappings[op_type]; + if (arg_mappings.find(arg_name) == arg_mappings.end()) { + return UnderscoreToCamelCase(arg_name); + } + return arg_mappings.at(op_type); + } }; } // namespace translator diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index cdb3cd19ca7b37..d22b430b8569d6 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -15,12 +15,14 @@ #include "paddle/fluid/translator/op_translator.h" #include +#include #include #include #include #include #include +#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/translator/attribute_translator.h" #include "paddle/fluid/translator/op_compat_info.h" @@ -44,6 +46,12 @@ using BlockDesc = paddle::framework::BlockDesc; using VarDesc = paddle::framework::VarDesc; using OpOutputTypeList = std::vector; using OpOutputMapping = std::unordered_map; +using OpInputInfo = paddle::dialect::OpInputInfo; +using OpInputInfoList = std::vector; +using OpAttributeInfo = paddle::dialect::OpAttributeInfo; +using OpAttributeInfoList = std::vector; +using OpOutputInfo = paddle::dialect::OpOutputInfo; +using OpOutputInfoList = std::vector; static const char kTargetDialectPrefix[] = "pd."; @@ -150,13 +158,25 @@ inline ir::Operation* InsertCombineOperationForTarget( return operation; } +inline ir::Operation* InsertConstantOperationForOptionalArg( + ir::IrContext* ctx, ir::Program* program) { + std::string constant_op_name(ir::ConstantOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); + + ir::Type null_type = ir::Type(nullptr); + ir::Operation* operation = + ir::Operation::create({}, {null_type}, {}, op_info); + program->InsertOp(operation); + return operation; +} + inline std::vector GenerateOperationInput( ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, - const OpDesc& op_desc) { - std::vector op_inputs; - + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos) { // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { @@ -168,7 +188,7 @@ inline std::vector GenerateOperationInput( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s.%s as input should be exists before prasing %d", + "arg %s.%s as input should be exists before prasing %s", name, arg_name, op_desc.Type())); @@ -180,27 +200,41 @@ inline std::vector GenerateOperationInput( } } - for (const auto& n : op_desc.Inputs()) { - auto& name = n.first; - VLOG(10) << "[input retriving]" - << "[" << op_desc.Type() << "]" << name; - auto& args = n.second; + std::vector op_inputs; + auto& op_normalizer = OpNameNormalizer::instance(); - // if src type is Tensor or a Vector with size <= 1 - if (args.size() <= 1) { - for (const auto& arg_name : args) { - auto defining_info = (*param_map)[arg_name]; - op_inputs.push_back(defining_info.value); - } + for (const auto& info : input_infos) { + std::string legacy_input_name = + op_normalizer.GetLegacyArgName(normalized_op_name, info.name_); + + // return empty type if this arg is optional and not shown in OpDesc + // TODO(lyk): HasInput doesnot consider variadic attribute + if (!op_desc.HasInput(legacy_input_name)) { + PADDLE_ENFORCE(info.optional_, + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_input_name); + auto* constant_op = InsertConstantOperationForOptionalArg(ctx, program); + op_inputs.push_back(constant_op->GetResultByIndex(0)); + continue; + } + + const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); + bool is_vector = (info.typename_.find("VectorType") != std::string::npos); + // if src type is Tensor + if (!is_vector) { + auto defining_info = (*param_map)[legacy_input_vars[0]]; + op_inputs.push_back(defining_info.value); // if src type is Vector , need an additional `CombineOp` to // assemble them. } else { - auto* combine_op = - InsertCombineOperationForTarget(ctx, param_map, program, args); + auto* combine_op = InsertCombineOperationForTarget( + ctx, param_map, program, legacy_input_vars); op_inputs.push_back(combine_op->GetResultByIndex(0)); } } + return op_inputs; } @@ -319,12 +353,23 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto* op_info_concept = + op_info.GetInterfaceImpl(); + + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos) = + op_info_concept->get_op_info_(); + + auto op_inputs = GenerateOperationInput( + ctx, param_map, program, op_desc, op_info.name(), input_infos); OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto attribute_map = TranslateOpAttribute(op_desc); VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; @@ -343,12 +388,12 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { + auto op_info = LoopkUpOpInfo(ctx, op_desc); std::vector op_inputs; OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::AttributeMap attribute_map = { {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, }; @@ -365,10 +410,20 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); + + auto* op_info_concept = + op_info.GetInterfaceImpl(); + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos) = + op_info_concept->get_op_info_(); + + auto op_inputs = GenerateOperationInput( + ctx, param_map, program, op_desc, op_info.name(), input_infos); OpOutputTypeList op_output_types; - auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::AttributeMap attribute_map = { {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, }; diff --git a/paddle/fluid/translator/utils.h b/paddle/fluid/translator/utils.h new file mode 100644 index 00000000000000..5711f1f550fb14 --- /dev/null +++ b/paddle/fluid/translator/utils.h @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace translator { + +static std::string UnderscoreToCamelCase(std::string str) { + std::string camel_case; + bool next_upper = true; + for (char c : str) { + if (c == '_') { + next_upper = true; + } else { + if (next_upper) { + camel_case += toupper(c); + next_upper = false; + } else { + camel_case += c; + } + } + } + return camel_case; +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/ir/builtin_op.cc b/paddle/ir/builtin_op.cc index 63bfc2196dca35..2b4900d6a8ff93 100644 --- a/paddle/ir/builtin_op.cc +++ b/paddle/ir/builtin_op.cc @@ -160,4 +160,21 @@ void SliceOp::verify(const std::vector &inputs, outputs[0])); } +void ConstantOp::verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + // outputs.size() == 1 + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); + // inputs.size() == 0 + PADDLE_ENFORCE_EQ( + inputs.size(), + 0, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", outputs.size())); +} + } // namespace ir diff --git a/paddle/ir/builtin_op.h b/paddle/ir/builtin_op.h index d1d2c20b5725db..b4042e3abc0f53 100644 --- a/paddle/ir/builtin_op.h +++ b/paddle/ir/builtin_op.h @@ -81,4 +81,18 @@ class SliceOp : public ir::Op { const ir::AttributeMap &attributes); }; +class ConstantOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.constant"; } + + static constexpr uint32_t attributes_num = 0; + + static constexpr const char **attributes_name = nullptr; + static void verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes); +}; + } // namespace ir From 0854418515640d005bab7587316646c638132070 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 08:01:47 +0000 Subject: [PATCH 16/43] fix merge conflicts --- paddle/fluid/translator/attribute_translator.h | 6 +++--- paddle/fluid/translator/op_translator.cc | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h index 9e7117d5db6412..ed64c27b120eb1 100644 --- a/paddle/fluid/translator/attribute_translator.h +++ b/paddle/fluid/translator/attribute_translator.h @@ -14,9 +14,9 @@ #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/type_defs.h" -#include "paddle/ir/attribute.h" -#include "paddle/ir/builtin_attribute.h" -#include "paddle/ir/ir_context.h" +#include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/ir_context.h" #pragma once diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 9b3dd54d802dbe..42a9d9b05748cf 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -165,7 +165,7 @@ inline ir::Operation* InsertConstantOperationForOptionalArg( ir::Type null_type = ir::Type(nullptr); ir::Operation* operation = - ir::Operation::create({}, {null_type}, {}, op_info); + ir::Operation::create({}, {}, {null_type}, op_info); program->InsertOp(operation); return operation; } @@ -205,12 +205,12 @@ inline std::vector GenerateOperationInput( for (const auto& info : input_infos) { std::string legacy_input_name = - op_normalizer.GetLegacyArgName(normalized_op_name, info.name_); + op_normalizer.GetLegacyArgName(normalized_op_name, info.name); // return empty type if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute if (!op_desc.HasInput(legacy_input_name)) { - PADDLE_ENFORCE(info.optional_, + PADDLE_ENFORCE(info.optional, "Op %s arg %s should be optional if it can be empty", op_desc.Type(), legacy_input_name); @@ -220,7 +220,7 @@ inline std::vector GenerateOperationInput( } const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); - bool is_vector = (info.typename_.find("VectorType") != std::string::npos); + bool is_vector = (info.type_name.find("VectorType") != std::string::npos); // if src type is Tensor if (!is_vector) { From 38f74537495272727f95fda09277e49d1772b457 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 08:36:39 +0000 Subject: [PATCH 17/43] fix op normalizer bug --- paddle/fluid/translator/op_compat_gen.py | 26 ++++----- paddle/fluid/translator/op_compat_info.h | 16 ++++- paddle/fluid/translator/op_translator.cc | 74 ++++++++++++++++-------- 3 files changed, 75 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index e8d32774cb4517..edf85211b8cb16 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -53,9 +53,9 @@ def to_phi_and_fluid_op_name(op_item): def insert_new_mappings(op_name_str: str) -> str: normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) if normalized_name == legacy_name: - return normalized_name + return normalized_name, legacy_name op_name_mappings[legacy_name] = normalized_name - return normalized_name + return normalized_name, legacy_name def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): if op_name is None: @@ -64,30 +64,26 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): op_arg_name_mappings[op_name] = {} op_arg_name_mappings[op_name].update(arg_mapping) - normalized_op_name = insert_new_mappings(op_compat_item["op"]) - normailized_backward_op_name = None + _, legacy_name = insert_new_mappings(op_compat_item["op"]) + legacy_backward_op_name = None if "backward" in op_compat_item: - normailized_backward_op_name = insert_new_mappings( + _, legacy_backward_op_name = insert_new_mappings( op_compat_item["backward"] ) if "inputs" in op_compat_item: + insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) insert_new_arg_mappings( - normalized_op_name, op_compat_item["inputs"] - ) - insert_new_arg_mappings( - normailized_backward_op_name, op_compat_item["inputs"] + legacy_backward_op_name, op_compat_item["inputs"] ) if "attrs" in op_compat_item: - insert_new_arg_mappings(normalized_op_name, op_compat_item["attrs"]) + insert_new_arg_mappings(legacy_name, op_compat_item["attrs"]) insert_new_arg_mappings( - normailized_backward_op_name, op_compat_item["attrs"] + legacy_backward_op_name, op_compat_item["attrs"] ) if "outputs" in op_compat_item: + insert_new_arg_mappings(legacy_name, op_compat_item["outputs"]) insert_new_arg_mappings( - normalized_op_name, op_compat_item["outputs"] - ) - insert_new_arg_mappings( - normailized_backward_op_name, op_compat_item["outputs"] + legacy_backward_op_name, op_compat_item["outputs"] ) # special op mappings diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 52258a5b08b9ba..38ae7c0ae68233 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -58,7 +58,21 @@ class OpNameNormalizer { if (arg_mappings.find(arg_name) == arg_mappings.end()) { return UnderscoreToCamelCase(arg_name); } - return arg_mappings.at(op_type); + return arg_mappings.at(arg_name); + } + + std::string GetLegacyAttrName(const std::string& op_type, + const std::string& arg_name) { + if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { + VLOG(10) << "[" << op_type << "] not found"; + return arg_name; + } + auto& arg_mappings = op_arg_name_mappings[op_type]; + if (arg_mappings.find(arg_name) == arg_mappings.end()) { + VLOG(10) << "[" << op_type << "][" << arg_name << "] not found"; + return arg_name; + } + return arg_mappings.at(arg_name); } }; diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 42a9d9b05748cf..f9268c009cd925 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -205,7 +205,7 @@ inline std::vector GenerateOperationInput( for (const auto& info : input_infos) { std::string legacy_input_name = - op_normalizer.GetLegacyArgName(normalized_op_name, info.name); + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); // return empty type if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute @@ -289,38 +289,61 @@ inline std::tuple GenerateOperationOutput( return {op_output_types, arg_to_idx}; } -inline ir::AttributeMap TranslateOpAttribute(const OpDesc& op_desc) { +inline ir::AttributeMap TranslateOpAttribute( + std::string normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) { auto& attribute_translator = AttributeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); ir::AttributeMap attribute_map = {}; - for (auto attr_in_op_desc : op_desc.GetAttrMap()) { - const auto& attr_name = attr_in_op_desc.first; - const auto& attr_value = attr_in_op_desc.second; - VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name - << " " << attr_value.index(); - ir::Attribute new_attr = attribute_translator[attr_value]; - attribute_map[attr_name] = new_attr; + + for (const auto& info : op_attr_infos) { + auto legacy_attr_name = + op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); + paddle::framework::Attribute legacy_attr = + op_desc.GetAttr(legacy_attr_name); + VLOG(10) << "attribute in " << op_desc.Type() + << " name: " << legacy_attr_name << " " << legacy_attr.index(); + ir::Attribute new_attr = attribute_translator[legacy_attr]; + attribute_map[info.name] = new_attr; if (!new_attr) { VLOG(0) << "empty attribute in " << op_desc.Type() - << " name: " << attr_name; + << " name: " << info.name; } else { VLOG(10) << "new attribute in " << op_desc.Type() - << " name: " << attr_name << " " << new_attr.storage(); + << " name: " << info.name << " " << new_attr.storage(); } } - for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { - const auto& attr_name = attr_in_op_desc.first; - const auto& attr_value = attr_in_op_desc.second; - ir::Attribute new_attr = attribute_translator[attr_value]; - attribute_map[attr_name] = new_attr; - if (!new_attr) { - VLOG(0) << "empty runtime attribute in " << op_desc.Type() - << " name: " << attr_name; - } else { - VLOG(10) << "new runtime attribute in " << op_desc.Type() - << " name: " << attr_name << " " << new_attr.storage(); - } - } + // for (auto attr_in_op_desc : op_desc.GetAttrMap()) { + // const auto& attr_name = attr_in_op_desc.first; + // const auto& attr_value = attr_in_op_desc.second; + // VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name + // << " " << attr_value.index(); + // ir::Attribute new_attr = attribute_translator[attr_value]; + // attribute_map[attr_name] = new_attr; + // if (!new_attr) { + // VLOG(0) << "empty attribute in " << op_desc.Type() + // << " name: " << attr_name; + // } else { + // VLOG(10) << "new attribute in " << op_desc.Type() + // << " name: " << attr_name << " " << new_attr.storage(); + // } + // } + + // for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { + // const auto& attr_name = attr_in_op_desc.first; + // const auto& attr_value = attr_in_op_desc.second; + // ir::Attribute new_attr = attribute_translator[attr_value]; + // attribute_map[attr_name] = new_attr; + // if (!new_attr) { + // VLOG(0) << "empty runtime attribute in " << op_desc.Type() + // << " name: " << attr_name; + // } else { + // VLOG(10) << "new runtime attribute in " << op_desc.Type() + // << " name: " << attr_name << " " << new_attr.storage(); + // } + // } return std::move(attribute_map); } @@ -370,7 +393,8 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto attribute_map = TranslateOpAttribute(op_desc); + auto attribute_map = + TranslateOpAttribute(op_info.name(), attr_infos, op_desc); VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; ir::Operation* operation = From e8c623459ba2ed21928154015d519ae39b1d0cef Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 09:57:01 +0000 Subject: [PATCH 18/43] refactor attribute translator --- .../fluid/translator/attribute_translator.cc | 133 +++++++++++++++--- .../fluid/translator/attribute_translator.h | 10 +- paddle/fluid/translator/op_translator.cc | 9 +- 3 files changed, 131 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc index 554b1f3ac7f3b3..4390ba68616d56 100644 --- a/paddle/fluid/translator/attribute_translator.cc +++ b/paddle/fluid/translator/attribute_translator.cc @@ -18,6 +18,10 @@ #include #include "paddle/fluid/dialect/pd_attribute.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" #include "paddle/utils/variant.h" @@ -31,25 +35,38 @@ class AttributeVisitor { ~AttributeVisitor() {} public: - ir::Attribute operator()(int i) { return ir::Int32_tAttribute::get(ctx, i); } + virtual ir::Attribute operator()(int i) { + VLOG(10) << "translating int"; + return ir::Int32_tAttribute::get(ctx, i); + } - ir::Attribute operator()(float f) { return ir::FloatAttribute::get(ctx, f); } + virtual ir::Attribute operator()(float f) { + VLOG(10) << "translating float"; + return ir::FloatAttribute::get(ctx, f); + } - ir::Attribute operator()(bool b) { return ir::BoolAttribute::get(ctx, b); } + virtual ir::Attribute operator()(bool b) { + VLOG(10) << "translating bool"; + return ir::BoolAttribute::get(ctx, b); + } - ir::Attribute operator()(double d) { + virtual ir::Attribute operator()(double d) { + VLOG(10) << "translating double"; return ir::DoubleAttribute::get(ctx, d); } - ir::Attribute operator()(std::string str) { + virtual ir::Attribute operator()(std::string str) { + VLOG(10) << "translating string"; return ir::StrAttribute::get(ctx, str); } - ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + VLOG(10) << "translating scalar"; return paddle::dialect::ScalarAttribute::get(ctx, scalar); } - ir::Attribute operator()(const std::vector& strs) { + virtual ir::Attribute operator()(const std::vector& strs) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(strs.size()); for (const auto& v : strs) { @@ -58,7 +75,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& fs) { + virtual ir::Attribute operator()(const std::vector& fs) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(fs.size()); for (const auto& v : fs) { @@ -67,7 +85,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& is) { + virtual ir::Attribute operator()(const std::vector& is) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(is.size()); for (const auto& v : is) { @@ -76,7 +95,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& bs) { + virtual ir::Attribute operator()(const std::vector& bs) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(bs.size()); for (const auto& v : bs) { @@ -85,7 +105,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& i64s) { + virtual ir::Attribute operator()(const std::vector& i64s) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(i64s.size()); for (const auto& v : i64s) { @@ -94,7 +115,8 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const std::vector& ds) { + virtual ir::Attribute operator()(const std::vector& ds) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(ds.size()); for (const auto& v : ds) { @@ -103,8 +125,9 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()( + virtual ir::Attribute operator()( const std::vector& ss) { + VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(ss.size()); for (const auto& v : ss) { @@ -113,17 +136,95 @@ class AttributeVisitor { return ir::ArrayAttribute::get(ctx, attrs); } + virtual ir::Attribute operator()(const paddle::blank& blank) { + VLOG(10) << "translating paddle::blank"; + return ir::Attribute(nullptr); + } + template ir::Attribute operator()(T attr) { + VLOG(10) << "translating null type"; return ir::Attribute(nullptr); } }; -AttributeTranslator::AttributeTranslator() { visitor = new AttributeVisitor(); } +class IntArrayAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(const std::vector& is) override { + VLOG(10) << "translating vector to IntArray"; + phi::IntArray data(is); + return paddle::dialect::IntArrayAttribute::get(ctx, data); + } + + ir::Attribute operator()(const std::vector& is) override { + VLOG(10) << "translating vector to IntArray"; + phi::IntArray data(is); + return paddle::dialect::IntArrayAttribute::get(ctx, data); + } +}; + +class ScalarAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(int i) override { + VLOG(10) << "translating int to Scalar"; + phi::Scalar data(i); + return paddle::dialect::ScalarAttribute::get(ctx, data); + } -ir::Attribute AttributeTranslator::operator[]( + ir::Attribute operator()(float f) override { + VLOG(10) << "translating float to Scalar"; + phi::Scalar data(f); + return paddle::dialect::ScalarAttribute::get(ctx, data); + } +}; + +class DataTypeAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(int i) override { + VLOG(10) << "translating int to DataType: " << i; + phi::DataType data = static_cast(i); + return paddle::dialect::DataTypeAttribute::get(ctx, data); + } +}; + +class PlaceAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + + ir::Attribute operator()(const paddle::blank& blank) override { + VLOG(10) << "translating paddle::blank"; + phi::Place data(phi::AllocationType::CPU); + return paddle::dialect::PlaceAttribute::get(ctx, data); + } +}; + +AttributeTranslator::AttributeTranslator() { + general_visitor = new AttributeVisitor(); + special_visitors["paddle::dialect::IntArrayAttribute"] = + new IntArrayAttributeVisitor(); + special_visitors["paddle::dialect::ScalarAttribute"] = + new ScalarAttributeVisitor(); + special_visitors["paddle::dialect::DataTypeAttribute"] = + new DataTypeAttributeVisitor(); + special_visitors["paddle::dialect::PlaceAttribute"] = + new PlaceAttributeVisitor(); +} + +ir::Attribute AttributeTranslator::operator()( const framework::Attribute& attr) { - return paddle::visit(*visitor, attr); + return paddle::visit(*general_visitor, attr); +} + +ir::Attribute AttributeTranslator::operator()( + const std::string& target_type, const framework::Attribute& attr) { + if (special_visitors.find(target_type) == special_visitors.end()) { + VLOG(10) << "[" << target_type << "] not found"; + return paddle::visit(*general_visitor, attr); + } + return paddle::visit(*(special_visitors.at(target_type)), attr); } } // namespace translator diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h index ed64c27b120eb1..ea509c7e346736 100644 --- a/paddle/fluid/translator/attribute_translator.h +++ b/paddle/fluid/translator/attribute_translator.h @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/ir/core/attribute.h" @@ -28,7 +31,8 @@ class AttributeVisitor; class AttributeTranslator { private: AttributeTranslator(); - AttributeVisitor* visitor; + AttributeVisitor* general_visitor; + std::unordered_map special_visitors; public: AttributeTranslator(const AttributeTranslator&) = delete; @@ -41,7 +45,9 @@ class AttributeTranslator { return attribute_translator; } - ir::Attribute operator[](const framework::Attribute& attr); + ir::Attribute operator()(const framework::Attribute& attr); + ir::Attribute operator()(const std::string& target_type, + const framework::Attribute& attr); }; } // namespace translator diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index f9268c009cd925..9cb0a3e41916c0 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -300,11 +300,14 @@ inline ir::AttributeMap TranslateOpAttribute( for (const auto& info : op_attr_infos) { auto legacy_attr_name = op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); - paddle::framework::Attribute legacy_attr = - op_desc.GetAttr(legacy_attr_name); + + paddle::framework::Attribute legacy_attr; + if (op_desc.HasAttr(legacy_attr_name)) { + legacy_attr = op_desc.GetAttr(legacy_attr_name); + } VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = attribute_translator[legacy_attr]; + ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); attribute_map[info.name] = new_attr; if (!new_attr) { VLOG(0) << "empty attribute in " << op_desc.Type() From 261fe9729692d066d1c1f2ee2d401e46b78efc09 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 30 May 2023 13:07:05 +0000 Subject: [PATCH 19/43] fix bug --- paddle/fluid/translator/op_compat_info.h | 18 +++++++++++++ paddle/fluid/translator/op_translator.cc | 33 +----------------------- paddle/fluid/translator/utils.h | 1 + paddle/ir/core/builtin_dialect.cc | 2 +- paddle/ir/core/operation.cc | 6 +++++ paddle/ir/core/value.cc | 3 +++ paddle/phi/api/yaml/op_compat.yaml | 2 ++ 7 files changed, 32 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 38ae7c0ae68233..654329dbbe357e 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -51,6 +52,23 @@ class OpNameNormalizer { std::string GetLegacyArgName(const std::string& op_type, const std::string& arg_name) { + bool is_grad_op = (op_type.find("grad") != std::string::npos); + bool is_grad_arg = (arg_name.find("grad") != std::string::npos); + if (is_grad_op && is_grad_arg) { + std::string target = "_grad"; + std::string data = "@GRAD"; + + size_t first_grad_pos = arg_name.find_first_of(target); + std::string legacy_name = + this->GetLegacyArgName(op_type, arg_name.substr(0, first_grad_pos)); + legacy_name += arg_name.substr(first_grad_pos); + for (size_t pos = 0; + legacy_name.npos != (pos = legacy_name.find(target, pos)); + pos += data.length()) { + legacy_name.replace(pos, target.length(), data); + } + return legacy_name; + } if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { return UnderscoreToCamelCase(arg_name); } diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 9cb0a3e41916c0..a2cceaf282176b 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -214,8 +214,7 @@ inline std::vector GenerateOperationInput( "Op %s arg %s should be optional if it can be empty", op_desc.Type(), legacy_input_name); - auto* constant_op = InsertConstantOperationForOptionalArg(ctx, program); - op_inputs.push_back(constant_op->GetResultByIndex(0)); + op_inputs.push_back(ir::OpResult(nullptr)); continue; } @@ -318,36 +317,6 @@ inline ir::AttributeMap TranslateOpAttribute( } } - // for (auto attr_in_op_desc : op_desc.GetAttrMap()) { - // const auto& attr_name = attr_in_op_desc.first; - // const auto& attr_value = attr_in_op_desc.second; - // VLOG(0) << "attribute in " << op_desc.Type() << " name: " << attr_name - // << " " << attr_value.index(); - // ir::Attribute new_attr = attribute_translator[attr_value]; - // attribute_map[attr_name] = new_attr; - // if (!new_attr) { - // VLOG(0) << "empty attribute in " << op_desc.Type() - // << " name: " << attr_name; - // } else { - // VLOG(10) << "new attribute in " << op_desc.Type() - // << " name: " << attr_name << " " << new_attr.storage(); - // } - // } - - // for (auto attr_in_op_desc : op_desc.GetRuntimeAttrMap()) { - // const auto& attr_name = attr_in_op_desc.first; - // const auto& attr_value = attr_in_op_desc.second; - // ir::Attribute new_attr = attribute_translator[attr_value]; - // attribute_map[attr_name] = new_attr; - // if (!new_attr) { - // VLOG(0) << "empty runtime attribute in " << op_desc.Type() - // << " name: " << attr_name; - // } else { - // VLOG(10) << "new runtime attribute in " << op_desc.Type() - // << " name: " << attr_name << " " << new_attr.storage(); - // } - // } - return std::move(attribute_map); } diff --git a/paddle/fluid/translator/utils.h b/paddle/fluid/translator/utils.h index 5711f1f550fb14..7065f46992c6aa 100644 --- a/paddle/fluid/translator/utils.h +++ b/paddle/fluid/translator/utils.h @@ -15,6 +15,7 @@ #pragma once #include +#include namespace paddle { namespace translator { diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index 600efd4bc240a7..eecc5dc05c154b 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -44,7 +44,7 @@ void BuiltinDialect::initialize() { Int64_tAttribute, ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 4f9575c03d349a..d6f2f4592fc42c 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -64,6 +64,7 @@ Operation *Operation::create(const std::vector &inputs, // 2. Malloc memory. char *base_ptr = reinterpret_cast(aligned_malloc(base_size, 8)); // 3.1. Construct OpResults. + VLOG(10) << "debug"; for (size_t idx = num_results; idx > 0; idx--) { if (idx > max_inline_result_num) { new (base_ptr) @@ -73,7 +74,9 @@ Operation *Operation::create(const std::vector &inputs, new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1); base_ptr += sizeof(detail::OpInlineResultImpl); } + VLOG(10) << "debug " << idx; } + VLOG(10) << "debug"; // 3.2. Construct Operation. Operation *op = new (base_ptr) Operation(attribute, op_info, num_results, num_operands, num_regions); @@ -82,10 +85,13 @@ Operation *Operation::create(const std::vector &inputs, if ((reinterpret_cast(base_ptr) & 0x7) != 0) { throw("The address of OpOperandImpl must be divisible by 8."); } + VLOG(10) << "debug"; for (size_t idx = 0; idx < num_operands; idx++) { new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); + VLOG(10) << "debug " << idx; base_ptr += sizeof(detail::OpOperandImpl); } + VLOG(10) << "debug"; // 3.4. Construct Regions if (num_regions > 0) { op->regions_ = reinterpret_cast(base_ptr); diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 631f0ba7adfd2a..30e920e089fd70 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -112,6 +112,9 @@ void OpOperandImpl::release_source() { source_ = nullptr; } OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) : source_(source), owner_(owner) { + if (!source) { + return; + } prev_use_addr_ = source.impl()->first_use_addr(); next_use_ = source.impl()->first_use(); if (next_use_) { diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index a1aa11bc39d581..eea7491b0633ca 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -462,6 +462,8 @@ - op : conv2d backward : conv2d_grad + inputs : + out : Output extra : attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, bool use_addto = false, bool force_fp32_output = false, From 6d57a5f5cbbd8e53d76dd2ce495cde6fe7cd6774 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 31 May 2023 08:17:01 +0000 Subject: [PATCH 20/43] refactor output translator --- paddle/fluid/dialect/op_gen.py | 29 +++++-- paddle/fluid/translator/op_translator.cc | 87 +++++++++++++++------ paddle/ir/core/operation.cc | 6 -- paddle/ir/core/printer.cc | 28 +++++-- paddle/ir/core/value.cc | 3 +- paddle/phi/api/yaml/op_compat.yaml | 1 + test/cpp/ir/core/program_translator_test.cc | 8 +- 7 files changed, 110 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index acb3b2721278f3..01914a5951f220 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -141,6 +141,14 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ }} """ +GRAD_OP_VERIFY_TEMPLATE = """ +void {op_name}::verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ + (void)inputs; + (void)outputs; + (void)attributes; +}} +""" + INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); """ @@ -703,14 +711,19 @@ def OpGenerator( ) # generate op verify function - op_verify_str = OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - inputs_size=len(op_input_type_list), - outputs_size=len(op_output_type_list), - inputs_type_check=inputs_type_check_str, - outputs_type_check=outputs_type_check_str, - attributes_check=attributes_check_str, - ) + if "Grad" in op_class_name: + op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + ) + else: + op_verify_str = OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + inputs_size=len(op_input_type_list), + outputs_size=len(op_output_type_list), + inputs_type_check=inputs_type_check_str, + outputs_type_check=outputs_type_check_str, + attributes_check=attributes_check_str, + ) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index a2cceaf282176b..e7d8112577c7ba 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -207,7 +207,7 @@ inline std::vector GenerateOperationInput( std::string legacy_input_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); - // return empty type if this arg is optional and not shown in OpDesc + // return empty OpResult if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute if (!op_desc.HasInput(legacy_input_name)) { PADDLE_ENFORCE(info.optional, @@ -225,6 +225,7 @@ inline std::vector GenerateOperationInput( if (!is_vector) { auto defining_info = (*param_map)[legacy_input_vars[0]]; op_inputs.push_back(defining_info.value); + // if src type is Vector , need an additional `CombineOp` to // assemble them. } else { @@ -238,48 +239,75 @@ inline std::vector GenerateOperationInput( } inline std::tuple GenerateOperationOutput( - ir::IrContext* ctx, const OpDesc& op_desc) { + ir::IrContext* ctx, + const OpDesc& op_desc, + const OpOutputInfoList& output_infos) { OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; auto& type_translator = TypeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); const BlockDesc* block = op_desc.Block(); - for (const auto& n : op_desc.Outputs()) { - auto& name = n.first; - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name; - auto& args = n.second; + for (const auto& info : output_infos) { size_t cur_output_idx = op_output_types.size(); + std::string legacy_output_name = + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); - // if src type is Tensor or a Vector with size <= 1 - if (args.size() <= 1) { - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name - << " " << var->GetType(); + // return empty type if this arg is optional and not shown in OpDesc + // TODO(lyk): HasOutput doesnot consider variadic attribute + if (!op_desc.HasOutput(legacy_output_name)) { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "] optional " << info.name << " :" + << info.type_name << " " << legacy_output_name; + PADDLE_ENFORCE(info.optional, + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name); + op_output_types.push_back(ir::Type(nullptr)); + continue; + } - ir::Type translated_var_type = - type_translator[var->GetType()](ctx, *var); + const auto& legacy_output_vars = op_desc.Output(legacy_output_name); + bool is_vector = (info.type_name.find("VectorType") != std::string::npos); - arg_to_idx[arg_name] = cur_output_idx; - op_output_types.push_back(translated_var_type); + // if src type is Tensor + if (!is_vector) { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " :" + << info.type_name << " " << legacy_output_name; + if (legacy_output_vars.size() == 0) { + op_output_types.push_back(ir::Type(nullptr)); + continue; } + auto& var_name = legacy_output_vars[0]; + VarDesc* var = block->FindVarRecursive(var_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " " << var_name + << " " << var->GetType(); + + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + + arg_to_idx[var_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); + // if src type is Vector } else { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " :" + << info.type_name << " " << legacy_output_name; std::vector types; - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); + for (const auto& var_name : legacy_output_vars) { + VarDesc* var = block->FindVarRecursive(var_name); VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name + << "[" << op_desc.Type() << "]" << info.name << " " << var_name << " " << var->GetType(); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); types.push_back(translated_var_type); - arg_to_idx[arg_name] = cur_output_idx; + arg_to_idx[var_name] = cur_output_idx; } ir::Type vec_type = ir::VectorType::get(ctx, types); op_output_types.push_back(vec_type); @@ -363,7 +391,8 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; - std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + std::tie(op_output_types, arg_to_idx) = + GenerateOperationOutput(ctx, op_desc, output_infos); auto attribute_map = TranslateOpAttribute(op_info.name(), attr_infos, op_desc); @@ -385,11 +414,21 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ir::Program* program, const OpDesc& op_desc) { auto op_info = LoopkUpOpInfo(ctx, op_desc); + + auto* op_info_concept = + op_info.GetInterfaceImpl(); + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos) = + op_info_concept->get_op_info_(); + std::vector op_inputs; OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; - std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + std::tie(op_output_types, arg_to_idx) = + GenerateOperationOutput(ctx, op_desc, output_infos); ir::AttributeMap attribute_map = { {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, }; diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index d6f2f4592fc42c..4f9575c03d349a 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -64,7 +64,6 @@ Operation *Operation::create(const std::vector &inputs, // 2. Malloc memory. char *base_ptr = reinterpret_cast(aligned_malloc(base_size, 8)); // 3.1. Construct OpResults. - VLOG(10) << "debug"; for (size_t idx = num_results; idx > 0; idx--) { if (idx > max_inline_result_num) { new (base_ptr) @@ -74,9 +73,7 @@ Operation *Operation::create(const std::vector &inputs, new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1); base_ptr += sizeof(detail::OpInlineResultImpl); } - VLOG(10) << "debug " << idx; } - VLOG(10) << "debug"; // 3.2. Construct Operation. Operation *op = new (base_ptr) Operation(attribute, op_info, num_results, num_operands, num_regions); @@ -85,13 +82,10 @@ Operation *Operation::create(const std::vector &inputs, if ((reinterpret_cast(base_ptr) & 0x7) != 0) { throw("The address of OpOperandImpl must be divisible by 8."); } - VLOG(10) << "debug"; for (size_t idx = 0; idx < num_operands; idx++) { new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); - VLOG(10) << "debug " << idx; base_ptr += sizeof(detail::OpOperandImpl); } - VLOG(10) << "debug"; // 3.4. Construct Regions if (num_regions > 0) { op->regions_ = reinterpret_cast(base_ptr); diff --git a/paddle/ir/core/printer.cc b/paddle/ir/core/printer.cc index 5dc91142fb5e28..8788365b08e40f 100644 --- a/paddle/ir/core/printer.cc +++ b/paddle/ir/core/printer.cc @@ -50,6 +50,11 @@ class Printer { explicit Printer(std::ostream& os) : os(os) {} void PrintType(ir::Type type) { + if (!type) { + os << ""; + return; + } + if (type.isa()) { os << "f16"; } else if (type.isa()) { @@ -82,10 +87,6 @@ class Printer { }; void Type::print(std::ostream& os) const { - if (!*this) { - os << ""; - return; - } Printer p(os); p.PrintType(*this); } @@ -104,6 +105,10 @@ class ProgramPrinter : public Printer { } void PrintValue(ir::Value v) { + if (!v) { + os << "<>"; + return; + } const void* key = static_cast(v.impl()); auto ret = aliases.find(key); if (ret != aliases.end()) { @@ -175,8 +180,12 @@ class ProgramPrinter : public Printer { std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - op_operand_types.push_back( - op->GetOperandByIndex(idx).impl()->source().type()); + auto op_operand = op->GetOperandByIndex(idx); + if (op_operand) { + op_operand_types.push_back(op_operand.impl()->source().type()); + } else { + op_operand_types.push_back(ir::Type(nullptr)); + } } PrintInterleave( op_operand_types.begin(), @@ -190,7 +199,12 @@ class ProgramPrinter : public Printer { std::vector op_result_types; op_result_types.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { - op_result_types.push_back(op->GetResultByIndex(idx).type()); + auto op_result = op->GetResultByIndex(idx); + if (op_result) { + op_result_types.push_back(op_result.type()); + } else { + op_result_types.push_back(ir::Type(nullptr)); + } } PrintInterleave( op_result_types.begin(), diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 30e920e089fd70..022d1cc105a575 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -42,7 +42,7 @@ bool OpOperand::operator!=(OpOperand other) const { bool OpOperand::operator!() const { return impl_ == nullptr; } -OpOperand::operator bool() const { return impl_; } +OpOperand::operator bool() const { return impl_ && impl_->source(); } detail::OpOperandImpl *OpOperand::impl() const { return impl_; } @@ -124,6 +124,7 @@ OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) } void OpOperandImpl::remove_from_ud_chain() { + if (!source_) return; if (!prev_use_addr_) return; if (prev_use_addr_ == source_.impl()->first_use_addr()) { /// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index eea7491b0633ca..ad21e9a5c74ac5 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1931,6 +1931,7 @@ x : X outputs: out : Out + xshape: XShape int_array: shape : data_type : int diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 81ac1d07ea44da..b98bedd016fca8 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -47,7 +47,6 @@ ProgramDesc load_from_file(const std::string &file_name) { } TEST(PaddleDialectTest, Translator) { - LOG(WARNING) << "TODO"; auto p = load_from_file("restnet50_main.prog"); EXPECT_EQ(p.Size(), 1u); @@ -58,10 +57,7 @@ TEST(PaddleDialectTest, Translator) { size_t op_size = program->block()->size(); // ops.size() = op size in BlockDesc + get_parameter_op + combine op - EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); + EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); - std::ofstream ostrm( - "/home/lvyongkang/Paddle/build/log_resnet_after_translated", - std::ios::out); - ostrm << *program; + std::cout << *program << std::endl; } From 0717f1105462a9fc1a219df5ba81cc4e3a1c54f5 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 31 May 2023 08:29:03 +0000 Subject: [PATCH 21/43] fix typo --- paddle/phi/api/yaml/op_compat.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index ad21e9a5c74ac5..96e1752f6eb489 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1780,6 +1780,9 @@ backward : pool2d_grad attrs: kernel_size: ksize + extra : + attrs : [bool use_mkldnn = false, bool use_quantizer = false, + str mkldnn_data_type = "float32", bool is_test = false] - op : pool3d backward : pool3d_grad From c9d7c63a936c3aef071ce629404f67a389338d8f Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 31 May 2023 10:59:36 +0000 Subject: [PATCH 22/43] fix --- paddle/fluid/translator/op_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index e7d8112577c7ba..bac7ea2efb1d1a 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -345,7 +345,7 @@ inline ir::AttributeMap TranslateOpAttribute( } } - return std::move(attribute_map); + return attribute_map; } inline void RecordOpResultMapping(TranslationContext* param_map, From 2fb54d8ee98072656e4ff1df427d28777a7e8b4a Mon Sep 17 00:00:00 2001 From: kangguangli Date: Thu, 1 Jun 2023 07:07:21 +0000 Subject: [PATCH 23/43] fix approval error --- paddle/fluid/translator/op_translator.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index bac7ea2efb1d1a..6d86108f2cfb36 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -211,9 +211,10 @@ inline std::vector GenerateOperationInput( // TODO(lyk): HasInput doesnot consider variadic attribute if (!op_desc.HasInput(legacy_input_name)) { PADDLE_ENFORCE(info.optional, - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_input_name); + platform::errors::PreconditionNotMet( + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_input_name)); op_inputs.push_back(ir::OpResult(nullptr)); continue; } @@ -262,9 +263,10 @@ inline std::tuple GenerateOperationOutput( << "[" << op_desc.Type() << "] optional " << info.name << " :" << info.type_name << " " << legacy_output_name; PADDLE_ENFORCE(info.optional, - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_output_name); + platform::errors::PreconditionNotMet( + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name)); op_output_types.push_back(ir::Type(nullptr)); continue; } From e945a413f81ade27ac3eab8d9e0024c242e8f983 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Thu, 1 Jun 2023 07:10:54 +0000 Subject: [PATCH 24/43] fix coverage --- test/cpp/ir/core/ir_program_test.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 8b78b739a65943..5eb6e4e11d0980 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -248,13 +248,10 @@ TEST(program_test, slice_combine_test) { ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info); program.InsertOp(op1); - // (5) Def b = GetParameterOp("b") - std::string op2_name = std::string(ir::GetParameterOp::name()); + // (5) Def b = Constant("b") + std::string op2_name = std::string(ir::ConstantOp::name()); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); - std::unordered_map op2_attribute{ - {"parameter_name", ir::StrAttribute::get(ctx, "b")}}; - ir::Operation *op2 = - ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info); + ir::Operation *op2 = ir::Operation::create({}, {}, {fp32_dtype}, op2_info); program.InsertOp(op2); // (6) Def combine_op = CombineOp("a", "b") From c9358f472a62f54bcd11433e94db732961f67596 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Thu, 1 Jun 2023 08:24:14 +0000 Subject: [PATCH 25/43] fix op_compat parser --- paddle/fluid/translator/op_compat_gen.py | 28 +++++++++++++----------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index edf85211b8cb16..5bc9df7ee8b34b 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -34,7 +34,7 @@ def OpNameNormalizerInitialization( op_compat_yaml_file: str = "", output_source_file: str = "" ) -> None: def to_phi_and_fluid_op_name(op_item): - # Templat: - op : phi_name (fluid_name) + # Template: - op : phi_name (fluid_name) names = op_item.split('(') if len(names) == 1: phi_fluid_name = names[0].strip() @@ -65,26 +65,28 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): op_arg_name_mappings[op_name].update(arg_mapping) _, legacy_name = insert_new_mappings(op_compat_item["op"]) - legacy_backward_op_name = None + legacy_backward_op_names = [] if "backward" in op_compat_item: - _, legacy_backward_op_name = insert_new_mappings( - op_compat_item["backward"] + backward_op_name_mapping_paris = op_compat_item["backward"].split( + "," ) + for pair in backward_op_name_mapping_paris: + _, legacy_backward_op_name = insert_new_mappings(pair) + legacy_backward_op_names.append(legacy_backward_op_name) + if "inputs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) - insert_new_arg_mappings( - legacy_backward_op_name, op_compat_item["inputs"] - ) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["inputs"]) + if "attrs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["attrs"]) - insert_new_arg_mappings( - legacy_backward_op_name, op_compat_item["attrs"] - ) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["attrs"]) if "outputs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["outputs"]) - insert_new_arg_mappings( - legacy_backward_op_name, op_compat_item["outputs"] - ) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["outputs"]) # special op mappings op_name_mappings["fetch_v2"] = "fetch" From b785d35a97df8e6a9abc612b2316304e0919b7bf Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 03:01:54 +0000 Subject: [PATCH 26/43] fix merge conflicts --- paddle/fluid/translator/op_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index e8e1f10865a67b..205dc88d3daaa2 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -166,7 +166,7 @@ inline ir::Operation* InsertConstantOperationForOptionalArg( ir::Type null_type = ir::Type(nullptr); ir::Operation* operation = ir::Operation::create({}, {}, {null_type}, op_info); - program->InsertOp(operation); + program->block()->push_back(operation); return operation; } From a0f21e6677705103fad6da4eae52bf0e83241a49 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 03:21:25 +0000 Subject: [PATCH 27/43] fix merge conflicts --- paddle/fluid/translator/op_translator.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 205dc88d3daaa2..5dca36db1c8b59 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -385,7 +385,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, OpInputInfoList input_infos; OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos) = + std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); auto op_inputs = GenerateOperationInput( @@ -422,7 +422,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, OpInputInfoList input_infos; OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos) = + std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); std::vector op_inputs; @@ -454,7 +454,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, OpInputInfoList input_infos; OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos) = + std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); auto op_inputs = GenerateOperationInput( From c2ede6faaca4de8fba7b11ffc1d69671e9375c23 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 06:21:31 +0000 Subject: [PATCH 28/43] fix merge conflicts --- paddle/fluid/dialect/op_gen.py | 2 +- paddle/fluid/dialect/pd_interface.h | 5 +---- paddle/ir/core/builtin_dialect.cc | 7 ++++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 0e085734f42fbf..0d8c4d336f1329 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -1186,7 +1186,7 @@ def OpGenerator( ) # generate op verify function - if "Grad" in op_class_name: + if "GradOp" in op_class_name or "Grad_Op" in op_class_name: op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( op_name=op_class_name, ) diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h index 9ca4871e346747..f49ba35fcc3e77 100644 --- a/paddle/fluid/dialect/pd_interface.h +++ b/paddle/fluid/dialect/pd_interface.h @@ -36,10 +36,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase { struct Model : public Concept { static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); } - Model() : Concept(GetOpInfo) { - static_assert(sizeof(Model) == sizeof(Concept), - "sizeof(Model) != sizeof(Concept)"); - } + Model() : Concept(GetOpInfo) {} }; GetOpInfoInterface(ir::Operation *op, Concept *impl) diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index f3e1dbe77fe723..f3a5183cd6ca61 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -45,7 +45,12 @@ void BuiltinDialect::initialize() { Int64_tAttribute, ArrayAttribute>(); - RegisterOps(); + RegisterOps(); } } // namespace ir From 7f97345c4e7c853c13d92da7eb330a42c5cde4a3 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 06:31:44 +0000 Subject: [PATCH 29/43] fix merge conflicts --- paddle/ir/core/value.cc | 2 -- paddle/ir/core/value.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 3a4ddb8b9cd4b2..92537a17a9dd08 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -38,8 +38,6 @@ Value OpOperand::source() const { return impl_->source(); } Operation *OpOperand::owner() const { return impl_->owner(); } -OpOperand::operator bool() const { return impl_ && impl_->source(); } - // Value Value::Value(const detail::ValueImpl *impl) : impl_(const_cast(impl)) {} diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index 795aa6f5a80663..ee2221c1bebc5d 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -49,7 +49,7 @@ class OpOperand { bool operator!() const { return impl_ == nullptr; } - operator bool() const { return impl_; } + operator bool() const { return impl_ && impl_->source(); } OpOperand next_use() const; From 5f81713280bd5b3d10a9d01328306cc2072a7649 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 2 Jun 2023 06:35:02 +0000 Subject: [PATCH 30/43] fix merge conflicts --- paddle/ir/core/value.cc | 1 + paddle/ir/core/value.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 92537a17a9dd08..d9032e8ae90a63 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -31,6 +31,7 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { impl_ = const_cast(impl); return *this; } +OpOperand::operator bool() const { return impl_ && impl_->source(); } OpOperand OpOperand::next_use() const { return impl_->next_use(); } diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index ee2221c1bebc5d..0825dd06cbd9bf 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -49,7 +49,7 @@ class OpOperand { bool operator!() const { return impl_ == nullptr; } - operator bool() const { return impl_ && impl_->source(); } + operator bool() const; OpOperand next_use() const; From f7bd3ab27bd70b99954bb3513942636597efe340 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 5 Jun 2023 03:24:33 +0000 Subject: [PATCH 31/43] refactor scalar attribute --- paddle/fluid/ir/dialect/pd_attribute.cc | 4 -- paddle/fluid/ir/dialect/pd_attribute.h | 13 ++++--- .../fluid/ir/dialect/pd_attribute_storage.h | 23 ----------- paddle/fluid/ir/dialect/pd_dialect.cc | 1 - .../translator/attribute_translator.cc | 27 +------------ test/cpp/ir/core/CMakeLists.txt | 9 +++++ test/cpp/ir/core/scalar_attribute_test.cc | 38 +++++++++++++++++++ 7 files changed, 56 insertions(+), 59 deletions(-) create mode 100644 test/cpp/ir/core/scalar_attribute_test.cc diff --git a/paddle/fluid/ir/dialect/pd_attribute.cc b/paddle/fluid/ir/dialect/pd_attribute.cc index c151aaab8012f2..4829010609d557 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.cc +++ b/paddle/fluid/ir/dialect/pd_attribute.cc @@ -18,10 +18,6 @@ namespace paddle { namespace dialect { phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); } -paddle::experimental::Scalar ScalarAttribute::data() const { - return storage()->GetAsKey(); -} - phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); } phi::Place PlaceAttribute::data() const { return storage()->GetAsKey(); } diff --git a/paddle/fluid/ir/dialect/pd_attribute.h b/paddle/fluid/ir/dialect/pd_attribute.h index 22825443a7d6f8..67093f965539f2 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.h +++ b/paddle/fluid/ir/dialect/pd_attribute.h @@ -16,6 +16,7 @@ #include "paddle/fluid/ir/dialect/pd_attribute_storage.h" #include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/builtin_attribute.h" namespace paddle { namespace dialect { @@ -37,13 +38,13 @@ class ScalarAttribute : public ir::Attribute { public: using Attribute::Attribute; - DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ScalarAttribute, ScalarAttributeStorage); - - bool operator<(const ScalarAttribute &right) const { - return storage() < right.storage(); + static bool classof(ir::Attribute val) { + return (val.type_id() == ir::BoolAttribute::type_id()) || + (val.type_id() == ir::FloatAttribute::type_id()) || + (val.type_id() == ir::DoubleAttribute::type_id()) || + (val.type_id() == ir::Int32_tAttribute::type_id()) || + (val.type_id() == ir::Int64_tAttribute::type_id()); } - - paddle::experimental::Scalar data() const; }; class DataTypeAttribute : public ir::Attribute { diff --git a/paddle/fluid/ir/dialect/pd_attribute_storage.h b/paddle/fluid/ir/dialect/pd_attribute_storage.h index 4bf4f3427c7f38..6270791f725d4f 100644 --- a/paddle/fluid/ir/dialect/pd_attribute_storage.h +++ b/paddle/fluid/ir/dialect/pd_attribute_storage.h @@ -20,7 +20,6 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/place.h" -#include "paddle/phi/common/scalar.h" namespace paddle { namespace dialect { @@ -54,28 +53,6 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage { phi::IntArray data_; }; -struct ScalarAttributeStorage : public ir::AttributeStorage { - using ParamKey = paddle::experimental::Scalar; - - explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; } - - static ScalarAttributeStorage *Construct(ParamKey key) { - return new ScalarAttributeStorage(key); - } - - static std::size_t HashValue(const ParamKey &key) { - return ir::hash_combine(std::hash()(key.ToString()), - std::hash()(key.FromTensor())); - } - - bool operator==(const ParamKey &key) const { return data_ == key; } - - ParamKey GetAsKey() const { return ParamKey(data_); } - - private: - paddle::experimental::Scalar data_; -}; - struct DataTypeAttributeStorage : public ir::AttributeStorage { using ParamKey = phi::DataType; diff --git a/paddle/fluid/ir/dialect/pd_dialect.cc b/paddle/fluid/ir/dialect/pd_dialect.cc index ccc089fda18b02..d7b4b599b55fe6 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/pd_dialect.cc @@ -92,7 +92,6 @@ void PaddleDialect::initialize() { RegisterTypes(); RegisterAttributes(); diff --git a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc index 67c323a42c679a..abee76c592ba75 100644 --- a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc @@ -62,7 +62,7 @@ class AttributeVisitor { virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { VLOG(10) << "translating scalar"; - return paddle::dialect::ScalarAttribute::get(ctx, scalar); + throw(std::invalid_argument("not implemented now")); } virtual ir::Attribute operator()(const std::vector& strs) { @@ -128,12 +128,7 @@ class AttributeVisitor { virtual ir::Attribute operator()( const std::vector& ss) { VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(ss.size()); - for (const auto& v : ss) { - attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); + throw(std::invalid_argument("not implemented now")); } virtual ir::Attribute operator()(const paddle::blank& blank) { @@ -164,22 +159,6 @@ class IntArrayAttributeVisitor : public AttributeVisitor { } }; -class ScalarAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(int i) override { - VLOG(10) << "translating int to Scalar"; - phi::Scalar data(i); - return paddle::dialect::ScalarAttribute::get(ctx, data); - } - - ir::Attribute operator()(float f) override { - VLOG(10) << "translating float to Scalar"; - phi::Scalar data(f); - return paddle::dialect::ScalarAttribute::get(ctx, data); - } -}; - class DataTypeAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; @@ -205,8 +184,6 @@ AttributeTranslator::AttributeTranslator() { general_visitor = new AttributeVisitor(); special_visitors["paddle::dialect::IntArrayAttribute"] = new IntArrayAttributeVisitor(); - special_visitors["paddle::dialect::ScalarAttribute"] = - new ScalarAttributeVisitor(); special_visitors["paddle::dialect::DataTypeAttribute"] = new DataTypeAttributeVisitor(); special_visitors["paddle::dialect::PlaceAttribute"] = diff --git a/test/cpp/ir/core/CMakeLists.txt b/test/cpp/ir/core/CMakeLists.txt index d19366e3b8ba0c..b42bdc17078d8d 100644 --- a/test/cpp/ir/core/CMakeLists.txt +++ b/test/cpp/ir/core/CMakeLists.txt @@ -22,6 +22,15 @@ cc_test_old( phi gtest) +cc_test_old( + scalar_attribute_test + SRCS + scalar_attribute_test.cc + DEPS + pd_dialect + new_ir + gtest) + file( DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog diff --git a/test/cpp/ir/core/scalar_attribute_test.cc b/test/cpp/ir/core/scalar_attribute_test.cc new file mode 100644 index 00000000000000..9be8aaf07491ba --- /dev/null +++ b/test/cpp/ir/core/scalar_attribute_test.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/ir/dialect/pd_attribute.h" +#include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/ir_context.h" + +TEST(ScalarTest, base) { + using ScalarAttribute = paddle::dialect::ScalarAttribute; + + ir::IrContext *ctx = ir::IrContext::Instance(); + + ir::Attribute bool_scalar = ir::BoolAttribute::get(ctx, false); + EXPECT_TRUE(bool_scalar.isa()); + + EXPECT_TRUE(bool_scalar.isa()); + ir::BoolAttribute pure_bool = bool_scalar.dyn_cast(); + EXPECT_TRUE(pure_bool.isa()); + ScalarAttribute scalar_from_bool = bool_scalar.dyn_cast(); + EXPECT_TRUE(scalar_from_bool.isa()); + EXPECT_NO_THROW(scalar_from_bool.dyn_cast()); +} From 4a54f065b85cba83be4962a28da78813d28afabf Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 5 Jun 2023 03:33:26 +0000 Subject: [PATCH 32/43] draft --- cmake/flags.cmake | 1 + paddle/ir/core/ir_printer.cc | 52 +++++++++++++++++++++++++++++++++--- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 7112e20e6d2df8..fe39bab6987d16 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -150,6 +150,7 @@ if(NOT WIN32) -Wextra -Wno-unused-parameter -Wno-unused-function + -Wno-unknown-pragmas -Wno-error=array-bounds #Warning in Eigen, gcc 12.2 -Wno-error=ignored-attributes # Warnings in Eigen, gcc 6.3 -Wno-error=int-in-bool-context # Warning in Eigen gcc 7.2 diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index a7a962bd204294..89576892f66dc7 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -83,9 +83,20 @@ class BasicIRPrinter { } } - void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE }"; } + void PrintAttribute(const ir::Attribute& attr) { + if (!attr) { + os << "<#AttrNull>"; + return; + } + + if (auto s = attr.dyn_cast()) { + os << s.data(); + } else if (auto b = attr.dyn_cast()) { + os << b.data(); + } + } - protected: + public: std::ostream& os; }; @@ -120,7 +131,7 @@ class IRPrinter : public BasicIRPrinter { // TODO(lyk): add API to get operands directly PrintOpOperands(op); - PrintAttribute(op); + PrintAttributeMap(op); os << " :"; // PrintOpSingature @@ -156,6 +167,25 @@ class IRPrinter : public BasicIRPrinter { os << new_name; } + /// @brief print operation + /// @param op + /// @example + void PrintOperation(ir::Operation* op) { + PrintOpResult(op); // TODO(lyk): add API to get opresults directly + os << " = "; + + os << "\"" << op->op_name() << "\""; + PrintOpOperands(op); // TODO(lyk): add API to get operands directly + + PrintAttribute(op); + os << " : "; + + // PrintOpSingature + PrintOperandsType(op); + os << " -> "; + PrintOpReturnType(op); // TODO(lyk): add API to get opresults directly + } + void PrintOpResult(ir::Operation* op) { os << " ("; auto num_op_result = op->num_results(); @@ -172,6 +202,22 @@ class IRPrinter : public BasicIRPrinter { os << ")"; } + void PrintAttributeMap(ir::Operation* op) { + os << "{"; + + PrintInterleave( + op->attributes().begin(), + op->attributes().end(), + [this](std::pair it) { + this->os << it.first; + this->os << ":"; + this->PrintAttribute(it.second); + }, + [this]() { this->os << ","; }); + + os << "}"; + } + void PrintOpOperands(ir::Operation* op) { os << " ("; auto num_op_operands = op->num_operands(); From 17edac41946ea98736401d5a36f1f45121c3c351 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 5 Jun 2023 03:34:44 +0000 Subject: [PATCH 33/43] fix --- cmake/flags.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index fe39bab6987d16..7112e20e6d2df8 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -150,7 +150,6 @@ if(NOT WIN32) -Wextra -Wno-unused-parameter -Wno-unused-function - -Wno-unknown-pragmas -Wno-error=array-bounds #Warning in Eigen, gcc 12.2 -Wno-error=ignored-attributes # Warnings in Eigen, gcc 6.3 -Wno-error=int-in-bool-context # Warning in Eigen gcc 7.2 From dc150e9e53692ac3694a940afe37c0e9ace90170 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 5 Jun 2023 07:11:09 +0000 Subject: [PATCH 34/43] fix op build --- paddle/fluid/ir/dialect/op_gen.py | 16 ++++++---------- paddle/fluid/ir/dialect/utils.h | 31 ++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index e9fa3b9ca85b56..c0b4853bcddd7a 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -303,10 +303,10 @@ def __init__(self, op_yaml_item, op_compat_item): self.attr_types_map = { 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], - 'Scalar(int)': ['paddle::dialect::ScalarAttribute', 'int'], - 'Scalar(int64_t)': ['paddle::dialect::ScalarAttribute', 'int64_t'], - 'Scalar(float)': ['paddle::dialect::ScalarAttribute', 'float'], - 'Scalar(dobule)': ['paddle::dialect::ScalarAttribute', 'dobule'], + 'Scalar(int)': ['ir::Int32_tAttribute', 'int'], + 'Scalar(int64_t)': ['ir::Int64_tAttribute', 'int64_t'], + 'Scalar(float)': ['ir::FloatAttribute', 'float'], + 'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'], 'Scalar[]': [ 'ir::ArrayAttribute', 'std::vector', @@ -627,7 +627,7 @@ def GenBuildInputs(op_input_name_list): def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); """ - SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::Scalar({attr})); + SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = TransToIrDataType({attr}, ir::IrContext::Instance()); """ STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr}); """ @@ -682,11 +682,7 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): ) elif op_attribute_type_list[idx] == "paddle::dialect::ScalarAttribute": - attr_str += SCALAR_STR_TEMPLATE.format( - attr_name=op_attribute_name_list[idx], - op_attribute_type=op_attribute_type_list[idx], - attr=op_attribute_name_list[idx], - ) + raise Exception("should have a concrete type instead of scalar") else: attr_str += STR_TEMPLATE.format( attr_name=op_attribute_name_list[idx], diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index 26781d7346c90c..1df18610fb7c68 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -18,13 +18,14 @@ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/ir/dialect/pd_type_storage.h" #include "paddle/ir/core/builtin_type.h" +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" namespace paddle { namespace dialect { // TODO(zhangbo): The builtin type needs to cover all data types of // phi::DataType. -inline phi::DataType TransToPhiDataType(ir::Type dtype) { +static inline phi::DataType TransToPhiDataType(ir::Type dtype) { if (dtype.isa()) { return phi::DataType::FLOAT16; } else if (dtype.isa()) { @@ -44,8 +45,8 @@ inline phi::DataType TransToPhiDataType(ir::Type dtype) { } } -inline ir::Type TransToIrDataType(phi::DataType dtype, - ir::IrContext *ctx = nullptr) { +static inline ir::Type TransToIrDataType(phi::DataType dtype, + ir::IrContext *ctx = nullptr) { if (ctx == nullptr) { ctx = ir::IrContext::Instance(); } @@ -70,6 +71,30 @@ inline ir::Type TransToIrDataType(phi::DataType dtype, } } +static inline ir::Type TransToIrAttribute(phi::Scalar scalar, + ir::IrContext *ctx = nullptr) { + if (ctx == nullptr) { + ctx = ir::IrContext::Instance(); + } + switch (scalar.dtype()) { + case phi::DataType::FLOAT32: + return ir::FloatAttribute::get(ctx, scalar.to()); + case phi::DataType::FLOAT64: + return ir::DoubleAttribute::get(ctx, scalar.to()); + case phi::DataType::INT32: + return ir::Int32_tAttribute::get(ctx, scalar.to()); + case phi::DataType::INT64: + return ir::Int64_tAttribute::get(ctx, scalar.to()); + case phi::DataType::BOOL: + return ir::BoolAttribute::get(ctx, scalar.to()); + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported phi data type `%s` when casting it into " + "ir attribute.", + dtype)); + } +} + // inline phi::DataLayout TransToPhiDataLayout( // DenseTensorTypeStorage::DataLayout data_layout) { // switch (data_layout) { From e885d45cc817c2c191e80518dbb40f74747650eb Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 5 Jun 2023 08:01:08 +0000 Subject: [PATCH 35/43] fix op build --- paddle/fluid/ir/dialect/op_gen.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index c0b4853bcddd7a..c048aea9da7b76 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -658,7 +658,6 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): attr_size=op_attribute_name_list[idx] + ".size()", create_attribute=SCALAR_STR_TEMPLATE.format( attr_name=op_attribute_name_list[idx], - op_attribute_type=inner_attribute_type, attr=op_attribute_name_list[idx] + "[i]", ), ) @@ -682,7 +681,10 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): ) elif op_attribute_type_list[idx] == "paddle::dialect::ScalarAttribute": - raise Exception("should have a concrete type instead of scalar") + attr_str += SCALAR_STR_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + attr=op_attribute_name_list[idx], + ) else: attr_str += STR_TEMPLATE.format( attr_name=op_attribute_name_list[idx], From 34faeb95970deabbd8ed1a58a61d524f27b1d0bb Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 5 Jun 2023 13:08:57 +0000 Subject: [PATCH 36/43] temporarily save --- paddle/fluid/ir/dialect/op_gen.py | 17 +- paddle/fluid/ir/dialect/pd_op.yaml | 2 +- paddle/fluid/ir/dialect/utils.h | 25 ++ .../ir_adaptor/translator/op_compat_info.h | 23 ++ .../ir_adaptor/translator/op_translator.cc | 61 ++++- .../fluid/translator/attribute_translator.cc | 231 ------------------ .../fluid/translator/attribute_translator.h | 54 ---- paddle/fluid/translator/utils.h | 42 ---- paddle/ir/core/ir_printer.cc | 36 ++- test/cpp/ir/core/program_translator_test.cc | 12 +- 10 files changed, 144 insertions(+), 359 deletions(-) delete mode 100644 paddle/fluid/translator/attribute_translator.cc delete mode 100644 paddle/fluid/translator/attribute_translator.h delete mode 100644 paddle/fluid/translator/utils.h diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index 7458b5f1912784..d7b1cfbf9734ac 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -249,6 +249,15 @@ def to_phi_and_fluid_op_name(op_item): return phi_name, fluid_name +scalar_type_maps = { + 'int': 'ir::Int32_tAttribute', + 'int64_t': 'ir::Int64_tAttribute', + 'float': 'ir::FloatAttribute', + 'dobule': 'ir::DoubleAttribute', + 'bool': 'ir::BoolAttribute', +} + + # ===================================== # Parse Op Compat From Yaml # ===================================== @@ -707,7 +716,7 @@ def GenBuildAttributes( ): INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); """ - SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = TransToIrDataType({attr}, ir::IrContext::Instance()); + SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = TransToIrAttribute({attr}, ir::IrContext::Instance()); """ STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr}); """ @@ -826,7 +835,7 @@ def GenBuildOutputs( }} """ CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().value().dyn_cast().data().GetData(); (void){name};\n""" - CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().value().dyn_cast().data().to<{dtype}>(); (void){name};\n""" + CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().value().dyn_cast<{ir_type}>().data(); (void){name};\n""" CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE = """ std::string {name} = {name}_.owner()->dyn_cast().value().dyn_cast().data(); (void){name};\n""" CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; @@ -867,7 +876,9 @@ def GenBuildOutputs( # scalar elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( - name=op_mutable_attribute_name_list[idx], dtype=attr_dtype[1] + name=op_mutable_attribute_name_list[idx], + dtype=attr_dtype[1], + ir_type=scalar_type_maps[attr_dtype[1]], ) # string elif attr_dtype[0] == "ir::StrAttribute": diff --git a/paddle/fluid/ir/dialect/pd_op.yaml b/paddle/fluid/ir/dialect/pd_op.yaml index c662b1c3d0f3fe..e0ee5e68372c99 100644 --- a/paddle/fluid/ir/dialect/pd_op.yaml +++ b/paddle/fluid/ir/dialect/pd_op.yaml @@ -1,7 +1,7 @@ - name: feed inputs: [] attrs: - - {typename: str, name: name} + - {typename: str, name: str} outputs: - {typename: Tensor, name: out, optional: false, intermediate: false} no_need_buffer: null diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index ce642e98009c73..65f68db2c6a586 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/ir/dialect/pd_type_storage.h" +#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" @@ -71,6 +72,30 @@ static inline ir::Type TransToIrDataType(phi::DataType dtype, } } +static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar, + ir::IrContext *ctx = nullptr) { + if (ctx == nullptr) { + ctx = ir::IrContext::Instance(); + } + switch (scalar.dtype()) { + case phi::DataType::FLOAT32: + return ir::FloatAttribute::get(ctx, scalar.to()); + case phi::DataType::FLOAT64: + return ir::DoubleAttribute::get(ctx, scalar.to()); + case phi::DataType::INT32: + return ir::Int32_tAttribute::get(ctx, scalar.to()); + case phi::DataType::INT64: + return ir::Int64_tAttribute::get(ctx, scalar.to()); + case phi::DataType::BOOL: + return ir::BoolAttribute::get(ctx, scalar.to()); + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported phi data type `%s` when casting it into " + "ir attribute.", + scalar.dtype())); + } +} + struct OpInputInfo { std::string name; std::string type_name; diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.h b/paddle/fluid/ir_adaptor/translator/op_compat_info.h index f2ccba28eb72d6..f1a29f711db066 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.h +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "glog/logging.h" @@ -25,6 +26,8 @@ namespace paddle { namespace translator { +using MutableAttributeInfo = std::vector; + class OpNameNormalizer { private: OpNameNormalizer(); // Disallow instantiation outside of the class. @@ -32,6 +35,12 @@ class OpNameNormalizer { std::unordered_map> op_arg_name_mappings; + std::unordered_map> + op_mutable_attribute_infos; + std::unordered_map> + op_mutable_attributes; + public: OpNameNormalizer(const OpNameNormalizer&) = delete; OpNameNormalizer& operator=(const OpNameNormalizer&) = delete; @@ -50,6 +59,20 @@ class OpNameNormalizer { return op_name_mappings.at(op_type); } + bool HasMutableAttribute(const std::string& op_type) { + return (op_mutable_attributes.find(op_type) != op_mutable_attributes.end()); + } + + const std::unordered_set& GetMutableAttributes( + const std::string& op_type) { + return op_mutable_attributes.at(op_type); + } + + const MutableAttributeInfo& GetMutableAttributeInfos( + const std::string& op_type, const std::string& arg_name) { + return op_mutable_attribute_infos.at(op_type).at(arg_name); + } + std::string GetLegacyArgName(const std::string& op_type, const std::string& arg_name) { bool is_grad_op = (op_type.find("grad") != std::string::npos); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 2495e71653af95..56544ecc1f0f30 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -22,7 +22,6 @@ #include #include -#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" @@ -160,13 +159,18 @@ inline ir::Operation* InsertCombineOperationForTarget( } inline ir::Operation* InsertConstantOperationForOptionalArg( - ir::IrContext* ctx, ir::Program* program) { + ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { std::string constant_op_name(ir::ConstantOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); - ir::Type null_type = ir::Type(nullptr); + ir::Type null_type = paddle::dialect::DenseTensorType::get( + ctx, + ir::Type(nullptr), + paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED, + {}, + 0); // TODO(lyk): to be done ir::Operation* operation = - ir::Operation::create({}, {}, {null_type}, op_info); + ir::Operation::create({}, {{"value", attr}}, {null_type}, op_info); program->block()->push_back(operation); return operation; } @@ -237,6 +241,45 @@ inline std::vector GenerateOperationInput( } } + // TODO(lyk): need optimization + if (!op_normalizer.HasMutableAttribute(op_desc.Type())) { + return op_inputs; + } + + VLOG(10) << "[handle mutable attribute]"; + + const auto& mutable_attributes = + op_normalizer.GetMutableAttributes(op_desc.Type()); + for (const auto& attr_name : mutable_attributes) { + const auto& candidate_var_names = + op_normalizer.GetMutableAttributeInfos(op_desc.Type(), attr_name); + + VLOG(10) << "[handle mutable attribute][" << attr_name << "]"; + for (const auto& var_name : candidate_var_names) { + VLOG(10) << "[handle mutable attribute][" << attr_name << "][" << var_name + << "]"; + if (op_desc.HasInput(var_name)) { + const auto& legacy_input_vars = op_desc.Input(var_name, true); + if (legacy_input_vars.size() < 1) continue; + bool is_vector = false; // TODO(lyk): need to judge by tensor/tensors + + // if src type is Tensor + if (!is_vector) { + auto defining_info = (*param_map)[legacy_input_vars[0]]; + op_inputs.push_back(defining_info.value); + + // if src type is Vector , need an additional `CombineOp` to + // assemble them. + } else { + auto* combine_op = InsertCombineOperationForTarget( + ctx, param_map, program, legacy_input_vars); + op_inputs.push_back(combine_op->GetResultByIndex(0)); + } + break; + } + } + } + return op_inputs; } @@ -327,6 +370,13 @@ inline ir::AttributeMap TranslateOpAttribute( auto& op_normalizer = OpNameNormalizer::instance(); ir::AttributeMap attribute_map = {}; + // TODO(lyk): need optimization + VLOG(10) << "[handle mutable attribute]"; + std::unordered_set mutable_attributes = {}; + if (op_normalizer.HasMutableAttribute(op_desc.Type())) { + mutable_attributes = op_normalizer.GetMutableAttributes(op_desc.Type()); + } + for (const auto& info : op_attr_infos) { auto legacy_attr_name = op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); @@ -338,6 +388,9 @@ inline ir::AttributeMap TranslateOpAttribute( VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name << " " << legacy_attr.index(); ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); + if (mutable_attributes.count(info.name) != 0) { + continue; + } attribute_map[info.name] = new_attr; if (!new_attr) { VLOG(0) << "empty attribute in " << op_desc.Type() diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc deleted file mode 100644 index 4390ba68616d56..00000000000000 --- a/paddle/fluid/translator/attribute_translator.cc +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/translator/attribute_translator.h" - -#include -#include - -#include "paddle/fluid/dialect/pd_attribute.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/int_array.h" -#include "paddle/phi/common/layout.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/common/scalar.h" -#include "paddle/utils/variant.h" - -namespace paddle { -namespace translator { - -class AttributeVisitor { - public: - ir::IrContext* ctx; - AttributeVisitor() { ctx = ir::IrContext::Instance(); } - ~AttributeVisitor() {} - - public: - virtual ir::Attribute operator()(int i) { - VLOG(10) << "translating int"; - return ir::Int32_tAttribute::get(ctx, i); - } - - virtual ir::Attribute operator()(float f) { - VLOG(10) << "translating float"; - return ir::FloatAttribute::get(ctx, f); - } - - virtual ir::Attribute operator()(bool b) { - VLOG(10) << "translating bool"; - return ir::BoolAttribute::get(ctx, b); - } - - virtual ir::Attribute operator()(double d) { - VLOG(10) << "translating double"; - return ir::DoubleAttribute::get(ctx, d); - } - - virtual ir::Attribute operator()(std::string str) { - VLOG(10) << "translating string"; - return ir::StrAttribute::get(ctx, str); - } - - virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { - VLOG(10) << "translating scalar"; - return paddle::dialect::ScalarAttribute::get(ctx, scalar); - } - - virtual ir::Attribute operator()(const std::vector& strs) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(strs.size()); - for (const auto& v : strs) { - attrs.push_back(ir::StrAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& fs) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(fs.size()); - for (const auto& v : fs) { - attrs.push_back(ir::FloatAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& is) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(is.size()); - for (const auto& v : is) { - attrs.push_back(ir::Int32_tAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& bs) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(bs.size()); - for (const auto& v : bs) { - attrs.push_back(ir::BoolAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& i64s) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(i64s.size()); - for (const auto& v : i64s) { - attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& ds) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(ds.size()); - for (const auto& v : ds) { - attrs.push_back(ir::DoubleAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()( - const std::vector& ss) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(ss.size()); - for (const auto& v : ss) { - attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const paddle::blank& blank) { - VLOG(10) << "translating paddle::blank"; - return ir::Attribute(nullptr); - } - - template - ir::Attribute operator()(T attr) { - VLOG(10) << "translating null type"; - return ir::Attribute(nullptr); - } -}; - -class IntArrayAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(const std::vector& is) override { - VLOG(10) << "translating vector to IntArray"; - phi::IntArray data(is); - return paddle::dialect::IntArrayAttribute::get(ctx, data); - } - - ir::Attribute operator()(const std::vector& is) override { - VLOG(10) << "translating vector to IntArray"; - phi::IntArray data(is); - return paddle::dialect::IntArrayAttribute::get(ctx, data); - } -}; - -class ScalarAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(int i) override { - VLOG(10) << "translating int to Scalar"; - phi::Scalar data(i); - return paddle::dialect::ScalarAttribute::get(ctx, data); - } - - ir::Attribute operator()(float f) override { - VLOG(10) << "translating float to Scalar"; - phi::Scalar data(f); - return paddle::dialect::ScalarAttribute::get(ctx, data); - } -}; - -class DataTypeAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(int i) override { - VLOG(10) << "translating int to DataType: " << i; - phi::DataType data = static_cast(i); - return paddle::dialect::DataTypeAttribute::get(ctx, data); - } -}; - -class PlaceAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - - ir::Attribute operator()(const paddle::blank& blank) override { - VLOG(10) << "translating paddle::blank"; - phi::Place data(phi::AllocationType::CPU); - return paddle::dialect::PlaceAttribute::get(ctx, data); - } -}; - -AttributeTranslator::AttributeTranslator() { - general_visitor = new AttributeVisitor(); - special_visitors["paddle::dialect::IntArrayAttribute"] = - new IntArrayAttributeVisitor(); - special_visitors["paddle::dialect::ScalarAttribute"] = - new ScalarAttributeVisitor(); - special_visitors["paddle::dialect::DataTypeAttribute"] = - new DataTypeAttributeVisitor(); - special_visitors["paddle::dialect::PlaceAttribute"] = - new PlaceAttributeVisitor(); -} - -ir::Attribute AttributeTranslator::operator()( - const framework::Attribute& attr) { - return paddle::visit(*general_visitor, attr); -} - -ir::Attribute AttributeTranslator::operator()( - const std::string& target_type, const framework::Attribute& attr) { - if (special_visitors.find(target_type) == special_visitors.end()) { - VLOG(10) << "[" << target_type << "] not found"; - return paddle::visit(*general_visitor, attr); - } - return paddle::visit(*(special_visitors.at(target_type)), attr); -} - -} // namespace translator -} // namespace paddle diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h deleted file mode 100644 index ea509c7e346736..00000000000000 --- a/paddle/fluid/translator/attribute_translator.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/attribute.h" -#include "paddle/fluid/framework/type_defs.h" -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/ir_context.h" - -#pragma once - -namespace paddle { -namespace translator { - -class AttributeVisitor; - -class AttributeTranslator { - private: - AttributeTranslator(); - AttributeVisitor* general_visitor; - std::unordered_map special_visitors; - - public: - AttributeTranslator(const AttributeTranslator&) = delete; - AttributeTranslator& operator=(const AttributeTranslator&) = delete; - AttributeTranslator(AttributeTranslator&&) = delete; - AttributeTranslator& operator=(AttributeTranslator&&) = delete; - - static auto& instance() { - static AttributeTranslator attribute_translator; - return attribute_translator; - } - - ir::Attribute operator()(const framework::Attribute& attr); - ir::Attribute operator()(const std::string& target_type, - const framework::Attribute& attr); -}; - -} // namespace translator -} // namespace paddle diff --git a/paddle/fluid/translator/utils.h b/paddle/fluid/translator/utils.h deleted file mode 100644 index 7065f46992c6aa..00000000000000 --- a/paddle/fluid/translator/utils.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -namespace paddle { -namespace translator { - -static std::string UnderscoreToCamelCase(std::string str) { - std::string camel_case; - bool next_upper = true; - for (char c : str) { - if (c == '_') { - next_upper = true; - } else { - if (next_upper) { - camel_case += toupper(c); - next_upper = false; - } else { - camel_case += c; - } - } - } - return camel_case; -} - -} // namespace translator -} // namespace paddle diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 89576892f66dc7..40c4a0c1224f6e 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -93,6 +93,23 @@ class BasicIRPrinter { os << s.data(); } else if (auto b = attr.dyn_cast()) { os << b.data(); + } else if (auto f = attr.dyn_cast()) { + os << f.data(); + } else if (auto d = attr.dyn_cast()) { + os << d.data(); + } else if (auto i = attr.dyn_cast()) { + os << i.data(); + } else if (auto i = attr.dyn_cast()) { + os << i.data(); + } else if (auto arr = attr.dyn_cast()) { + const auto& vec = arr.data(); + PrintInterleave( + vec.begin(), + vec.end(), + [this](ir::Attribute v) { this->PrintAttribute(v); }, + [this]() { this->os << ", "; }); + } else { + os << "<#CustomTODO>"; } } @@ -167,25 +184,6 @@ class IRPrinter : public BasicIRPrinter { os << new_name; } - /// @brief print operation - /// @param op - /// @example - void PrintOperation(ir::Operation* op) { - PrintOpResult(op); // TODO(lyk): add API to get opresults directly - os << " = "; - - os << "\"" << op->op_name() << "\""; - PrintOpOperands(op); // TODO(lyk): add API to get operands directly - - PrintAttribute(op); - os << " : "; - - // PrintOpSingature - PrintOperandsType(op); - os << " -> "; - PrintOpReturnType(op); // TODO(lyk): add API to get opresults directly - } - void PrintOpResult(ir::Operation* op) { os << " ("; auto num_op_result = op->num_results(); diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 811fcf4c8b68f9..65815a237f08c5 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -48,16 +48,18 @@ ProgramDesc load_from_file(const std::string &file_name) { TEST(PaddleDialectTest, Translator) { auto p = load_from_file("restnet50_main.prog"); + // auto p = + // load_from_file("/home/lvyongkang/Paddle/build/test/cpp/lm_main_program"); EXPECT_EQ(p.Size(), 1u); ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - // auto program = paddle::TranslateLegacyProgramToProgram(p); + auto program = paddle::TranslateLegacyProgramToProgram(p); - // size_t op_size = program->block()->size(); - // // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); + size_t op_size = program->block()->size(); + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); - // program->Print(std::cout); + program->Print(std::cout); } From 63c7dd70c5f0b56a97fc1185ecae6c882e42bc54 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 6 Jun 2023 11:08:01 +0000 Subject: [PATCH 37/43] adpat mutable attribute --- paddle/fluid/ir/dialect/op_gen.py | 22 ++- paddle/fluid/ir/dialect/pd_op.yaml | 2 +- .../ir_adaptor/translator/op_compat_info.h | 5 +- .../ir_adaptor/translator/op_translator.cc | 174 ++++++++++++------ test/cpp/ir/core/program_translator_test.cc | 13 +- 5 files changed, 145 insertions(+), 71 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index d7b1cfbf9734ac..c5bd1cc16f1ddf 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -1193,8 +1193,8 @@ def OpGenerator( # generate get op info funciton: inputs inputs_info_str = "" + input_info_list = [] if len(op_input_name_list) > 0: - input_info_list = [] for idx in range(len(op_input_name_list)): input_info_list.append( CONSTRUCT_INPUT_INFO_TEMPLATE.format( @@ -1204,7 +1204,19 @@ def OpGenerator( no_need_buffer=op_input_no_need_buffer_list[idx], ) ) - inputs_info_str = ", ".join(input_info_list) + + # add mutable attribute as input + if len(op_mutable_attribute_name_list) > 0: + for idx in range(len(op_mutable_attribute_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx], + typename=op_mutable_attribute_type_list[idx], + optional='false', + no_need_buffer='false', + ) + ) + inputs_info_str = ", ".join(input_info_list) # generate get op info funciton: outputs outputs_info_str = "" @@ -1223,12 +1235,16 @@ def OpGenerator( # generate get op info funciton: attributes attribute_info_str = "" + op_mutable_attribute_name_set = set(op_mutable_attribute_name_list) if len(op_attribute_name_list) > 0: attribute_info_list = [] for idx in range(len(op_attribute_name_list)): + attribute_name = op_attribute_name_list[idx] + if attribute_name in op_mutable_attribute_name_set: + continue attribute_info_list.append( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( - name=op_attribute_name_list[idx], + name=attribute_name, typename=op_attribute_type_list[idx], data_type=op_attribute_data_type_list[idx], ) diff --git a/paddle/fluid/ir/dialect/pd_op.yaml b/paddle/fluid/ir/dialect/pd_op.yaml index e0ee5e68372c99..c662b1c3d0f3fe 100644 --- a/paddle/fluid/ir/dialect/pd_op.yaml +++ b/paddle/fluid/ir/dialect/pd_op.yaml @@ -1,7 +1,7 @@ - name: feed inputs: [] attrs: - - {typename: str, name: str} + - {typename: str, name: name} outputs: - {typename: Tensor, name: out, optional: false, intermediate: false} no_need_buffer: null diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.h b/paddle/fluid/ir_adaptor/translator/op_compat_info.h index f1a29f711db066..799e62c7544e3f 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.h +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.h @@ -63,9 +63,10 @@ class OpNameNormalizer { return (op_mutable_attributes.find(op_type) != op_mutable_attributes.end()); } - const std::unordered_set& GetMutableAttributes( + const std::unordered_set* GetMutableAttributes( const std::string& op_type) { - return op_mutable_attributes.at(op_type); + if (!HasMutableAttribute(op_type)) return nullptr; + return &op_mutable_attributes.at(op_type); } const MutableAttributeInfo& GetMutableAttributeInfos( diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 56544ecc1f0f30..3fa7930604a45a 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -23,11 +23,13 @@ #include #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" +#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/ir_context.h" @@ -35,6 +37,10 @@ #include "paddle/ir/core/value.h" #include "paddle/phi/core/enforce.h" +// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in +// paddle/fluid/ir/dialect/CMakeLists.txt. +#include "paddle/fluid/ir/dialect/pd_op.h" + namespace paddle { namespace translator { @@ -66,8 +72,13 @@ inline bool IsInplace(const OpDesc& op_desc) { } auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); + if (input_names.size() == 0 || output_names.size() == 0) { + return inplace; + } std::vector name_intersection; + std::sort(input_names.begin(), input_names.end()); + std::sort(output_names.begin(), output_names.end()); std::set_intersection(input_names.begin(), input_names.end(), output_names.begin(), @@ -158,7 +169,29 @@ inline ir::Operation* InsertCombineOperationForTarget( return operation; } -inline ir::Operation* InsertConstantOperationForOptionalArg( +inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, + ir::Program* program, + ir::Attribute attr) { + float data = 0.0f; + if (attr.isa()) { + data = attr.dyn_cast().data(); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + } + ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block()); + paddle::dialect::FullOp full_op = builder.create( + std::vector{1}, data, phi::DataType::FLOAT32, phi::CPUPlace()); + + return full_op.operation(); +} + +inline ir::Operation* InsertFullArrayOperationForAttributeInput( ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { std::string constant_op_name(ir::ConstantOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); @@ -166,8 +199,9 @@ inline ir::Operation* InsertConstantOperationForOptionalArg( ir::Type null_type = paddle::dialect::DenseTensorType::get( ctx, ir::Type(nullptr), + phi::DDim{}, paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED, - {}, + phi::LoD{}, 0); // TODO(lyk): to be done ir::Operation* operation = ir::Operation::create({}, {{"value", attr}}, {null_type}, op_info); @@ -175,6 +209,41 @@ inline ir::Operation* InsertConstantOperationForOptionalArg( return operation; } +inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, + ir::Program* program, + const OpDesc& op_desc, + const OpInputInfo& input_info) { + auto& attribute_translator = AttributeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); + + auto legacy_attr_name = + op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name); + + if (!op_desc.HasAttr(legacy_attr_name)) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Op %s arg %s should not be zero size", + op_desc.Type(), + legacy_attr_name)); + } + paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); + VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name + << " " << legacy_attr.index(); + ir::Attribute new_attr = + attribute_translator(input_info.type_name, legacy_attr); + + ir::Operation* defining_op = nullptr; + bool is_int_array = (input_info.type_name.find("IntArrayAttribute") != + input_info.type_name.npos); + if (is_int_array) { + defining_op = + InsertFullArrayOperationForAttributeInput(ctx, program, new_attr); + } else { + defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr); + } + + return defining_op->GetResultByIndex(0); +} + inline std::vector GenerateOperationInput( ir::IrContext* ctx, TranslationContext* param_map, @@ -207,25 +276,59 @@ inline std::vector GenerateOperationInput( std::vector op_inputs; auto& op_normalizer = OpNameNormalizer::instance(); + const auto* mutable_attributes = + op_normalizer.GetMutableAttributes(op_desc.Type()); for (const auto& info : input_infos) { std::string legacy_input_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); + VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " + << legacy_input_name; + + std::vector legacy_input_vars; // return empty OpResult if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute - if (!op_desc.HasInput(legacy_input_name)) { - PADDLE_ENFORCE(info.optional, - platform::errors::PreconditionNotMet( - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_input_name)); - op_inputs.push_back(ir::OpResult(nullptr)); - continue; + if (op_desc.HasInput(legacy_input_name)) { + legacy_input_vars = op_desc.Input(legacy_input_name, true); + } + + if (legacy_input_vars.size() == 0) { + if (info.optional) { + op_inputs.push_back(ir::OpResult(nullptr)); + continue; + } + } + + VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " + << legacy_input_name << " " << legacy_input_vars.size(); + + if (legacy_input_vars.size() == 0 && mutable_attributes != nullptr && + mutable_attributes->count(info.name) != 0) { + const auto& candidate_var_names = + op_normalizer.GetMutableAttributeInfos(op_desc.Type(), info.name); + bool found_candidate_var = false; + for (const auto& var_name : candidate_var_names) { + VLOG(10) << "[handle mutable attribute][" << info.name << "][" + << var_name << "]"; + if (op_desc.HasInput(var_name)) { + legacy_input_vars = op_desc.Input(var_name, true); + if (legacy_input_vars.size() == 0) continue; + found_candidate_var = true; + break; + } + } + + if (!found_candidate_var) { + auto attribute_input = GetAttributeAsInput(ctx, program, op_desc, info); + op_inputs.push_back(attribute_input); + continue; + } } - const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); bool is_vector = (info.type_name.find("VectorType") != std::string::npos); + VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " + << is_vector << " " << info.type_name; // if src type is Tensor if (!is_vector) { @@ -241,45 +344,6 @@ inline std::vector GenerateOperationInput( } } - // TODO(lyk): need optimization - if (!op_normalizer.HasMutableAttribute(op_desc.Type())) { - return op_inputs; - } - - VLOG(10) << "[handle mutable attribute]"; - - const auto& mutable_attributes = - op_normalizer.GetMutableAttributes(op_desc.Type()); - for (const auto& attr_name : mutable_attributes) { - const auto& candidate_var_names = - op_normalizer.GetMutableAttributeInfos(op_desc.Type(), attr_name); - - VLOG(10) << "[handle mutable attribute][" << attr_name << "]"; - for (const auto& var_name : candidate_var_names) { - VLOG(10) << "[handle mutable attribute][" << attr_name << "][" << var_name - << "]"; - if (op_desc.HasInput(var_name)) { - const auto& legacy_input_vars = op_desc.Input(var_name, true); - if (legacy_input_vars.size() < 1) continue; - bool is_vector = false; // TODO(lyk): need to judge by tensor/tensors - - // if src type is Tensor - if (!is_vector) { - auto defining_info = (*param_map)[legacy_input_vars[0]]; - op_inputs.push_back(defining_info.value); - - // if src type is Vector , need an additional `CombineOp` to - // assemble them. - } else { - auto* combine_op = InsertCombineOperationForTarget( - ctx, param_map, program, legacy_input_vars); - op_inputs.push_back(combine_op->GetResultByIndex(0)); - } - break; - } - } - } - return op_inputs; } @@ -370,13 +434,6 @@ inline ir::AttributeMap TranslateOpAttribute( auto& op_normalizer = OpNameNormalizer::instance(); ir::AttributeMap attribute_map = {}; - // TODO(lyk): need optimization - VLOG(10) << "[handle mutable attribute]"; - std::unordered_set mutable_attributes = {}; - if (op_normalizer.HasMutableAttribute(op_desc.Type())) { - mutable_attributes = op_normalizer.GetMutableAttributes(op_desc.Type()); - } - for (const auto& info : op_attr_infos) { auto legacy_attr_name = op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); @@ -388,9 +445,6 @@ inline ir::AttributeMap TranslateOpAttribute( VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name << " " << legacy_attr.index(); ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); - if (mutable_attributes.count(info.name) != 0) { - continue; - } attribute_map[info.name] = new_attr; if (!new_attr) { VLOG(0) << "empty attribute in " << op_desc.Type() diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 65815a237f08c5..6a810015c1fb51 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -47,9 +47,10 @@ ProgramDesc load_from_file(const std::string &file_name) { } TEST(PaddleDialectTest, Translator) { - auto p = load_from_file("restnet50_main.prog"); - // auto p = - // load_from_file("/home/lvyongkang/Paddle/build/test/cpp/lm_main_program"); + // auto p = load_from_file("restnet50_main.prog"); + auto p = load_from_file( + "/home/lvyongkang/Paddle/test_program/" + "resnet50_main_no_merged_momentum.prog"); EXPECT_EQ(p.Size(), 1u); ir::IrContext *ctx = ir::IrContext::Instance(); @@ -58,8 +59,10 @@ TEST(PaddleDialectTest, Translator) { auto program = paddle::TranslateLegacyProgramToProgram(p); size_t op_size = program->block()->size(); - // ops.size() = op size in BlockDesc + get_parameter_op + combine op - EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int + // array op + full op + EXPECT_EQ(op_size, + p.Block(0).OpSize() + program->parameters_num() + 16 + 1 + 8); program->Print(std::cout); } From 85167e2c862810d8a3bcb5658795cffa8678af54 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 6 Jun 2023 12:42:58 +0000 Subject: [PATCH 38/43] refine op_comat_gen process --- .../ir_adaptor/translator/CMakeLists.txt | 4 +- .../ir_adaptor/translator/op_compat_gen.py | 37 +++++++++++++++++-- .../translator/op_compat_info.cc.j2 | 31 ++++++++++++++++ 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt index c2a66c1e71318f..2f0014c69f74c4 100644 --- a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt +++ b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt @@ -5,12 +5,14 @@ set(PD_PROGRAM_TRANSLATOR_BINARY_DIR set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py) set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc) +set(op_compat_templat_file + ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc.j2) add_custom_command( OUTPUT ${op_compat_source_file} COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file ${op_compat_yaml_file} --output_source_file ${op_compat_source_file} - DEPENDS ${op_gen_file} ${op_compat_yaml_file} + DEPENDS ${op_gen_file} ${op_compat_yaml_file} ${op_compat_templat_file} VERBATIM) file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 5bc9df7ee8b34b..9f56a324c259e1 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -14,7 +14,7 @@ import argparse from pathlib import Path -from typing import Dict +from typing import Dict, List, Set import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined @@ -46,8 +46,11 @@ def to_phi_and_fluid_op_name(op_item): with open(op_compat_yaml_file, "r") as f: op_compat_infos = yaml.safe_load(f) - op_name_mappings = {} - op_arg_name_mappings = {} + op_name_mappings: Dict[str, str] = {} + op_arg_name_mappings: Dict[str, Dict[str, str]] = {} + op_mutable_attribues: Dict[str, Set[str]] = {} + op_mutable_attribute_infos: Dict[str, Dict[str, List[str]]] = {} + for op_compat_item in op_compat_infos: def insert_new_mappings(op_name_str: str) -> str: @@ -64,6 +67,24 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): op_arg_name_mappings[op_name] = {} op_arg_name_mappings[op_name].update(arg_mapping) + def insert_new_mutable_attributes( + op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]] + ): + op_mutable_attribues[op_name] = set() + op_mutable_attribute_infos[op_name] = {} + for ( + attribute_name, + mutable_attribute_info, + ) in mutable_attribute_infos.items(): + print(attribute_name, mutable_attribute_info) + op_mutable_attribues[op_name].add(attribute_name) + op_mutable_attribute_infos[op_name][attribute_name] = [] + for k, v in mutable_attribute_info.items(): + if k == 'tensor_name' or k == 'tensors_name': + op_mutable_attribute_infos[op_name][ + attribute_name + ].append(v) + _, legacy_name = insert_new_mappings(op_compat_item["op"]) legacy_backward_op_names = [] if "backward" in op_compat_item: @@ -88,6 +109,14 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): for backward_op in legacy_backward_op_names: insert_new_arg_mappings(backward_op, op_compat_item["outputs"]) + if "int_array" in op_compat_item: + insert_new_mutable_attributes( + legacy_name, op_compat_item["int_array"] + ) + + if "scalar" in op_compat_item: + insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) + # special op mappings op_name_mappings["fetch_v2"] = "fetch" @@ -96,6 +125,8 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): op_compat_definition = op_name_normailzer_template.render( op_name_pairs=op_name_mappings, op_arg_name_pairs=op_arg_name_mappings, + op_mutable_attributes=op_mutable_attribues, + op_mutable_attribute_infos=op_mutable_attribute_infos, ) f.write(op_compat_definition) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 b/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 index bfc80986a34c98..e7b7812fe61bea 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 @@ -21,6 +21,37 @@ OpNameNormalizer::OpNameNormalizer() { }, {% endfor %} }; + op_mutable_attributes = { + {% for op_name, mutable_attributes in op_mutable_attributes.items() %} + { + "{{op_name}}", + { + {% for attribute_name in mutable_attributes %} + "{{attribute_name}}", + {% endfor %} + }, + }, + {% endfor %} + }; + op_mutable_attribute_infos = { + {% for op_name, mutable_attribute_infos in op_mutable_attribute_infos.items() %} + { + "{{op_name}}", + { + {% for attribute_name, attribute_info in mutable_attribute_infos.items() %} + { + "{{attribute_name}}", + { + {% for candidate_var_name in attribute_info %} + "{{candidate_var_name}}", + {% endfor %} + }, + }, + {% endfor %} + }, + }, + {% endfor %} + }; } } // namespace translator From 623c2a3c9a10967d900ccde678266e68b6485bf6 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 7 Jun 2023 03:59:14 +0000 Subject: [PATCH 39/43] fix merge conflicts --- .../ir_adaptor/translator/op_translator.cc | 22 +------------------ test/cpp/ir/core/ir_exe_test.cc | 6 ++--- test/cpp/ir/core/phi_kernel_adaptor.h | 3 --- test/cpp/ir/core/program_translator_test.cc | 2 +- 4 files changed, 4 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 1028d447736902..6f0e01b52a94e0 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -185,7 +185,7 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, data = static_cast(attr.dyn_cast().data()); } ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block()); - paddle::dialect::FullOp full_op = builder.Create( + paddle::dialect::FullOp full_op = builder.Build( std::vector{1}, data, phi::DataType::FLOAT32, phi::CPUPlace()); return full_op.operation(); @@ -325,26 +325,6 @@ inline std::vector GenerateOperationInput( continue; } } - } - - VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " - << legacy_input_name << " " << legacy_input_vars.size(); - - if (legacy_input_vars.size() == 0 && mutable_attributes != nullptr && - mutable_attributes->count(info.name) != 0) { - const auto& candidate_var_names = - op_normalizer.GetMutableAttributeInfos(op_desc.Type(), info.name); - bool found_candidate_var = false; - for (const auto& var_name : candidate_var_names) { - VLOG(10) << "[handle mutable attribute][" << info.name << "][" << var_name - << "]"; - if (op_desc.HasInput(var_name)) { - legacy_input_vars = op_desc.Input(var_name, true); - if (legacy_input_vars.size() == 0) continue; - found_candidate_var = true; - break; - } - } bool is_vector = (info.type_name.find("VectorType") != std::string::npos); VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " diff --git a/test/cpp/ir/core/ir_exe_test.cc b/test/cpp/ir/core/ir_exe_test.cc index e8227982fbeaf2..abb814d2890fb4 100644 --- a/test/cpp/ir/core/ir_exe_test.cc +++ b/test/cpp/ir/core/ir_exe_test.cc @@ -74,10 +74,8 @@ TEST(program_test, program) { ctx, std::vector({2, 2})); ir::Attribute data_type = paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32); - ir::Attribute min = - paddle::dialect::ScalarAttribute::get(ctx, phi::Scalar(0.0)); - ir::Attribute max = - paddle::dialect::ScalarAttribute::get(ctx, phi::Scalar(1.0)); + ir::Attribute min = ir::FloatAttribute::get(ctx, 0.0f); + ir::Attribute max = ir::FloatAttribute::get(ctx, 1.0f); ir::Attribute seed = ir::Int32_tAttribute::get(ctx, 2); ir::Attribute uni_place = paddle::dialect::PlaceAttribute::get( ctx, phi::Place(phi::AllocationType::CPU)); diff --git a/test/cpp/ir/core/phi_kernel_adaptor.h b/test/cpp/ir/core/phi_kernel_adaptor.h index 601e37649c5fa5..4916582dfbf0b3 100644 --- a/test/cpp/ir/core/phi_kernel_adaptor.h +++ b/test/cpp/ir/core/phi_kernel_adaptor.h @@ -143,9 +143,6 @@ void build_context(ir::Operation* op, } else if (type_name == "paddle::dialect::DataTypeAttribute") { ctx->EmplaceBackAttr( attr_map[t].dyn_cast().data()); - } else if (type_name == "paddle::dialect::ScalarAttribute") { - ctx->EmplaceBackAttr( - attr_map[t].dyn_cast().data()); } else if (type_name == "ir::Int32_tAttribute") { ctx->EmplaceBackAttr( attr_map[t].dyn_cast().data()); diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 6a810015c1fb51..65bcd599bf4a66 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -62,7 +62,7 @@ TEST(PaddleDialectTest, Translator) { // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int // array op + full op EXPECT_EQ(op_size, - p.Block(0).OpSize() + program->parameters_num() + 16 + 1 + 8); + p.Block(0).OpSize() + program->parameters_num() + 16 + 3 + 8); program->Print(std::cout); } From c7332359cf921a6f521c7f873b9a0c5627722d12 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 7 Jun 2023 04:02:13 +0000 Subject: [PATCH 40/43] fix merge conflicts --- paddle/fluid/ir_adaptor/translator/op_compat_gen.py | 1 - paddle/fluid/ir_adaptor/translator/op_translator.cc | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 9f56a324c259e1..5a852754aed1ec 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -76,7 +76,6 @@ def insert_new_mutable_attributes( attribute_name, mutable_attribute_info, ) in mutable_attribute_infos.items(): - print(attribute_name, mutable_attribute_info) op_mutable_attribues[op_name].add(attribute_name) op_mutable_attribute_infos[op_name][attribute_name] = [] for k, v in mutable_attribute_info.items(): diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 6f0e01b52a94e0..d684aa718ae72a 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -226,8 +226,8 @@ inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, legacy_attr_name)); } paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); - VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name - << " " << legacy_attr.index(); + VLOG(10) << "[" << op_desc.Type() << "][attribute]" + << " name: " << legacy_attr_name << " " << legacy_attr.index(); ir::Attribute new_attr = attribute_translator(input_info.type_name, legacy_attr); From e1c16e7ca7e81f13606fabbbc0e9790e48e6d92c Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 7 Jun 2023 06:19:38 +0000 Subject: [PATCH 41/43] fix merge conflicts --- test/cpp/ir/core/program_translator_test.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 65bcd599bf4a66..d922a35d40b76f 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -47,10 +47,7 @@ ProgramDesc load_from_file(const std::string &file_name) { } TEST(PaddleDialectTest, Translator) { - // auto p = load_from_file("restnet50_main.prog"); - auto p = load_from_file( - "/home/lvyongkang/Paddle/test_program/" - "resnet50_main_no_merged_momentum.prog"); + auto p = load_from_file("restnet50_main.prog"); EXPECT_EQ(p.Size(), 1u); ir::IrContext *ctx = ir::IrContext::Instance(); From 84636fe25b0ac41e2b118d48866277676017e5b4 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 7 Jun 2023 08:23:17 +0000 Subject: [PATCH 42/43] complete dialect attribute printer and refine ir_throw --- paddle/fluid/ir/dialect/pd_dialect.cc | 24 +++++++++- paddle/fluid/ir/dialect/pd_dialect.h | 3 +- .../ir_adaptor/translator/op_translator.cc | 38 ++++++++------- paddle/ir/core/attribute.h | 6 +++ paddle/ir/core/dialect.h | 10 +++- paddle/ir/core/enforce.h | 48 +++++++++++-------- paddle/ir/core/ir_printer.cc | 20 +++++++- paddle/ir/core/type.cc | 6 --- test/cpp/ir/core/program_translator_test.cc | 2 +- 9 files changed, 108 insertions(+), 49 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_dialect.cc b/paddle/fluid/ir/dialect/pd_dialect.cc index d7b4b599b55fe6..6c9d25e0d92f5a 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/pd_dialect.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/ir/dialect/pd_type_storage.h" #include "paddle/fluid/ir/dialect/utils.h" #include "paddle/ir/core/dialect_interface.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" namespace paddle { @@ -107,7 +108,7 @@ void PaddleDialect::initialize() { RegisterInterfaces(); } -void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { +void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { DenseTensorType tensor_type = type.dyn_cast(); os << "tensor<"; @@ -119,5 +120,26 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { os << ">"; } +void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { + if (auto int_array_attr = attr.dyn_cast()) { + phi::IntArray data = int_array_attr.data(); + os << "IntArray<"; + const auto &inner_data = data.GetData(); + for (const auto &i : inner_data) { + os << i; + os << " "; + } + os << ">"; + } else if (auto data_type_attr = attr.dyn_cast()) { + os << data_type_attr.data(); + } else if (auto place_type_attr = attr.dyn_cast()) { + os << place_type_attr.data(); + } else if (auto data_layout_attr = attr.dyn_cast()) { + os << data_layout_attr.data(); + } else { + os << "<#AttrNotImplemented>"; + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_dialect.h b/paddle/fluid/ir/dialect/pd_dialect.h index 069827bedcf9a1..b8782c156d8851 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.h +++ b/paddle/fluid/ir/dialect/pd_dialect.h @@ -39,7 +39,8 @@ class PaddleDialect : public ir::Dialect { static const char* name() { return "pd"; } - void PrintType(ir::Type type, std::ostream& os); + void PrintType(ir::Type type, std::ostream& os) const; + void PrintAttribute(ir::Attribute type, std::ostream& os) const; private: void initialize(); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 4fb77c8466fde5..3e2c8117897cee 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -114,10 +114,9 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { << target_op_name; auto op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW(platform::errors::PreconditionNotMet( - "Op %d should have corresponding OpInfo %d", - op_desc.Type(), - target_op_name)); + IR_THROW("Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name); } return op_info; @@ -173,20 +172,26 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { float data = 0.0f; + phi::DataType dtype = phi::DataType::UNDEFINED; if (attr.isa()) { data = attr.dyn_cast().data(); + dtype = phi::DataType::FLOAT32; } else if (attr.isa()) { data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::FLOAT64; } else if (attr.isa()) { data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::INT32; } else if (attr.isa()) { data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::INT64; } else if (attr.isa()) { data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::BOOL; } ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block()); paddle::dialect::FullOp full_op = builder.Build( - std::vector{1}, data, phi::DataType::FLOAT32, phi::CPUPlace()); + std::vector{1}, data, dtype, phi::CPUPlace()); return full_op.operation(); } @@ -220,10 +225,9 @@ inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name); if (!op_desc.HasAttr(legacy_attr_name)) { - IR_THROW(platform::errors::PreconditionNotMet( - "Op %s arg %s should not be zero size", - op_desc.Type(), - legacy_attr_name)); + IR_THROW("Op %s arg %s should not be zero size", + op_desc.Type(), + legacy_attr_name); } paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); VLOG(10) << "[" << op_desc.Type() << "][attribute]" @@ -259,11 +263,10 @@ inline std::vector GenerateOperationInput( for (const auto& arg_name : args) { IR_ENFORCE(param_map->count(arg_name) != 0, - platform::errors::PreconditionNotMet( - "arg %s.%s as input should be exists before prasing %s", - name, - arg_name, - op_desc.Type())); + "arg %s.%s as input should be exists before prasing %s", + name, + arg_name, + op_desc.Type()); auto defining_info = (*param_map)[arg_name]; if (defining_info.generated_by_vector) { InsertSliceOperationForTarget( @@ -369,10 +372,9 @@ inline std::tuple GenerateOperationOutput( << "[" << op_desc.Type() << "] optional " << info.name << " :" << info.type_name << " " << legacy_output_name; IR_ENFORCE(info.optional, - platform::errors::PreconditionNotMet( - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_output_name)); + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name); op_output_types.push_back(ir::Type(nullptr)); continue; } diff --git a/paddle/ir/core/attribute.h b/paddle/ir/core/attribute.h index 4f269187b751b3..ea7b0f5daae811 100644 --- a/paddle/ir/core/attribute.h +++ b/paddle/ir/core/attribute.h @@ -60,6 +60,10 @@ class Attribute { IrContext *ir_context() const; + /// @brief print attribute + /// @param os + void Print(std::ostream &os) const; + /// /// \brief Methods for type judgment and cast. /// @@ -80,6 +84,8 @@ class Attribute { protected: const Storage *storage_{nullptr}; }; + +std::ostream &operator<<(std::ostream &os, Attribute attr); } // namespace ir namespace std { diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index 3421b9d942f6dd..5cd932920056d0 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -16,8 +16,10 @@ #include +#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/dialect_interface.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_base.h" #include "paddle/ir/core/type_base.h" @@ -130,8 +132,12 @@ class Dialect { return *interface; } - virtual void PrintType(ir::Type type, std::ostream &os) { - throw std::logic_error("dialect has no registered type printing hook"); + virtual void PrintType(ir::Type type, std::ostream &os) const { + IR_THROW("dialect has no registered type printing hook"); + } + + virtual void PrintAttribute(ir::Attribute type, std::ostream &os) const { + IR_THROW("dialect has no registered attribute printing hook"); } private: diff --git a/paddle/ir/core/enforce.h b/paddle/ir/core/enforce.h index b5c48c22a83dc9..e87ac0c41a07ee 100644 --- a/paddle/ir/core/enforce.h +++ b/paddle/ir/core/enforce.h @@ -17,6 +17,8 @@ #include #include +#include "paddle/utils/string/printf.h" + #if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #else @@ -37,27 +39,35 @@ class IrNotMetException : public std::exception { std::string err_str_; }; -#define IR_THROW(...) \ - do { \ - try { \ - throw ir::IrNotMetException(__VA_ARGS__); \ - } catch (const std::exception& e) { \ - std::cout << e.what() << std::endl; \ - throw; \ - } \ +#define IR_THROW(...) \ + do { \ + try { \ + throw ir::IrNotMetException( \ + paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \ + __FILE__, \ + __LINE__, \ + paddle::string::Sprintf(__VA_ARGS__))); \ + } catch (const std::exception& e) { \ + std::cout << e.what() << std::endl; \ + throw; \ + } \ } while (0) -#define IR_ENFORCE(COND, ...) \ - do { \ - auto __cond__ = (COND); \ - if (UNLIKELY(is_error(__cond__))) { \ - try { \ - throw ir::IrNotMetException(__VA_ARGS__); \ - } catch (const std::exception& e) { \ - std::cout << e.what() << std::endl; \ - throw; \ - } \ - } \ +#define IR_ENFORCE(COND, ...) \ + do { \ + auto __cond__ = (COND); \ + if (UNLIKELY(is_error(__cond__))) { \ + try { \ + throw ir::IrNotMetException( \ + paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \ + __FILE__, \ + __LINE__, \ + paddle::string::Sprintf(__VA_ARGS__))); \ + } catch (const std::exception& e) { \ + std::cout << e.what() << std::endl; \ + throw; \ + } \ + } \ } while (0) } // namespace ir diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 40c4a0c1224f6e..f809a20253bd73 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -103,13 +103,16 @@ class BasicIRPrinter { os << i.data(); } else if (auto arr = attr.dyn_cast()) { const auto& vec = arr.data(); + os << "array<"; PrintInterleave( vec.begin(), vec.end(), [this](ir::Attribute v) { this->PrintAttribute(v); }, [this]() { this->os << ", "; }); + os << ">"; } else { - os << "<#CustomTODO>"; + auto& dialect = attr.dialect(); + dialect.PrintAttribute(attr, os); } } @@ -292,4 +295,19 @@ void Type::Print(std::ostream& os) const { printer.PrintType(*this); } +void Attribute::Print(std::ostream& os) const { + BasicIRPrinter printer(os); + printer.PrintAttribute(*this); +} + +std::ostream& operator<<(std::ostream& os, Type type) { + type.Print(os); + return os; +} + +std::ostream& operator<<(std::ostream& os, Attribute attr) { + attr.Print(os); + return os; +} + } // namespace ir diff --git a/paddle/ir/core/type.cc b/paddle/ir/core/type.cc index e93d9f63e8c6f5..8b1451fa76fb72 100644 --- a/paddle/ir/core/type.cc +++ b/paddle/ir/core/type.cc @@ -17,10 +17,4 @@ namespace ir { IrContext* Type::ir_context() const { return dialect().ir_context(); } - -std::ostream& operator<<(std::ostream& os, Type type) { - type.Print(os); - return os; -} - } // namespace ir diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index d922a35d40b76f..c8824a5f7c8998 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -59,7 +59,7 @@ TEST(PaddleDialectTest, Translator) { // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int // array op + full op EXPECT_EQ(op_size, - p.Block(0).OpSize() + program->parameters_num() + 16 + 3 + 8); + p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8); program->Print(std::cout); } From d731403c507af1fd863f8424be66f70b2748ab2b Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 7 Jun 2023 10:33:56 +0000 Subject: [PATCH 43/43] polish code --- paddle/fluid/ir/dialect/pd_dialect.cc | 14 ++-- paddle/ir/core/dialect.h | 14 ++-- paddle/ir/core/ir_printer.cc | 106 +++++++++++--------------- paddle/ir/core/utils.h | 14 ++++ 4 files changed, 74 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_dialect.cc b/paddle/fluid/ir/dialect/pd_dialect.cc index 6c9d25e0d92f5a..b347d85d2a1cf0 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/pd_dialect.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/ir/dialect/pd_type_storage.h" #include "paddle/fluid/ir/dialect/utils.h" #include "paddle/ir/core/dialect_interface.h" +#include "paddle/ir/core/utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" @@ -123,13 +124,14 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { if (auto int_array_attr = attr.dyn_cast()) { phi::IntArray data = int_array_attr.data(); - os << "IntArray<"; + os << "IntArray["; const auto &inner_data = data.GetData(); - for (const auto &i : inner_data) { - os << i; - os << " "; - } - os << ">"; + ir::PrintInterleave( + inner_data.begin(), + inner_data.end(), + [&os](int64_t i) { os << i; }, + [&os]() { os << ","; }); + os << "]"; } else if (auto data_type_attr = attr.dyn_cast()) { os << data_type_attr.data(); } else if (auto place_type_attr = attr.dyn_cast()) { diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index 5cd932920056d0..1eabc8010d670d 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -35,15 +35,15 @@ class DialectInterface; /// class Dialect { public: - Dialect(std::string name, ir::IrContext *context, ir::TypeId id); + Dialect(std::string name, IrContext *context, TypeId id); virtual ~Dialect(); const std::string &name() const { return name_; } - ir::IrContext *ir_context() const { return context_; } + IrContext *ir_context() const { return context_; } - ir::TypeId id() const { return id_; } + TypeId id() const { return id_; } /// /// \brief Register all types contained in the template parameter Args. @@ -132,11 +132,11 @@ class Dialect { return *interface; } - virtual void PrintType(ir::Type type, std::ostream &os) const { + virtual void PrintType(Type type, std::ostream &os) const { IR_THROW("dialect has no registered type printing hook"); } - virtual void PrintAttribute(ir::Attribute type, std::ostream &os) const { + virtual void PrintAttribute(Attribute type, std::ostream &os) const { IR_THROW("dialect has no registered attribute printing hook"); } @@ -147,9 +147,9 @@ class Dialect { std::string name_; - ir::IrContext *context_; // not owned + IrContext *context_; // not owned - ir::TypeId id_; + TypeId id_; std::unordered_map> registered_interfaces_; diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index f809a20253bd73..fd0a41fbdae309 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -23,93 +23,79 @@ #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/program.h" +#include "paddle/ir/core/utils.h" #include "paddle/ir/core/value.h" namespace ir { namespace { constexpr char newline[] = "\n"; - -template -void PrintInterleave(ForwardIterator begin, - ForwardIterator end, - UnaryFunctor print_func, - NullFunctor between_func) { - if (begin == end) return; - print_func(*begin); - begin++; - for (; begin != end; begin++) { - between_func(); - print_func(*begin); - } -} - } // namespace class BasicIRPrinter { public: explicit BasicIRPrinter(std::ostream& os) : os(os) {} - void PrintType(ir::Type type) { + void PrintType(Type type) { if (!type) { os << "<>"; return; } - if (type.isa()) { + if (type.isa()) { os << "f16"; - } else if (type.isa()) { + } else if (type.isa()) { os << "f32"; - } else if (type.isa()) { + } else if (type.isa()) { os << "f64"; - } else if (type.isa()) { + } else if (type.isa()) { os << "i16"; - } else if (type.isa()) { + } else if (type.isa()) { os << "i32"; - } else if (type.isa()) { + } else if (type.isa()) { os << "i64"; - } else if (type.isa()) { - os << "vec<"; - auto inner_types = type.dyn_cast().data(); + } else if (type.isa()) { + os << "vec["; + auto inner_types = type.dyn_cast().data(); PrintInterleave( inner_types.begin(), inner_types.end(), - [this](ir::Type v) { this->PrintType(v); }, - [this]() { this->os << ", "; }); - os << ">"; + [this](Type v) { this->PrintType(v); }, + [this]() { this->os << ","; }); + os << "]"; } else { auto& dialect = type.dialect(); dialect.PrintType(type, os); } } - void PrintAttribute(const ir::Attribute& attr) { + void PrintAttribute(const Attribute& attr) { if (!attr) { os << "<#AttrNull>"; return; } - if (auto s = attr.dyn_cast()) { + if (auto s = attr.dyn_cast()) { os << s.data(); - } else if (auto b = attr.dyn_cast()) { + } else if (auto b = attr.dyn_cast()) { os << b.data(); - } else if (auto f = attr.dyn_cast()) { + } else if (auto f = attr.dyn_cast()) { os << f.data(); - } else if (auto d = attr.dyn_cast()) { + } else if (auto d = attr.dyn_cast()) { os << d.data(); - } else if (auto i = attr.dyn_cast()) { + } else if (auto i = attr.dyn_cast()) { os << i.data(); - } else if (auto i = attr.dyn_cast()) { + } else if (auto i = attr.dyn_cast()) { os << i.data(); - } else if (auto arr = attr.dyn_cast()) { + } else if (auto arr = attr.dyn_cast()) { const auto& vec = arr.data(); - os << "array<"; + os << "array["; PrintInterleave( vec.begin(), vec.end(), - [this](ir::Attribute v) { this->PrintAttribute(v); }, - [this]() { this->os << ", "; }); - os << ">"; + [this](Attribute v) { this->PrintAttribute(v); }, + [this]() { this->os << ","; }); + os << "]"; } else { auto& dialect = attr.dialect(); dialect.PrintAttribute(attr, os); @@ -127,14 +113,12 @@ class IRPrinter : public BasicIRPrinter { /// @brief print program /// @param program /// @example - void PrintProgram(ir::Program* program) { - PrintOperation(program->module_op()); - } + void PrintProgram(Program* program) { PrintOperation(program->module_op()); } /// @brief print operation /// @param op /// @example - void PrintOperation(ir::Operation* op) { + void PrintOperation(Operation* op) { for (size_t i = 0; i < op->num_regions(); ++i) { auto& region = op->GetRegion(i); for (auto it = region.begin(); it != region.end(); ++it) { @@ -169,7 +153,7 @@ class IRPrinter : public BasicIRPrinter { } private: - void PrintValue(ir::Value v) { + void PrintValue(Value v) { if (!v) { os << "<>"; return; @@ -187,10 +171,10 @@ class IRPrinter : public BasicIRPrinter { os << new_name; } - void PrintOpResult(ir::Operation* op) { + void PrintOpResult(Operation* op) { os << " ("; auto num_op_result = op->num_results(); - std::vector op_results; + std::vector op_results; op_results.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { op_results.push_back(op->GetResultByIndex(idx)); @@ -198,13 +182,13 @@ class IRPrinter : public BasicIRPrinter { PrintInterleave( op_results.begin(), op_results.end(), - [this](ir::Value v) { this->PrintValue(v); }, + [this](Value v) { this->PrintValue(v); }, [this]() { this->os << ", "; }); os << ")"; } - void PrintAttributeMap(ir::Operation* op) { - os << "{"; + void PrintAttributeMap(Operation* op) { + os << " {"; PrintInterleave( op->attributes().begin(), @@ -219,10 +203,10 @@ class IRPrinter : public BasicIRPrinter { os << "}"; } - void PrintOpOperands(ir::Operation* op) { + void PrintOpOperands(Operation* op) { os << " ("; auto num_op_operands = op->num_operands(); - std::vector op_operands; + std::vector op_operands; op_operands.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { op_operands.push_back(op->GetOperandByIndex(idx).source()); @@ -230,48 +214,48 @@ class IRPrinter : public BasicIRPrinter { PrintInterleave( op_operands.begin(), op_operands.end(), - [this](ir::Value v) { this->PrintValue(v); }, + [this](Value v) { this->PrintValue(v); }, [this]() { this->os << ", "; }); os << ")"; } - void PrintOperandsType(ir::Operation* op) { + void PrintOperandsType(Operation* op) { auto num_op_operands = op->num_operands(); - std::vector op_operand_types; + std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { auto op_operand = op->GetOperandByIndex(idx); if (op_operand) { op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); } else { - op_operand_types.push_back(ir::Type(nullptr)); + op_operand_types.push_back(Type(nullptr)); } } os << " ("; PrintInterleave( op_operand_types.begin(), op_operand_types.end(), - [this](ir::Type t) { this->PrintType(t); }, + [this](Type t) { this->PrintType(t); }, [this]() { this->os << ", "; }); os << ")"; } - void PrintOpReturnType(ir::Operation* op) { + void PrintOpReturnType(Operation* op) { auto num_op_result = op->num_results(); - std::vector op_result_types; + std::vector op_result_types; op_result_types.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { auto op_result = op->GetResultByIndex(idx); if (op_result) { op_result_types.push_back(op_result.type()); } else { - op_result_types.push_back(ir::Type(nullptr)); + op_result_types.push_back(Type(nullptr)); } } PrintInterleave( op_result_types.begin(), op_result_types.end(), - [this](ir::Type t) { this->PrintType(t); }, + [this](Type t) { this->PrintType(t); }, [this]() { this->os << ", "; }); } diff --git a/paddle/ir/core/utils.h b/paddle/ir/core/utils.h index f4316a7e57e446..b619bc065fef57 100644 --- a/paddle/ir/core/utils.h +++ b/paddle/ir/core/utils.h @@ -120,4 +120,18 @@ struct Filter { using Type = std::tuple<>; }; +template +void PrintInterleave(ForwardIterator begin, + ForwardIterator end, + UnaryFunctor print_func, + NullFunctor between_func) { + if (begin == end) return; + print_func(*begin); + begin++; + for (; begin != end; begin++) { + between_func(); + print_func(*begin); + } +} + } // namespace ir