Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Adapt startup program #54452

Merged
merged 3 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions paddle/fluid/ir_adaptor/translator/op_compat_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]):
def insert_new_mutable_attributes(
op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]]
):
op_mutable_attribues[op_name] = set()
op_mutable_attribute_infos[op_name] = {}
if op_name not in op_mutable_attribues:
op_mutable_attribues[op_name] = set()
if op_name not in op_mutable_attribute_infos:
op_mutable_attribute_infos[op_name] = {}
for (
attribute_name,
mutable_attribute_info,
Expand Down Expand Up @@ -116,6 +118,14 @@ def insert_new_mutable_attributes(
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])

if "int_array" in op_compat_item:
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)

if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])

# special op mappings
op_name_mappings["fetch_v2"] = "fetch"

Expand Down
112 changes: 95 additions & 17 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
#include "glog/logging.h"

#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/ir_adaptor/translator/op_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace translator {

using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;

ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program)
Expand All @@ -55,38 +59,77 @@ void ProgramTranslator::Translate() {

for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx);
ExtractParameterFromSingleBlock(block);
GetParameterForSingleBlock(block);
}

for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx);
InsertOperationToSingleBlock(block);
}

for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) {
const BlockDesc& block = legacy_program->Block(block_idx);
SetParameterFromSingleBlock(block);
}
}

void ProgramTranslator::ExtractParameterFromSingleBlock(
const BlockDesc& block) {
inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
const VarDesc* var) {
auto& type_translator = TypeTranslator::instance();
std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};

ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::Create(
{}, op_attribute_map, {translated_var_type}, op_info);
return operation;
}

inline ir::Operation* InsertSetParamaterOp(ir::IrContext* ctx,
ir::OpResult defining_op_result,
const VarDesc* var) {
std::string set_parameter_op_name(ir::SetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};

ir::Operation* operation = ir::Operation::Create(
{defining_op_result}, op_attribute_map, {}, op_info);
return operation;
}

void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
for (auto& var : block.AllVars()) {
if (!var->Persistable()) continue;
if (param_map.count(var->Name()) != 0) continue;
if (no_cast_var_names.count(var->Name()) != 0) continue;

std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::Create(
{}, op_attribute_map, {translated_var_type}, op_info);
program->block()->push_back(operation);
param_map[var->Name()] =
VariableDefiningInfo(operation->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << operation;

program->SetParameter(var->Name(), nullptr);
parameter_name_mappings[var->Name()] = var;
}

for (auto op_desc : block.AllOps()) {
for (const auto& n : op_desc->Inputs()) {
const auto& input_var_names = n.second;
for (const auto& var_name : input_var_names) {
bool need_get_parameter_op = (parameter_name_mappings.find(var_name) !=
parameter_name_mappings.end());
need_get_parameter_op &= (parameter_visited.count(var_name) == 0);
if (need_get_parameter_op) {
ir::Operation* op =
InsertGetParamaterOp(ctx, parameter_name_mappings[var_name]);
program->block()->push_back(op);
param_map[var_name] = VariableDefiningInfo(op->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << op;

program->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name);
}
}
}
}
}

Expand All @@ -99,5 +142,40 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
}
}

void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
const auto& ops = block.AllOps();
for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) {
for (const auto& n : (*op_desc)->Outputs()) {
const auto& output_var_names = n.second;
for (const auto& var_name : output_var_names) {
bool need_set_parameter_op = (parameter_name_mappings.find(var_name) !=
parameter_name_mappings.end());
need_set_parameter_op &= (parameter_visited.count(var_name) == 0);
if (need_set_parameter_op) {
ir::OpResult defining_op_result = param_map[var_name].value;
ir::Operation* op = InsertSetParamaterOp(
ctx, defining_op_result, parameter_name_mappings[var_name]);

ir::Block* block = program->block();
ir::Block::iterator insert_pos = std::find(
block->begin(), block->end(), defining_op_result.owner());

IR_ENFORCE(
insert_pos != block->end(),
"Parameter %s must have corresponding its defining operation",
var_name);
insert_pos++;

block->insert(insert_pos, op);
VLOG(10) << "[op translated][set parameter]" << op;

program->SetParameter(var_name, nullptr);
parameter_visited.insert(var_name);
}
}
}
}
}

} // namespace translator
} // namespace paddle
13 changes: 9 additions & 4 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ using TranslationContext =
class ProgramTranslator {
using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;

public:
explicit ProgramTranslator(const ProgramDesc* legacy_program,
Expand All @@ -58,10 +59,13 @@ class ProgramTranslator {
void Translate();

private:
const ProgramDesc* legacy_program;
ir::Program* program;
const ProgramDesc* legacy_program; // not owned
ir::Program* program; // not owned
ir::IrContext* ctx; // not owned

TranslationContext param_map;
ir::IrContext* ctx;
std::unordered_map<std::string, VarDesc*> parameter_name_mappings;
std::unordered_set<std::string> parameter_visited;

/// In the legacy program desc, there are two special named varibales:
/// 1. "feed", the input variable of feed op
Expand All @@ -71,8 +75,9 @@ class ProgramTranslator {
/// `ExtractParameterFromSingleBlock`
static const std::unordered_set<std::string> no_cast_var_names;

void ExtractParameterFromSingleBlock(const BlockDesc& block);
void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
};

} // namespace translator
Expand Down
11 changes: 8 additions & 3 deletions test/cpp/ir/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ cc_test_old(
gtest)

file(
DOWNLOAD
https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog
${CMAKE_CURRENT_BINARY_DIR}/restnet50_main.prog
DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_main.prog
${CMAKE_CURRENT_BINARY_DIR}/resnet50_main.prog
EXPECTED_MD5 b64c0ad3c96d99fc37d12094623ce1ad)

file(
DOWNLOAD
https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_startup.prog
${CMAKE_CURRENT_BINARY_DIR}/resnet50_startup.prog
EXPECTED_MD5 6affc5f40f0f0bb84d956919b95eaf50)

cc_test_old(
program_translator_test
SRCS
Expand Down
22 changes: 20 additions & 2 deletions test/cpp/ir/core/program_translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ ProgramDesc load_from_file(const std::string &file_name) {
return ProgramDesc(buffer);
}

TEST(PaddleDialectTest, Translator) {
auto p = load_from_file("restnet50_main.prog");
TEST(PaddleDialectTest, MainProgram) {
auto p = load_from_file("resnet50_main.prog");
EXPECT_EQ(p.Size(), 1u);

ir::IrContext *ctx = ir::IrContext::Instance();
Expand All @@ -63,3 +63,21 @@ TEST(PaddleDialectTest, Translator) {

program->Print(std::cout);
}

TEST(PaddleDialectTest, StartupProgram) {
auto p = load_from_file("resnet50_startup.prog");
EXPECT_EQ(p.Size(), 1u);

ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);

size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op +
// consant_op_for_uniform
// + consant_op for guassian
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 3 + 53);

program->Print(std::cout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest delete cout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
program->Print(std::cout);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,之后单独提PR修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #54499

}