Skip to content

Commit

Permalink
[PIR] separate the translation context between different blocks (Padd…
Browse files Browse the repository at this point in the history
…lePaddle#58198)

* seperate the translation context bewteen different blocks

* remove debug code

* remove debug code
  • Loading branch information
kangguangli authored and jiahy0825 committed Oct 26, 2023
1 parent d4cd048 commit ec654d2
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 50 deletions.
93 changes: 51 additions & 42 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using VarDesc = ::paddle::framework::VarDesc;

using TCKey = TranslationContext::Key;
using TCValue = TranslationContext::Value;
using TCConatiner = TranslationContext::Conatiner;
using TCContainer = TranslationContext::Container;

const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
Expand Down Expand Up @@ -227,6 +227,9 @@ const TCValue& TranslationContext::operator[](const TCKey& key) const {

const TCValue& TranslationContext::at(const TCKey& key) const {
auto it = container_.find(key);
if (it == container_.end() && parent_) {
return parent_->at(key);
}
PADDLE_ENFORCE_NE(it,
container_.end(),
platform::errors::InvalidArgument(
Expand All @@ -243,12 +246,13 @@ const TCValue& TranslationContext::at(const TCKey& key) const {
size_t TranslationContext::count(const TCKey& key) const {
auto it = container_.find(key);
if (it == container_.end()) {
if (parent_) return parent_->count(key);
return 0u;
}
const auto& values = it->second;
PADDLE_ENFORCE_NE(
values.size(),
0,
0u,
platform::errors::InvalidArgument(
"param %s should have size > 0, but get:%d", key, values.size()));
return values.size();
Expand All @@ -261,6 +265,11 @@ void TranslationContext::PopValue(const Key& key) {
container_[key].pop_back();
}

TranslationContext* TranslationContext::CreateInnerContext() {
sons_.emplace_back(std::make_unique<TranslationContext>(this));
return sons_.back().get();
}

ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
pir::Program* program)
: legacy_program_(legacy_program), program_(program) {
Expand All @@ -274,6 +283,7 @@ void ProgramTranslator::Translate() {
TranslateBlock(legacy_program_->Block(0),
0,
legacy_program_->Block(0).OpSize(),
&param_map_,
program_->block());

SetParameterFromSingleBlock(legacy_program_->Block(0));
Expand All @@ -293,6 +303,7 @@ void ProgramTranslator::TranslateBlock(
const BlockDesc& src_block,
uint64_t start_id,
uint64_t end_id,
TranslationContext* translation_ctx,
pir::Block* dest_block,
bool for_cond_block,
std::vector<std::string> skip_cond_assign) {
Expand Down Expand Up @@ -325,14 +336,14 @@ void ProgramTranslator::TranslateBlock(
std::vector<const OpDesc*> cond_op_list = {op};
std::vector<uint64_t> cond_op_ids = GetCondOpIds(src_block, op_id);
ConditionBlockCombination cond_op_combination(src_block, cond_op_ids);
pir::Operation* if_op =
TranslateCondIfOperation(cond_op_combination, dest_block);
pir::Operation* if_op = TranslateCondIfOperation(
cond_op_combination, translation_ctx, dest_block);
for (auto cond_id : cond_op_ids) {
translate_completed[cond_id] = true;
}
VLOG(10) << "[op translated][conditional_block]" << if_op;
} else if (op->Type() == "while") {
TranslateWhileOperation(op, dest_block);
TranslateWhileOperation(op, translation_ctx, dest_block);
} else {
if (for_cond_block && op->Type() == "assign" &&
std::count(skip_cond_assign.begin(),
Expand All @@ -341,7 +352,7 @@ void ProgramTranslator::TranslateBlock(
assign_inputs.push_back(op->Input("X")[0]);
translate_completed[op_id] = true;
} else {
TranslateGeneralOperation(op, dest_block);
TranslateGeneralOperation(op, translation_ctx, dest_block);
translate_completed[op_id] = true;
}
}
Expand All @@ -351,7 +362,7 @@ void ProgramTranslator::TranslateBlock(
if (for_cond_block) {
std::vector<pir::Value> yeild_inputs;
for (size_t id = 0; id < assign_inputs.size(); id++) {
yeild_inputs.emplace_back(param_map_[assign_inputs[id]].value);
yeild_inputs.emplace_back((*translation_ctx)[assign_inputs[id]].value);
}
pir::AttributeMap attribute_map;
auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name());
Expand All @@ -362,11 +373,13 @@ void ProgramTranslator::TranslateBlock(
}

pir::Operation* ProgramTranslator::TranslateCondIfOperation(
const ConditionBlockCombination& cond_ops, pir::Block* dest_block) {
const ConditionBlockCombination& cond_ops,
TranslationContext* translation_ctx,
pir::Block* dest_block) {
auto& type_translator = TypeTranslator::instance();
auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name());
std::vector<pir::Value> op_inputs = {
param_map_[cond_ops.CondVarName()].value};
(*translation_ctx)[cond_ops.CondVarName()].value};

// NOTE(zhangbo): Now paddle::dialect::IfOp has 0 attribute
pir::AttributeMap attribute_map;
Expand All @@ -386,8 +399,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
op_inputs, attribute_map, op_output_types, op_info, 2);

for (size_t i = 0; i < output_vardescs.size(); i++) {
param_map_.PushValue(output_vardescs[i]->Name(),
VariableDefiningInfo(operation->result(i)));
translation_ctx->PushValue(output_vardescs[i]->Name(),
VariableDefiningInfo(operation->result(i)));
}

dest_block->push_back(operation);
Expand All @@ -398,9 +411,13 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
legacy_program_->Block(cond_ops.TrueBlockId());
pir::Region& true_region = operation->region(0);
if (true_region.empty()) true_region.emplace_back();

auto* true_block_context = translation_ctx->CreateInnerContext();

TranslateBlock(true_sub_block,
0,
true_sub_block.OpSize(),
true_block_context,
true_region.front(),
true,
cond_ops.TrueBlockOutputVarNames());
Expand All @@ -412,9 +429,11 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
legacy_program_->Block(cond_ops.FalseBlockId());
pir::Region& false_region = operation->region(1);
if (false_region.empty()) false_region.emplace_back();
auto* false_block_context = translation_ctx->CreateInnerContext();
TranslateBlock(false_sub_block,
0,
false_sub_block.OpSize(),
false_block_context,
false_region.front(),
true,
cond_ops.FalseBlockOutputVarNames());
Expand All @@ -426,16 +445,18 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
return operation;
}

void ProgramTranslator::TranslateWhileOperation(const OpDesc* op,
pir::Block* dest_block) {
void ProgramTranslator::TranslateWhileOperation(
const OpDesc* op,
TranslationContext* translation_ctx,
pir::Block* dest_block) {
VLOG(8) << "=============>Start to translate while op:" << op;
auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block"));
int index = static_cast<int>(sub_block.OpSize()) - 1;
std::vector<std::pair<std::string, std::string>> loop_vars_reverse;
while (index >= 0) {
auto sub_op = sub_block.Op(index);
if (sub_op->Type() == "assign" &&
param_map_.count(sub_op->Output("Out")[0]) > 0) {
translation_ctx->count(sub_op->Output("Out")[0]) > 0) {
loop_vars_reverse.emplace_back(sub_op->Output("Out")[0],
sub_op->Input("X")[0]);
--index;
Expand All @@ -452,58 +473,60 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op,
"condition var"));
auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::WhileOp::name());
std::vector<pir::Value> op_inputs{
param_map_.at(loop_vars_reverse[0].first).value};
translation_ctx->at(loop_vars_reverse[0].first).value};
std::vector<pir::Type> op_outputs_type;
auto body_block = new pir::Block();
std::vector<TCValue> param_map_status;
std::vector<TCValue> param_status;
for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].first;
auto& tc_value = param_map_.at(name);
auto& tc_value = translation_ctx->at(name);
auto val_type = tc_value.value.type();
op_inputs.push_back(tc_value.value);
op_outputs_type.push_back(val_type);
param_map_status.emplace_back(tc_value);
param_map_.PushValue(name, body_block->AddArgument(val_type));
param_status.emplace_back(tc_value);
translation_ctx->PushValue(name, body_block->AddArgument(val_type));
}
pir::Operation* while_op =
pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1);
dest_block->push_back(while_op);
while_op->region(0).push_back(body_block);
TranslateBlock(sub_block, 0, index + 1, body_block);
TranslateBlock(sub_block, 0, index + 1, translation_ctx, body_block);

auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name());
std::vector<pir::Value> yeild_inputs{
param_map_.at(loop_vars_reverse[0].second).value};
translation_ctx->at(loop_vars_reverse[0].second).value};
for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].second;
yeild_inputs.push_back(param_map_.at(name).value);
yeild_inputs.push_back(translation_ctx->at(name).value);
}
body_block->push_back(
pir::Operation::Create(yeild_inputs, {}, {}, yeild_info));

index = 0;
for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].first;
param_map_.PushValue(name, param_map_status[index++]);
translation_ctx->PushValue(name, param_status[index++]);
}
auto name_iter = loop_vars_reverse.rbegin();
for (size_t idx = 0; idx < while_op->num_results(); ++idx) {
param_map_.PushValue(name_iter++->first, while_op->result(idx));
translation_ctx->PushValue(name_iter++->first, while_op->result(idx));
}
while_op->Verify();
VLOG(8) << "=============>end to translate while op:" << op;
}

void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op,
pir::Block* dest_block) {
void ProgramTranslator::TranslateGeneralOperation(
const OpDesc* src_op,
TranslationContext* translation_ctx,
pir::Block* dest_block) {
auto& op_translator = OpTranslator::instance();
OpTranslateFn& fn = op_translator[src_op->Type()];
if (src_op->Type() == "shadow_output") {
if (!param_map_.count(src_op->Input("x")[0])) {
if (!translation_ctx->count(src_op->Input("x")[0])) {
return;
}
}
pir::Operation* operation = fn(ctx_, &param_map_, *src_op, dest_block);
pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dest_block);
VLOG(10) << "[op translated][general]" << operation << "end";
}

Expand Down Expand Up @@ -592,20 +615,6 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
}
}

void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
auto& op_translator = OpTranslator::instance();
for (auto op : block.AllOps()) {
OpTranslateFn& fn = op_translator[op->Type()];
if (op->Type() == "shadow_output") {
if (!param_map_.count(op->Input("x")[0])) {
continue;
}
}
pir::Operation* operation = fn(ctx_, &param_map_, *op, program_->block());
VLOG(10) << "[op translated][special]" << operation;
}
}

void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
const auto& ops = block.AllOps();
for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) {
Expand Down
28 changes: 20 additions & 8 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
Expand Down Expand Up @@ -68,9 +69,10 @@ class TranslationContext {
using Key = std::string;
using Value = VariableDefiningInfo;
using ValueList = std::vector<Value>;
using Conatiner = std::unordered_map<Key, ValueList>;
using Container = std::unordered_map<Key, ValueList>;

TranslationContext() {}
explicit TranslationContext(TranslationContext* parent) : parent_(parent) {}
~TranslationContext() {}

const Value& operator[](const Key& key) const;
Expand All @@ -80,12 +82,16 @@ class TranslationContext {

void PushValue(const Key& key, const Value& value);
void PopValue(const Key& key);
TranslationContext* CreateInnerContext();

Conatiner::const_iterator begin() const { return container_.begin(); }
Conatiner::const_iterator end() const { return container_.end(); }
Container::const_iterator begin() const { return container_.begin(); }
Container::const_iterator end() const { return container_.end(); }

private:
Conatiner container_;
Container container_;
TranslationContext* parent_ = nullptr;
std::vector<std::unique_ptr<TranslationContext>>
sons_; // used to seperate different block
};

class ProgramTranslator {
Expand Down Expand Up @@ -124,20 +130,26 @@ class ProgramTranslator {
void TranslateBlock(const BlockDesc& src_block,
uint64_t start_id,
uint64_t end_id,
TranslationContext* translation_ctx,
pir::Block* dest_block,
bool for_cond_block = false,
std::vector<std::string> skip_cond_assign = {});
void TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block);
void TranslateGeneralOperation(const OpDesc* src_op,
TranslationContext* translation_ctx,
pir::Block* dest_block);
void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
void SetStopGradientAttributeForAllValue(const BlockDesc& block);
void SetIsPersisableAttributeForAllValue(const BlockDesc& block);

/// Translate methods for control flow ops.
pir::Operation* TranslateCondIfOperation(
const ConditionBlockCombination& cond_ops, pir::Block* dest_block);
void TranslateWhileOperation(const OpDesc* op, pir::Block* dest_block);
const ConditionBlockCombination& cond_ops,
TranslationContext* translation_ctx,
pir::Block* dest_block);
void TranslateWhileOperation(const OpDesc* op,
TranslationContext* translation_ctx,
pir::Block* dest_block);
};

} // namespace translator
Expand Down
Loading

0 comments on commit ec654d2

Please sign in to comment.