Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]support while op translate #58098

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx,
pir::OpResult value = operation->result(idx_in_op);
bool generated_by_vector = value.type().isa<pir::VectorType>();

param_map->insert(
param_map->UpdateValue(
arg_name,
VariableDefiningInfo(
value,
Expand Down Expand Up @@ -1439,7 +1439,7 @@ pir::OpResult TranslateNumClassesForOneHot(
"%s should be existed in one_hot_v2 as input depth_tensor.",
legacy_vars[0]);
auto defining_info = param_map->at(legacy_vars[0]);
return defining_info.value;
return defining_info.value.dyn_cast<pir::OpResult>();
}

auto& attribute_translator = AttributeTranslator::instance();
Expand Down Expand Up @@ -1529,7 +1529,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
ctx, param_map, block, x_defining_info, x_name);
x_defining_info = param_map->at(x_name);
}
pir::OpResult x_value = x_defining_info.value;
pir::OpResult x_value = x_defining_info.value.dyn_cast<pir::OpResult>();
IR_ENFORCE(x_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand Down Expand Up @@ -1560,7 +1560,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
ctx, param_map, block, y_defining_info, y_name);
y_defining_info = param_map->at(y_name);
}
pir::OpResult y_value = y_defining_info.value;
pir::OpResult y_value = y_defining_info.value.dyn_cast<pir::OpResult>();
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand Down Expand Up @@ -1680,7 +1680,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
op_desc.Type(),
y_name);
auto y_defining_info = param_map->at(y_name);
pir::OpResult y_value = y_defining_info.value;
pir::OpResult y_value = y_defining_info.value.dyn_cast<pir::OpResult>();
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand All @@ -1698,8 +1698,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
pir::OpResult value = operation->result(idx_in_op);
pir::Builder builder(ctx, operation->GetParent());
auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape);
param_map->insert(y_grad_var_name,
VariableDefiningInfo(reshape_op.out(), false, -1));
param_map->UpdateValue(y_grad_var_name,
VariableDefiningInfo(reshape_op.out(), false, -1));
}
};

Expand Down Expand Up @@ -1771,7 +1771,7 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber {
ctx, param_map, block, defining_info, var_name);
defining_info = param_map->at(var_name).value;
}
return defining_info.value;
return defining_info.value.dyn_cast<pir::OpResult>();
};
}
};
Expand Down Expand Up @@ -1866,8 +1866,8 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber {
auto output_var = output_vars[0];
auto fused_feedforward_op =
operation->dyn_cast<dialect::FusedFeedforwardOp>();
param_map->insert(output_var,
VariableDefiningInfo{fused_feedforward_op.out()});
param_map->UpdateValue(output_var,
VariableDefiningInfo{fused_feedforward_op.out()});
}
}
};
Expand Down
101 changes: 90 additions & 11 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {

const std::unordered_set<std::string> ProgramTranslator::unsupported_ops = {
"conditional_block_grad",
"while",
"while_grad",
};

Expand Down Expand Up @@ -255,9 +254,20 @@ size_t TranslationContext::count(const TCKey& key) const {
return values.size();
}

void TranslationContext::insert(const TCKey& key, const TCValue& value) {
void TranslationContext::PushValue(const Key& key, const Value& value) {
container_[key].push_back(value);
}
void TranslationContext::PopValue(const Key& key) {
container_[key].pop_back();
}

void TranslationContext::UpdateValue(const Key& key, const Value& value) {
auto& vec = container_[key];
if (vec.empty())
vec.push_back(value);
else
vec.back() = value;
}

ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
pir::Program* program)
Expand Down Expand Up @@ -329,6 +339,8 @@ void ProgramTranslator::TranslateBlock(
translate_completed[cond_id] = true;
}
VLOG(10) << "[op translated][conditional_block]" << if_op;
} else if (op->Type() == "while") {
TranslateWhileOperation(op, dest_block);
} else {
if (for_cond_block && op->Type() == "assign" &&
std::count(skip_cond_assign.begin(),
Expand Down Expand Up @@ -382,8 +394,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
op_inputs, attribute_map, op_output_types, op_info, 2);

for (size_t i = 0; i < output_vardescs.size(); i++) {
param_map_.insert(output_vardescs[i]->Name(),
VariableDefiningInfo(operation->result(i)));
param_map_.PushValue(output_vardescs[i]->Name(),
VariableDefiningInfo(operation->result(i)));
}

dest_block->push_back(operation);
Expand Down Expand Up @@ -422,6 +434,71 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
return operation;
}

void ProgramTranslator::TranslateWhileOperation(const OpDesc* op,
pir::Block* dest_block) {
VLOG(8) << "=============>Start to translate while op:" << op;
auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block"));
int index = static_cast<int>(sub_block.OpSize()) - 1;
std::vector<std::pair<std::string, std::string>> loop_vars_reverse;
while (index >= 0) {
auto sub_op = sub_block.Op(index);
if (sub_op->Type() == "assign" &&
param_map_.count(sub_op->Output("Out")[0]) > 0) {
loop_vars_reverse.emplace_back(sub_op->Output("Out")[0],
sub_op->Input("X")[0]);
--index;
} else {
break;
}
}
PADDLE_ENFORCE(!loop_vars_reverse.empty(),
platform::errors::PreconditionNotMet(
"While op must has condition value input"));
PADDLE_ENFORCE(loop_vars_reverse.front().first == op->Input("Condition")[0],
platform::errors::PreconditionNotMet(
"The last op in sub_block of While op must used to assign "
"condition var"));
auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::WhileOp::name());
std::vector<pir::Value> op_inputs{
param_map_.at(loop_vars_reverse[0].first).value};
std::vector<pir::Type> op_outputs_type;
auto body_block = new pir::Block();
for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].first;
auto val = param_map_.at(name).value;
auto val_type = val.type();
op_inputs.push_back(val);
op_outputs_type.push_back(val_type);
param_map_.PushValue(name, body_block->AddArgument(val_type));
}
pir::Operation* while_op =
pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1);
dest_block->push_back(while_op);
while_op->region(0).push_back(body_block);
TranslateBlock(sub_block, 0, index + 1, body_block);

auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name());
std::vector<pir::Value> yeild_inputs{
param_map_.at(loop_vars_reverse[0].second).value};
for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].second;
yeild_inputs.push_back(param_map_.at(name).value);
}
body_block->push_back(
pir::Operation::Create(yeild_inputs, {}, {}, yeild_info));

for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].first;
param_map_.PopValue(name);
}
auto name_iter = loop_vars_reverse.rbegin();
for (size_t idx = 0; idx < while_op->num_results(); ++idx) {
param_map_.UpdateValue(name_iter++->first, while_op->result(idx));
}
while_op->Verify();
VLOG(8) << "=============>end to translate while op:" << op;
}

void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op,
pir::Block* dest_block) {
auto& op_translator = OpTranslator::instance();
Expand All @@ -432,7 +509,7 @@ void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op,
}
}
pir::Operation* operation = fn(ctx_, &param_map_, *src_op, dest_block);
VLOG(10) << "[op translated][special]" << operation << "end";
VLOG(10) << "[op translated][general]" << operation << "end";
}

inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx,
Expand All @@ -451,7 +528,7 @@ inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx,
}

inline pir::Operation* InsertSetParamaterOp(pir::IrContext* ctx,
pir::OpResult defining_op_result,
pir::Value defining_op_result,
const VarDesc* var) {
std::string set_parameter_op_name(pir::SetParameterOp::name());
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name);
Expand Down Expand Up @@ -502,7 +579,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
"VarDesc of [%s] can not be nullptr", var_name));
pir::Operation* op = InsertGetParamaterOp(ctx_, var_desc);
program_->block()->push_back(op);
param_map_.insert(var_name, VariableDefiningInfo(op->result(0)));
param_map_.PushValue(var_name, VariableDefiningInfo(op->result(0)));
VLOG(10) << "[op translated][get parameter]" << var_name;

program_->SetParameter(var_name, nullptr);
Expand Down Expand Up @@ -554,7 +631,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
need_set_parameter_op &= (param_map_.count(var_name) != 0);
need_set_parameter_op &= (!set_input_var_names.count(var_name));
if (need_set_parameter_op) {
pir::OpResult defining_op_result = param_map_[var_name].value;
pir::OpResult defining_op_result =
param_map_[var_name].value.dyn_cast<pir::OpResult>();
if (!defining_op_result) {
continue;
}
Expand All @@ -565,7 +643,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
program_->block(),
param_map_[var_name],
var_name);
defining_op_result = param_map_.at(var_name).value;
defining_op_result =
param_map_.at(var_name).value.dyn_cast<pir::OpResult>();
}

pir::Operation* op = InsertSetParamaterOp(
Expand Down Expand Up @@ -604,7 +683,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
continue;
}
for (const auto& value_info : value_list) {
pir::OpResult value = value_info.value;
pir::OpResult value = value_info.value.dyn_cast<pir::OpResult>();
if (!value) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Value of [%s] can not ber None", var_name));
Expand Down Expand Up @@ -645,7 +724,7 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue(
continue;
}
for (const auto& value_info : value_list) {
pir::OpResult value = value_info.value;
pir::OpResult value = value_info.value.dyn_cast<pir::OpResult>();
if (!value) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Value of [%s] can not ber None", var_name));
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ namespace paddle {
namespace translator {

struct VariableDefiningInfo {
VariableDefiningInfo(pir::OpResult value,
VariableDefiningInfo(pir::Value value,
bool generated_by_vector = false,
int idx_in_vector = -1)
: value(value),
generated_by_vector(generated_by_vector),
idx_in_vector(idx_in_vector) {}
VariableDefiningInfo() {}

pir::OpResult value;
pir::Value value;

bool generated_by_vector =
false; // true if target variable is generated by Vector<Tensor>
Expand Down Expand Up @@ -78,7 +78,9 @@ class TranslationContext {
size_t count(const Key& key)
const; // Caution: not exactly same as count in stl library

void insert(const Key& key, const Value& value);
void UpdateValue(const Key& key, const Value& value);
void PushValue(const Key& key, const Value& value);
void PopValue(const Key& key);

Conatiner::const_iterator begin() const { return container_.begin(); }
Conatiner::const_iterator end() const { return container_.end(); }
Expand Down Expand Up @@ -136,6 +138,7 @@ class ProgramTranslator {
/// Translate methods for control flow ops.
pir::Operation* TranslateCondIfOperation(
const ConditionBlockCombination& cond_ops, pir::Block* dest_block);
void TranslateWhileOperation(const OpDesc* op, pir::Block* dest_block);
};

} // namespace translator
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pir::Operation* InsertSliceOperationForTarget(
op_info);
block->push_back(operation);
pir::OpResult target_op_result = operation->result(0);
param_map->insert(arg_name, VariableDefiningInfo(target_op_result));
param_map->UpdateValue(arg_name, VariableDefiningInfo(target_op_result));
return operation;
}

Expand Down
5 changes: 5 additions & 0 deletions test/cpp/pir/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ file(
${CMAKE_CURRENT_BINARY_DIR}/conditional_block_test.prog
EXPECTED_MD5 cf9dc869ca7f69e2d57b38dbf8427134)

file(
DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/while_op_test.prog
${CMAKE_CURRENT_BINARY_DIR}/while_op_test.prog
EXPECTED_MD5 290164ae52a496332b0be5829fc93bcd)

copy_if_different(${CMAKE_CURRENT_SOURCE_DIR}/TestParserText.txt
${CMAKE_CURRENT_BINARY_DIR}/TestParserText.txt)

Expand Down
72 changes: 72 additions & 0 deletions test/cpp/pir/core/program_translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,75 @@ TEST(IrParserTest, StartupProgram) {

EXPECT_TRUE(ssp.str() == ss.str());
}

TEST(OperatorDialectTest, WhileOpProgram) {
auto p = load_from_file("while_op_test.prog");
EXPECT_EQ(p.Size(), 3u);
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<OperatorDialect>();
ctx->GetOrRegisterDialect<pir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);

std::stringstream ss;
program->Print(ss);

LOG(INFO) << ss.str();

EXPECT_EQ(program->block()->size(), 4u);
size_t id = 0;
for (auto &op : *program->block()) {
if (id == 0 || id == 1) {
EXPECT_TRUE(op->isa<paddle::dialect::FullOp>());
}
if (id == 2) {
EXPECT_TRUE(op->isa<paddle::dialect::LessThanOp>());
}
if (id == 3) {
EXPECT_TRUE(op->isa<paddle::dialect::WhileOp>());
EXPECT_EQ(op->num_regions(), 1u);
// body block
pir::Block *body_block =
op->dyn_cast<paddle::dialect::WhileOp>().body_block();
size_t body_id = 0;
for (auto &op1 : *body_block) {
if (body_id == 0) {
EXPECT_TRUE(op1->isa<paddle::dialect::FullOp>());
}
if (body_id == 1) {
EXPECT_TRUE(op1->isa<paddle::dialect::ScaleOp>());
}
if (body_id == 2) {
EXPECT_TRUE(op1->isa<paddle::dialect::LessThanOp>());
}
if (body_id == 3) {
pir::Block *body_body_block =
op1->dyn_cast<paddle::dialect::WhileOp>().body_block();
size_t body_body_id = 0;
for (auto &op2 : *body_body_block) {
if (body_body_id == 0) {
EXPECT_TRUE(op2->isa<paddle::dialect::FullOp>());
}
if (body_body_id == 1) {
EXPECT_TRUE(op2->isa<paddle::dialect::ScaleOp>());
}
if (body_body_id == 2) {
EXPECT_TRUE(op2->isa<paddle::dialect::LessThanOp>());
}
if (body_body_id == 3) {
EXPECT_TRUE(op2->isa<pir::YieldOp>());
}
body_body_id++;
}
}
if (body_id == 4) {
EXPECT_TRUE(op1->isa<paddle::dialect::LessThanOp>());
}
if (body_id == 5) {
EXPECT_TRUE(op1->isa<pir::YieldOp>());
}
body_id++;
}
}
id++;
}
}