diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 2ffe86c9f1d2f..7cd2153e5d047 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -677,7 +677,7 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx, pir::OpResult value = operation->result(idx_in_op); bool generated_by_vector = value.type().isa(); - param_map->insert( + param_map->UpdateValue( arg_name, VariableDefiningInfo( value, @@ -1439,7 +1439,7 @@ pir::OpResult TranslateNumClassesForOneHot( "%s should be existed in one_hot_v2 as input depth_tensor.", legacy_vars[0]); auto defining_info = param_map->at(legacy_vars[0]); - return defining_info.value; + return defining_info.value.dyn_cast(); } auto& attribute_translator = AttributeTranslator::instance(); @@ -1529,7 +1529,7 @@ struct ElementwiseTranscriber : public OpTranscriber { ctx, param_map, block, x_defining_info, x_name); x_defining_info = param_map->at(x_name); } - pir::OpResult x_value = x_defining_info.value; + pir::OpResult x_value = x_defining_info.value.dyn_cast(); IR_ENFORCE(x_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), @@ -1560,7 +1560,7 @@ struct ElementwiseTranscriber : public OpTranscriber { ctx, param_map, block, y_defining_info, y_name); y_defining_info = param_map->at(y_name); } - pir::OpResult y_value = y_defining_info.value; + pir::OpResult y_value = y_defining_info.value.dyn_cast(); IR_ENFORCE(y_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), @@ -1680,7 +1680,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber { op_desc.Type(), y_name); auto y_defining_info = param_map->at(y_name); - pir::OpResult y_value = y_defining_info.value; + pir::OpResult y_value = y_defining_info.value.dyn_cast(); IR_ENFORCE(y_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), @@ -1698,8 +1698,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber { pir::OpResult value = operation->result(idx_in_op); pir::Builder builder(ctx, operation->GetParent()); auto reshape_op = builder.Build(value, y_shape); - param_map->insert(y_grad_var_name, - VariableDefiningInfo(reshape_op.out(), false, -1)); + param_map->UpdateValue(y_grad_var_name, + VariableDefiningInfo(reshape_op.out(), false, -1)); } }; @@ -1771,7 +1771,7 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber { ctx, param_map, block, defining_info, var_name); defining_info = param_map->at(var_name).value; } - return defining_info.value; + return defining_info.value.dyn_cast(); }; } }; @@ -1866,8 +1866,8 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber { auto output_var = output_vars[0]; auto fused_feedforward_op = operation->dyn_cast(); - param_map->insert(output_var, - VariableDefiningInfo{fused_feedforward_op.out()}); + param_map->UpdateValue(output_var, + VariableDefiningInfo{fused_feedforward_op.out()}); } } }; diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index f53a93ac3a102..dba2bae8dc911 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -55,7 +55,6 @@ const std::unordered_set ProgramTranslator::no_cast_var_names = { const std::unordered_set ProgramTranslator::unsupported_ops = { "conditional_block_grad", - "while", "while_grad", }; @@ -255,9 +254,20 @@ size_t TranslationContext::count(const TCKey& key) const { return values.size(); } -void TranslationContext::insert(const TCKey& key, const TCValue& value) { +void TranslationContext::PushValue(const Key& key, const Value& value) { container_[key].push_back(value); } +void TranslationContext::PopValue(const Key& key) { + container_[key].pop_back(); +} + +void TranslationContext::UpdateValue(const Key& key, const Value& value) { + auto& vec = container_[key]; + if (vec.empty()) + vec.push_back(value); + else + vec.back() = value; +} ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, pir::Program* program) @@ -329,6 +339,8 @@ void ProgramTranslator::TranslateBlock( translate_completed[cond_id] = true; } VLOG(10) << "[op translated][conditional_block]" << if_op; + } else if (op->Type() == "while") { + TranslateWhileOperation(op, dest_block); } else { if (for_cond_block && op->Type() == "assign" && std::count(skip_cond_assign.begin(), @@ -382,8 +394,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( op_inputs, attribute_map, op_output_types, op_info, 2); for (size_t i = 0; i < output_vardescs.size(); i++) { - param_map_.insert(output_vardescs[i]->Name(), - VariableDefiningInfo(operation->result(i))); + param_map_.PushValue(output_vardescs[i]->Name(), + VariableDefiningInfo(operation->result(i))); } dest_block->push_back(operation); @@ -422,6 +434,71 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( return operation; } +void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, + pir::Block* dest_block) { + VLOG(8) << "=============>Start to translate while op:" << op; + auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); + int index = static_cast(sub_block.OpSize()) - 1; + std::vector> loop_vars_reverse; + while (index >= 0) { + auto sub_op = sub_block.Op(index); + if (sub_op->Type() == "assign" && + param_map_.count(sub_op->Output("Out")[0]) > 0) { + loop_vars_reverse.emplace_back(sub_op->Output("Out")[0], + sub_op->Input("X")[0]); + --index; + } else { + break; + } + } + PADDLE_ENFORCE(!loop_vars_reverse.empty(), + platform::errors::PreconditionNotMet( + "While op must has condition value input")); + PADDLE_ENFORCE(loop_vars_reverse.front().first == op->Input("Condition")[0], + platform::errors::PreconditionNotMet( + "The last op in sub_block of While op must used to assign " + "condition var")); + auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::WhileOp::name()); + std::vector op_inputs{ + param_map_.at(loop_vars_reverse[0].first).value}; + std::vector op_outputs_type; + auto body_block = new pir::Block(); + for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { + auto& name = loop_vars_reverse[idx].first; + auto val = param_map_.at(name).value; + auto val_type = val.type(); + op_inputs.push_back(val); + op_outputs_type.push_back(val_type); + param_map_.PushValue(name, body_block->AddArgument(val_type)); + } + pir::Operation* while_op = + pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1); + dest_block->push_back(while_op); + while_op->region(0).push_back(body_block); + TranslateBlock(sub_block, 0, index + 1, body_block); + + auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); + std::vector yeild_inputs{ + param_map_.at(loop_vars_reverse[0].second).value}; + for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { + auto& name = loop_vars_reverse[idx].second; + yeild_inputs.push_back(param_map_.at(name).value); + } + body_block->push_back( + pir::Operation::Create(yeild_inputs, {}, {}, yeild_info)); + + for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { + auto& name = loop_vars_reverse[idx].first; + param_map_.PopValue(name); + } + auto name_iter = loop_vars_reverse.rbegin(); + for (size_t idx = 0; idx < while_op->num_results(); ++idx) { + param_map_.UpdateValue(name_iter++->first, while_op->result(idx)); + } + while_op->Verify(); + VLOG(8) << "=============>end to translate while op:" << op; +} + void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block) { auto& op_translator = OpTranslator::instance(); @@ -432,7 +509,7 @@ void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op, } } pir::Operation* operation = fn(ctx_, ¶m_map_, *src_op, dest_block); - VLOG(10) << "[op translated][special]" << operation << "end"; + VLOG(10) << "[op translated][general]" << operation << "end"; } inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx, @@ -451,7 +528,7 @@ inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx, } inline pir::Operation* InsertSetParamaterOp(pir::IrContext* ctx, - pir::OpResult defining_op_result, + pir::Value defining_op_result, const VarDesc* var) { std::string set_parameter_op_name(pir::SetParameterOp::name()); pir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name); @@ -502,7 +579,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { "VarDesc of [%s] can not be nullptr", var_name)); pir::Operation* op = InsertGetParamaterOp(ctx_, var_desc); program_->block()->push_back(op); - param_map_.insert(var_name, VariableDefiningInfo(op->result(0))); + param_map_.PushValue(var_name, VariableDefiningInfo(op->result(0))); VLOG(10) << "[op translated][get parameter]" << var_name; program_->SetParameter(var_name, nullptr); @@ -554,7 +631,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { need_set_parameter_op &= (param_map_.count(var_name) != 0); need_set_parameter_op &= (!set_input_var_names.count(var_name)); if (need_set_parameter_op) { - pir::OpResult defining_op_result = param_map_[var_name].value; + pir::OpResult defining_op_result = + param_map_[var_name].value.dyn_cast(); if (!defining_op_result) { continue; } @@ -565,7 +643,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { program_->block(), param_map_[var_name], var_name); - defining_op_result = param_map_.at(var_name).value; + defining_op_result = + param_map_.at(var_name).value.dyn_cast(); } pir::Operation* op = InsertSetParamaterOp( @@ -604,7 +683,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( continue; } for (const auto& value_info : value_list) { - pir::OpResult value = value_info.value; + pir::OpResult value = value_info.value.dyn_cast(); if (!value) { PADDLE_THROW(phi::errors::PreconditionNotMet( "Value of [%s] can not ber None", var_name)); @@ -645,7 +724,7 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue( continue; } for (const auto& value_info : value_list) { - pir::OpResult value = value_info.value; + pir::OpResult value = value_info.value.dyn_cast(); if (!value) { PADDLE_THROW(phi::errors::PreconditionNotMet( "Value of [%s] can not ber None", var_name)); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 1215ae603a85d..9d9e1b99552af 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -29,7 +29,7 @@ namespace paddle { namespace translator { struct VariableDefiningInfo { - VariableDefiningInfo(pir::OpResult value, + VariableDefiningInfo(pir::Value value, bool generated_by_vector = false, int idx_in_vector = -1) : value(value), @@ -37,7 +37,7 @@ struct VariableDefiningInfo { idx_in_vector(idx_in_vector) {} VariableDefiningInfo() {} - pir::OpResult value; + pir::Value value; bool generated_by_vector = false; // true if target variable is generated by Vector @@ -78,7 +78,9 @@ class TranslationContext { size_t count(const Key& key) const; // Caution: not exactly same as count in stl library - void insert(const Key& key, const Value& value); + void UpdateValue(const Key& key, const Value& value); + void PushValue(const Key& key, const Value& value); + void PopValue(const Key& key); Conatiner::const_iterator begin() const { return container_.begin(); } Conatiner::const_iterator end() const { return container_.end(); } @@ -136,6 +138,7 @@ class ProgramTranslator { /// Translate methods for control flow ops. pir::Operation* TranslateCondIfOperation( const ConditionBlockCombination& cond_ops, pir::Block* dest_block); + void TranslateWhileOperation(const OpDesc* op, pir::Block* dest_block); }; } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc index 9582b3b78a56b..e8102e4e686a2 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.cc +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -59,7 +59,7 @@ pir::Operation* InsertSliceOperationForTarget( op_info); block->push_back(operation); pir::OpResult target_op_result = operation->result(0); - param_map->insert(arg_name, VariableDefiningInfo(target_op_result)); + param_map->UpdateValue(arg_name, VariableDefiningInfo(target_op_result)); return operation; } diff --git a/test/cpp/pir/core/CMakeLists.txt b/test/cpp/pir/core/CMakeLists.txt index 0f0ec568bb50a..ca71cb8fe9eef 100644 --- a/test/cpp/pir/core/CMakeLists.txt +++ b/test/cpp/pir/core/CMakeLists.txt @@ -65,6 +65,11 @@ file( ${CMAKE_CURRENT_BINARY_DIR}/conditional_block_test.prog EXPECTED_MD5 cf9dc869ca7f69e2d57b38dbf8427134) +file( + DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/while_op_test.prog + ${CMAKE_CURRENT_BINARY_DIR}/while_op_test.prog + EXPECTED_MD5 290164ae52a496332b0be5829fc93bcd) + copy_if_different(${CMAKE_CURRENT_SOURCE_DIR}/TestParserText.txt ${CMAKE_CURRENT_BINARY_DIR}/TestParserText.txt) diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index 483299c206129..ba85e396d41b7 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -266,3 +266,75 @@ TEST(IrParserTest, StartupProgram) { EXPECT_TRUE(ssp.str() == ss.str()); } + +TEST(OperatorDialectTest, WhileOpProgram) { + auto p = load_from_file("while_op_test.prog"); + EXPECT_EQ(p.Size(), 3u); + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + std::stringstream ss; + program->Print(ss); + + LOG(INFO) << ss.str(); + + EXPECT_EQ(program->block()->size(), 4u); + size_t id = 0; + for (auto &op : *program->block()) { + if (id == 0 || id == 1) { + EXPECT_TRUE(op->isa()); + } + if (id == 2) { + EXPECT_TRUE(op->isa()); + } + if (id == 3) { + EXPECT_TRUE(op->isa()); + EXPECT_EQ(op->num_regions(), 1u); + // body block + pir::Block *body_block = + op->dyn_cast().body_block(); + size_t body_id = 0; + for (auto &op1 : *body_block) { + if (body_id == 0) { + EXPECT_TRUE(op1->isa()); + } + if (body_id == 1) { + EXPECT_TRUE(op1->isa()); + } + if (body_id == 2) { + EXPECT_TRUE(op1->isa()); + } + if (body_id == 3) { + pir::Block *body_body_block = + op1->dyn_cast().body_block(); + size_t body_body_id = 0; + for (auto &op2 : *body_body_block) { + if (body_body_id == 0) { + EXPECT_TRUE(op2->isa()); + } + if (body_body_id == 1) { + EXPECT_TRUE(op2->isa()); + } + if (body_body_id == 2) { + EXPECT_TRUE(op2->isa()); + } + if (body_body_id == 3) { + EXPECT_TRUE(op2->isa()); + } + body_body_id++; + } + } + if (body_id == 4) { + EXPECT_TRUE(op1->isa()); + } + if (body_id == 5) { + EXPECT_TRUE(op1->isa()); + } + body_id++; + } + } + id++; + } +}