diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 05f410dc67d886..3061fdeb556d77 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -58,21 +58,36 @@ const std::unordered_set ProgramTranslator::unsupported_ops = { static std::vector GetCondOpIds(const BlockDesc& src_block, uint64_t first_id) { std::vector op_list = {first_id}; - if (src_block.Op(static_cast(first_id + 1))->Type() == "logical_not") { + if (((first_id + 1) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 1))->Type() == "logical_not")) { op_list.emplace_back(first_id + 1); } - if (src_block.Op(static_cast(first_id + 2))->Type() == - "conditional_block") { + if (((first_id + 2) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 2))->Type() == + "conditional_block")) { op_list.emplace_back(first_id + 2); } - if (src_block.Op(static_cast(first_id + 3))->Type() == "cast") { + if (((first_id + 3) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 3))->Type() == "cast")) { op_list.emplace_back(first_id + 3); } - size_t output_size = - src_block.Op(static_cast(first_id))->Output("Out").size(); + // Note(zhangbo): Some output variables are input, without select_input op. + std::vector output_names = + src_block.Op(static_cast(first_id))->Output("Out"); + std::vector input_names = + src_block.Op(static_cast(first_id))->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + size_t output_size = diffs.size(); for (size_t i = 0; i < output_size; i++) { - if (src_block.Op(static_cast(first_id + 4 + i))->Type() == - "select_input") { + if (((first_id + 4 + i) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 4 + i))->Type() == + "select_input")) { op_list.emplace_back(first_id + 4 + i); } } @@ -97,7 +112,16 @@ const std::string& ConditionBlockCombination::CondVarName() const { } size_t ConditionBlockCombination::OutputSize() const { - return op_list_[0]->Output("Out").size(); + std::vector output_names = op_list_[0]->Output("Out"); + std::vector input_names = op_list_[0]->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + return diffs.size(); } std::vector<::paddle::framework::VarDesc*> @@ -112,23 +136,41 @@ ConditionBlockCombination::OutputVars() const { return outputs; } -const std::vector& -ConditionBlockCombination::TrueBlockOutputVarNames() const { - return op_list_[0]->Output("Out"); -} - -int ConditionBlockCombination::TrueBlockId() const { - return op_list_[0]->GetBlockAttrId("sub_block"); +std::vector ConditionBlockCombination::TrueBlockOutputVarNames() + const { + std::vector output_names = op_list_[0]->Output("Out"); + std::vector input_names = op_list_[0]->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + return diffs; } std::vector ConditionBlockCombination::FalseBlockOutputVarNames() const { if (op_list_.size() > 1) { - return op_list_[2]->Output("Out"); + std::vector output_names = op_list_[2]->Output("Out"); + std::vector input_names = op_list_[2]->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + return diffs; } return {""}; } +int ConditionBlockCombination::TrueBlockId() const { + return op_list_[0]->GetBlockAttrId("sub_block"); +} + int ConditionBlockCombination::FalseBlockId() const { if (op_list_.size() > 1) { return op_list_[2]->GetBlockAttrId("sub_block"); @@ -143,9 +185,6 @@ bool ConditionBlockCombination::Verify( if (op_list[id]->Type() != "conditional_block") { return false; } - if (op_list.size() == 1 && op_list[id]->Output("Out").size() != 0) { - return false; - } } else if (id == 1) { if (op_list[id]->Type() != "logical_not") { return false; @@ -207,11 +246,13 @@ void ProgramTranslator::Translate() { } } -void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - pir::Block* dest_block, - bool for_cond_block) { +void ProgramTranslator::TranslateBlock( + const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + pir::Block* dest_block, + bool for_cond_block, + std::vector skip_cond_assign) { VLOG(8) << "=============>start to translate a block"; PADDLE_ENFORCE( (src_block.OpSize() >= end_id) && (start_id <= end_id), @@ -223,10 +264,12 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, src_block.OpSize())); std::unordered_map translate_completed; + std::vector assign_inputs; for (uint64_t op_id = start_id; op_id < end_id; op_id++) { if (translate_completed.count(op_id) && translate_completed.at(op_id)) { continue; } + auto op = src_block.Op(static_cast(op_id)); VLOG(8) << "=============>start to translate a op: " << op->Type(); @@ -246,20 +289,24 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, } VLOG(10) << "[op translated][conditional_block]" << if_op; } else { - TranslateGeneralOperation(op, dest_block); - translate_completed[op_id] = true; + if (for_cond_block && op->Type() == "assign" && + std::count(skip_cond_assign.begin(), + skip_cond_assign.end(), + op->Output("Out")[0])) { + assign_inputs.push_back(op->Input("X")[0]); + translate_completed[op_id] = true; + } else { + TranslateGeneralOperation(op, dest_block); + translate_completed[op_id] = true; + } } } // NOTE(zhangbo): If conditional_block operator has output, the cf.yeild // operator needs to be inserted if (for_cond_block) { std::vector yeild_inputs; - for (size_t id = end_id; id < src_block.OpSize(); id++) { - PADDLE_ENFORCE( - src_block.Op(id)->Type() == "assign", - "The operator at the end of the sub block needs to be assign"); - yeild_inputs.emplace_back( - param_map_[src_block.Op(static_cast(id))->Input("X")[0]].value); + for (size_t id = 0; id < assign_inputs.size(); id++) { + yeild_inputs.emplace_back(param_map_[assign_inputs[id]].value); } pir::AttributeMap attribute_map; auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); @@ -308,9 +355,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( if (true_region.empty()) true_region.emplace_back(); TranslateBlock(true_sub_block, 0, - true_sub_block.OpSize() - cond_ops.OutputSize(), + true_sub_block.OpSize(), true_region.front(), - true); + true, + cond_ops.TrueBlockOutputVarNames()); } VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; @@ -321,9 +369,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( if (false_region.empty()) false_region.emplace_back(); TranslateBlock(false_sub_block, 0, - false_sub_block.OpSize() - cond_ops.OutputSize(), + false_sub_block.OpSize(), false_region.front(), - true); + true, + cond_ops.FalseBlockOutputVarNames()); } VLOG(4) << "[general op][conditional_block] IfOp false block translate end."; VLOG(4) << "[general op][conditional_block] IfOp translate end."; diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index a59f4b34a5adaa..fcfbef609be686 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -50,12 +50,12 @@ class ConditionBlockCombination { ConditionBlockCombination(const ::paddle::framework::BlockDesc& src_block, const std::vector& op_ids); const std::string& CondVarName() const; + int TrueBlockId() const; + int FalseBlockId() const; size_t OutputSize() const; std::vector<::paddle::framework::VarDesc*> OutputVars() const; - const std::vector& TrueBlockOutputVarNames() const; - int TrueBlockId() const; + std::vector TrueBlockOutputVarNames() const; std::vector FalseBlockOutputVarNames() const; - int FalseBlockId() const; private: bool Verify(const std::vector<::paddle::framework::OpDesc*>& op_list); @@ -101,7 +101,8 @@ class ProgramTranslator { uint64_t start_id, uint64_t end_id, pir::Block* dest_block, - bool for_cond_block = false); + bool for_cond_block = false, + std::vector skip_cond_assign = {}); void TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block); void GetParameterForSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block);