From ec654d2390f8324c2d521a39d40664718be90759 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 18 Oct 2023 23:11:11 +0800 Subject: [PATCH] [PIR] separate the translation context between different blocks (#58198) * seperate the translation context bewteen different blocks * remove debug code * remove debug code --- .../translator/program_translator.cc | 93 ++++++++++--------- .../translator/program_translator.h | 28 ++++-- test/ir/new_ir/test_special_op_translator.py | 63 +++++++++++++ 3 files changed, 134 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 4dbf9d8707409..11c2743117586 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -46,7 +46,7 @@ using VarDesc = ::paddle::framework::VarDesc; using TCKey = TranslationContext::Key; using TCValue = TranslationContext::Value; -using TCConatiner = TranslationContext::Conatiner; +using TCContainer = TranslationContext::Container; const std::unordered_set ProgramTranslator::no_cast_var_names = { "feed", @@ -227,6 +227,9 @@ const TCValue& TranslationContext::operator[](const TCKey& key) const { const TCValue& TranslationContext::at(const TCKey& key) const { auto it = container_.find(key); + if (it == container_.end() && parent_) { + return parent_->at(key); + } PADDLE_ENFORCE_NE(it, container_.end(), platform::errors::InvalidArgument( @@ -243,12 +246,13 @@ const TCValue& TranslationContext::at(const TCKey& key) const { size_t TranslationContext::count(const TCKey& key) const { auto it = container_.find(key); if (it == container_.end()) { + if (parent_) return parent_->count(key); return 0u; } const auto& values = it->second; PADDLE_ENFORCE_NE( values.size(), - 0, + 0u, platform::errors::InvalidArgument( "param %s should have size > 0, but get:%d", key, values.size())); return values.size(); @@ -261,6 +265,11 @@ void TranslationContext::PopValue(const Key& key) { container_[key].pop_back(); } +TranslationContext* TranslationContext::CreateInnerContext() { + sons_.emplace_back(std::make_unique(this)); + return sons_.back().get(); +} + ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, pir::Program* program) : legacy_program_(legacy_program), program_(program) { @@ -274,6 +283,7 @@ void ProgramTranslator::Translate() { TranslateBlock(legacy_program_->Block(0), 0, legacy_program_->Block(0).OpSize(), + ¶m_map_, program_->block()); SetParameterFromSingleBlock(legacy_program_->Block(0)); @@ -293,6 +303,7 @@ void ProgramTranslator::TranslateBlock( const BlockDesc& src_block, uint64_t start_id, uint64_t end_id, + TranslationContext* translation_ctx, pir::Block* dest_block, bool for_cond_block, std::vector skip_cond_assign) { @@ -325,14 +336,14 @@ void ProgramTranslator::TranslateBlock( std::vector cond_op_list = {op}; std::vector cond_op_ids = GetCondOpIds(src_block, op_id); ConditionBlockCombination cond_op_combination(src_block, cond_op_ids); - pir::Operation* if_op = - TranslateCondIfOperation(cond_op_combination, dest_block); + pir::Operation* if_op = TranslateCondIfOperation( + cond_op_combination, translation_ctx, dest_block); for (auto cond_id : cond_op_ids) { translate_completed[cond_id] = true; } VLOG(10) << "[op translated][conditional_block]" << if_op; } else if (op->Type() == "while") { - TranslateWhileOperation(op, dest_block); + TranslateWhileOperation(op, translation_ctx, dest_block); } else { if (for_cond_block && op->Type() == "assign" && std::count(skip_cond_assign.begin(), @@ -341,7 +352,7 @@ void ProgramTranslator::TranslateBlock( assign_inputs.push_back(op->Input("X")[0]); translate_completed[op_id] = true; } else { - TranslateGeneralOperation(op, dest_block); + TranslateGeneralOperation(op, translation_ctx, dest_block); translate_completed[op_id] = true; } } @@ -351,7 +362,7 @@ void ProgramTranslator::TranslateBlock( if (for_cond_block) { std::vector yeild_inputs; for (size_t id = 0; id < assign_inputs.size(); id++) { - yeild_inputs.emplace_back(param_map_[assign_inputs[id]].value); + yeild_inputs.emplace_back((*translation_ctx)[assign_inputs[id]].value); } pir::AttributeMap attribute_map; auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); @@ -362,11 +373,13 @@ void ProgramTranslator::TranslateBlock( } pir::Operation* ProgramTranslator::TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, pir::Block* dest_block) { + const ConditionBlockCombination& cond_ops, + TranslationContext* translation_ctx, + pir::Block* dest_block) { auto& type_translator = TypeTranslator::instance(); auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); std::vector op_inputs = { - param_map_[cond_ops.CondVarName()].value}; + (*translation_ctx)[cond_ops.CondVarName()].value}; // NOTE(zhangbo): Now paddle::dialect::IfOp has 0 attribute pir::AttributeMap attribute_map; @@ -386,8 +399,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( op_inputs, attribute_map, op_output_types, op_info, 2); for (size_t i = 0; i < output_vardescs.size(); i++) { - param_map_.PushValue(output_vardescs[i]->Name(), - VariableDefiningInfo(operation->result(i))); + translation_ctx->PushValue(output_vardescs[i]->Name(), + VariableDefiningInfo(operation->result(i))); } dest_block->push_back(operation); @@ -398,9 +411,13 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( legacy_program_->Block(cond_ops.TrueBlockId()); pir::Region& true_region = operation->region(0); if (true_region.empty()) true_region.emplace_back(); + + auto* true_block_context = translation_ctx->CreateInnerContext(); + TranslateBlock(true_sub_block, 0, true_sub_block.OpSize(), + true_block_context, true_region.front(), true, cond_ops.TrueBlockOutputVarNames()); @@ -412,9 +429,11 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( legacy_program_->Block(cond_ops.FalseBlockId()); pir::Region& false_region = operation->region(1); if (false_region.empty()) false_region.emplace_back(); + auto* false_block_context = translation_ctx->CreateInnerContext(); TranslateBlock(false_sub_block, 0, false_sub_block.OpSize(), + false_block_context, false_region.front(), true, cond_ops.FalseBlockOutputVarNames()); @@ -426,8 +445,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( return operation; } -void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, - pir::Block* dest_block) { +void ProgramTranslator::TranslateWhileOperation( + const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dest_block) { VLOG(8) << "=============>Start to translate while op:" << op; auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); int index = static_cast(sub_block.OpSize()) - 1; @@ -435,7 +456,7 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, while (index >= 0) { auto sub_op = sub_block.Op(index); if (sub_op->Type() == "assign" && - param_map_.count(sub_op->Output("Out")[0]) > 0) { + translation_ctx->count(sub_op->Output("Out")[0]) > 0) { loop_vars_reverse.emplace_back(sub_op->Output("Out")[0], sub_op->Input("X")[0]); --index; @@ -452,31 +473,31 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, "condition var")); auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::WhileOp::name()); std::vector op_inputs{ - param_map_.at(loop_vars_reverse[0].first).value}; + translation_ctx->at(loop_vars_reverse[0].first).value}; std::vector op_outputs_type; auto body_block = new pir::Block(); - std::vector param_map_status; + std::vector param_status; for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { auto& name = loop_vars_reverse[idx].first; - auto& tc_value = param_map_.at(name); + auto& tc_value = translation_ctx->at(name); auto val_type = tc_value.value.type(); op_inputs.push_back(tc_value.value); op_outputs_type.push_back(val_type); - param_map_status.emplace_back(tc_value); - param_map_.PushValue(name, body_block->AddArgument(val_type)); + param_status.emplace_back(tc_value); + translation_ctx->PushValue(name, body_block->AddArgument(val_type)); } pir::Operation* while_op = pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1); dest_block->push_back(while_op); while_op->region(0).push_back(body_block); - TranslateBlock(sub_block, 0, index + 1, body_block); + TranslateBlock(sub_block, 0, index + 1, translation_ctx, body_block); auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); std::vector yeild_inputs{ - param_map_.at(loop_vars_reverse[0].second).value}; + translation_ctx->at(loop_vars_reverse[0].second).value}; for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { auto& name = loop_vars_reverse[idx].second; - yeild_inputs.push_back(param_map_.at(name).value); + yeild_inputs.push_back(translation_ctx->at(name).value); } body_block->push_back( pir::Operation::Create(yeild_inputs, {}, {}, yeild_info)); @@ -484,26 +505,28 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, index = 0; for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { auto& name = loop_vars_reverse[idx].first; - param_map_.PushValue(name, param_map_status[index++]); + translation_ctx->PushValue(name, param_status[index++]); } auto name_iter = loop_vars_reverse.rbegin(); for (size_t idx = 0; idx < while_op->num_results(); ++idx) { - param_map_.PushValue(name_iter++->first, while_op->result(idx)); + translation_ctx->PushValue(name_iter++->first, while_op->result(idx)); } while_op->Verify(); VLOG(8) << "=============>end to translate while op:" << op; } -void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op, - pir::Block* dest_block) { +void ProgramTranslator::TranslateGeneralOperation( + const OpDesc* src_op, + TranslationContext* translation_ctx, + pir::Block* dest_block) { auto& op_translator = OpTranslator::instance(); OpTranslateFn& fn = op_translator[src_op->Type()]; if (src_op->Type() == "shadow_output") { - if (!param_map_.count(src_op->Input("x")[0])) { + if (!translation_ctx->count(src_op->Input("x")[0])) { return; } } - pir::Operation* operation = fn(ctx_, ¶m_map_, *src_op, dest_block); + pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dest_block); VLOG(10) << "[op translated][general]" << operation << "end"; } @@ -592,20 +615,6 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { } } -void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { - auto& op_translator = OpTranslator::instance(); - for (auto op : block.AllOps()) { - OpTranslateFn& fn = op_translator[op->Type()]; - if (op->Type() == "shadow_output") { - if (!param_map_.count(op->Input("x")[0])) { - continue; - } - } - pir::Operation* operation = fn(ctx_, ¶m_map_, *op, program_->block()); - VLOG(10) << "[op translated][special]" << operation; - } -} - void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { const auto& ops = block.AllOps(); for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) { diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 668f4db2c9682..97c7ae1ec8687 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -18,6 +18,7 @@ #include #include #include + #include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_desc.h" @@ -68,9 +69,10 @@ class TranslationContext { using Key = std::string; using Value = VariableDefiningInfo; using ValueList = std::vector; - using Conatiner = std::unordered_map; + using Container = std::unordered_map; TranslationContext() {} + explicit TranslationContext(TranslationContext* parent) : parent_(parent) {} ~TranslationContext() {} const Value& operator[](const Key& key) const; @@ -80,12 +82,16 @@ class TranslationContext { void PushValue(const Key& key, const Value& value); void PopValue(const Key& key); + TranslationContext* CreateInnerContext(); - Conatiner::const_iterator begin() const { return container_.begin(); } - Conatiner::const_iterator end() const { return container_.end(); } + Container::const_iterator begin() const { return container_.begin(); } + Container::const_iterator end() const { return container_.end(); } private: - Conatiner container_; + Container container_; + TranslationContext* parent_ = nullptr; + std::vector> + sons_; // used to seperate different block }; class ProgramTranslator { @@ -124,20 +130,26 @@ class ProgramTranslator { void TranslateBlock(const BlockDesc& src_block, uint64_t start_id, uint64_t end_id, + TranslationContext* translation_ctx, pir::Block* dest_block, bool for_cond_block = false, std::vector skip_cond_assign = {}); - void TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block); + void TranslateGeneralOperation(const OpDesc* src_op, + TranslationContext* translation_ctx, + pir::Block* dest_block); void GetParameterForSingleBlock(const BlockDesc& block); - void InsertOperationToSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block); void SetStopGradientAttributeForAllValue(const BlockDesc& block); void SetIsPersisableAttributeForAllValue(const BlockDesc& block); /// Translate methods for control flow ops. pir::Operation* TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, pir::Block* dest_block); - void TranslateWhileOperation(const OpDesc* op, pir::Block* dest_block); + const ConditionBlockCombination& cond_ops, + TranslationContext* translation_ctx, + pir::Block* dest_block); + void TranslateWhileOperation(const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dest_block); }; } // namespace translator diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index 30757fbeb95bd..032f451bda842 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -39,6 +39,69 @@ def test_op(self): assert len(str(mappings)) > 0, "no mapping found" +class TestCondWithInplace(unittest.TestCase): + def test_op(self): + def cond_with_inplace(): + x = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + y = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + running_mean = paddle.to_tensor([0], dtype="float32") + running_variance = paddle.to_tensor([1], dtype="float32") + weight = paddle.to_tensor([2], dtype="float32") + bias = paddle.to_tensor([1], dtype="float32") + if x > y: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + + legacy_program = paddle.jit.to_static( + cond_with_inplace, + input_spec=[], + ) + + l = pir.translate_to_new_ir(legacy_program.main_program.desc) + assert l is not None + + def test_nested_op(self): + def cond_with_inplace(): + x = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + y = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + z = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + running_mean = paddle.to_tensor([0], dtype="float32") + running_variance = paddle.to_tensor([1], dtype="float32") + weight = paddle.to_tensor([2], dtype="float32") + bias = paddle.to_tensor([1], dtype="float32") + if x > y: + if y > z: + z = paddle.nn.functional.batch_norm( + z, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + else: + if y > z: + z = paddle.nn.functional.batch_norm( + z, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + + legacy_program = paddle.jit.to_static( + cond_with_inplace, + input_spec=[], + ) + + l = pir.translate_to_new_ir(legacy_program.main_program.desc) + assert l is not None + + class TestElementwiseOpTranscriber(unittest.TestCase): def test_elementwise_without_y_grad(self): place = core.Place()