Skip to content

Commit

Permalink
[PIR] support while op translate. (PaddlePaddle#58098)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored and jiahy0825 committed Oct 26, 2023
1 parent 6c051f8 commit 7386e26
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 25 deletions.
20 changes: 10 additions & 10 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx,
pir::OpResult value = operation->result(idx_in_op);
bool generated_by_vector = value.type().isa<pir::VectorType>();

param_map->insert(
param_map->UpdateValue(
arg_name,
VariableDefiningInfo(
value,
Expand Down Expand Up @@ -1439,7 +1439,7 @@ pir::OpResult TranslateNumClassesForOneHot(
"%s should be existed in one_hot_v2 as input depth_tensor.",
legacy_vars[0]);
auto defining_info = param_map->at(legacy_vars[0]);
return defining_info.value;
return defining_info.value.dyn_cast<pir::OpResult>();
}

auto& attribute_translator = AttributeTranslator::instance();
Expand Down Expand Up @@ -1529,7 +1529,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
ctx, param_map, block, x_defining_info, x_name);
x_defining_info = param_map->at(x_name);
}
pir::OpResult x_value = x_defining_info.value;
pir::OpResult x_value = x_defining_info.value.dyn_cast<pir::OpResult>();
IR_ENFORCE(x_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand Down Expand Up @@ -1560,7 +1560,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
ctx, param_map, block, y_defining_info, y_name);
y_defining_info = param_map->at(y_name);
}
pir::OpResult y_value = y_defining_info.value;
pir::OpResult y_value = y_defining_info.value.dyn_cast<pir::OpResult>();
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand Down Expand Up @@ -1680,7 +1680,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
op_desc.Type(),
y_name);
auto y_defining_info = param_map->at(y_name);
pir::OpResult y_value = y_defining_info.value;
pir::OpResult y_value = y_defining_info.value.dyn_cast<pir::OpResult>();
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
Expand All @@ -1698,8 +1698,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
pir::OpResult value = operation->result(idx_in_op);
pir::Builder builder(ctx, operation->GetParent());
auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape);
param_map->insert(y_grad_var_name,
VariableDefiningInfo(reshape_op.out(), false, -1));
param_map->UpdateValue(y_grad_var_name,
VariableDefiningInfo(reshape_op.out(), false, -1));
}
};

Expand Down Expand Up @@ -1771,7 +1771,7 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber {
ctx, param_map, block, defining_info, var_name);
defining_info = param_map->at(var_name).value;
}
return defining_info.value;
return defining_info.value.dyn_cast<pir::OpResult>();
};
}
};
Expand Down Expand Up @@ -1866,8 +1866,8 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber {
auto output_var = output_vars[0];
auto fused_feedforward_op =
operation->dyn_cast<dialect::FusedFeedforwardOp>();
param_map->insert(output_var,
VariableDefiningInfo{fused_feedforward_op.out()});
param_map->UpdateValue(output_var,
VariableDefiningInfo{fused_feedforward_op.out()});
}
}
};
Expand Down
101 changes: 90 additions & 11 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {

const std::unordered_set<std::string> ProgramTranslator::unsupported_ops = {
"conditional_block_grad",
"while",
"while_grad",
};

Expand Down Expand Up @@ -255,9 +254,20 @@ size_t TranslationContext::count(const TCKey& key) const {
return values.size();
}

void TranslationContext::insert(const TCKey& key, const TCValue& value) {
void TranslationContext::PushValue(const Key& key, const Value& value) {
container_[key].push_back(value);
}
void TranslationContext::PopValue(const Key& key) {
container_[key].pop_back();
}

void TranslationContext::UpdateValue(const Key& key, const Value& value) {
auto& vec = container_[key];
if (vec.empty())
vec.push_back(value);
else
vec.back() = value;
}

ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
pir::Program* program)
Expand Down Expand Up @@ -329,6 +339,8 @@ void ProgramTranslator::TranslateBlock(
translate_completed[cond_id] = true;
}
VLOG(10) << "[op translated][conditional_block]" << if_op;
} else if (op->Type() == "while") {
TranslateWhileOperation(op, dest_block);
} else {
if (for_cond_block && op->Type() == "assign" &&
std::count(skip_cond_assign.begin(),
Expand Down Expand Up @@ -382,8 +394,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
op_inputs, attribute_map, op_output_types, op_info, 2);

for (size_t i = 0; i < output_vardescs.size(); i++) {
param_map_.insert(output_vardescs[i]->Name(),
VariableDefiningInfo(operation->result(i)));
param_map_.PushValue(output_vardescs[i]->Name(),
VariableDefiningInfo(operation->result(i)));
}

dest_block->push_back(operation);
Expand Down Expand Up @@ -422,6 +434,71 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
return operation;
}

void ProgramTranslator::TranslateWhileOperation(const OpDesc* op,
pir::Block* dest_block) {
VLOG(8) << "=============>Start to translate while op:" << op;
auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block"));
int index = static_cast<int>(sub_block.OpSize()) - 1;
std::vector<std::pair<std::string, std::string>> loop_vars_reverse;
while (index >= 0) {
auto sub_op = sub_block.Op(index);
if (sub_op->Type() == "assign" &&
param_map_.count(sub_op->Output("Out")[0]) > 0) {
loop_vars_reverse.emplace_back(sub_op->Output("Out")[0],
sub_op->Input("X")[0]);
--index;
} else {
break;
}
}
PADDLE_ENFORCE(!loop_vars_reverse.empty(),
platform::errors::PreconditionNotMet(
"While op must has condition value input"));
PADDLE_ENFORCE(loop_vars_reverse.front().first == op->Input("Condition")[0],
platform::errors::PreconditionNotMet(
"The last op in sub_block of While op must used to assign "
"condition var"));
auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::WhileOp::name());
std::vector<pir::Value> op_inputs{
param_map_.at(loop_vars_reverse[0].first).value};
std::vector<pir::Type> op_outputs_type;
auto body_block = new pir::Block();
for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].first;
auto val = param_map_.at(name).value;
auto val_type = val.type();
op_inputs.push_back(val);
op_outputs_type.push_back(val_type);
param_map_.PushValue(name, body_block->AddArgument(val_type));
}
pir::Operation* while_op =
pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1);
dest_block->push_back(while_op);
while_op->region(0).push_back(body_block);
TranslateBlock(sub_block, 0, index + 1, body_block);

auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name());
std::vector<pir::Value> yeild_inputs{
param_map_.at(loop_vars_reverse[0].second).value};
for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].second;
yeild_inputs.push_back(param_map_.at(name).value);
}
body_block->push_back(
pir::Operation::Create(yeild_inputs, {}, {}, yeild_info));

for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) {
auto& name = loop_vars_reverse[idx].first;
param_map_.PopValue(name);
}
auto name_iter = loop_vars_reverse.rbegin();
for (size_t idx = 0; idx < while_op->num_results(); ++idx) {
param_map_.UpdateValue(name_iter++->first, while_op->result(idx));
}
while_op->Verify();
VLOG(8) << "=============>end to translate while op:" << op;
}

void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op,
pir::Block* dest_block) {
auto& op_translator = OpTranslator::instance();
Expand All @@ -432,7 +509,7 @@ void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op,
}
}
pir::Operation* operation = fn(ctx_, &param_map_, *src_op, dest_block);
VLOG(10) << "[op translated][special]" << operation << "end";
VLOG(10) << "[op translated][general]" << operation << "end";
}

inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx,
Expand All @@ -451,7 +528,7 @@ inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx,
}

inline pir::Operation* InsertSetParamaterOp(pir::IrContext* ctx,
pir::OpResult defining_op_result,
pir::Value defining_op_result,
const VarDesc* var) {
std::string set_parameter_op_name(pir::SetParameterOp::name());
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name);
Expand Down Expand Up @@ -502,7 +579,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
"VarDesc of [%s] can not be nullptr", var_name));
pir::Operation* op = InsertGetParamaterOp(ctx_, var_desc);
program_->block()->push_back(op);
param_map_.insert(var_name, VariableDefiningInfo(op->result(0)));
param_map_.PushValue(var_name, VariableDefiningInfo(op->result(0)));
VLOG(10) << "[op translated][get parameter]" << var_name;

program_->SetParameter(var_name, nullptr);
Expand Down Expand Up @@ -554,7 +631,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
need_set_parameter_op &= (param_map_.count(var_name) != 0);
need_set_parameter_op &= (!set_input_var_names.count(var_name));
if (need_set_parameter_op) {
pir::OpResult defining_op_result = param_map_[var_name].value;
pir::OpResult defining_op_result =
param_map_[var_name].value.dyn_cast<pir::OpResult>();
if (!defining_op_result) {
continue;
}
Expand All @@ -565,7 +643,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
program_->block(),
param_map_[var_name],
var_name);
defining_op_result = param_map_.at(var_name).value;
defining_op_result =
param_map_.at(var_name).value.dyn_cast<pir::OpResult>();
}

pir::Operation* op = InsertSetParamaterOp(
Expand Down Expand Up @@ -604,7 +683,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
continue;
}
for (const auto& value_info : value_list) {
pir::OpResult value = value_info.value;
pir::OpResult value = value_info.value.dyn_cast<pir::OpResult>();
if (!value) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Value of [%s] can not ber None", var_name));
Expand Down Expand Up @@ -645,7 +724,7 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue(
continue;
}
for (const auto& value_info : value_list) {
pir::OpResult value = value_info.value;
pir::OpResult value = value_info.value.dyn_cast<pir::OpResult>();
if (!value) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Value of [%s] can not ber None", var_name));
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ namespace paddle {
namespace translator {

struct VariableDefiningInfo {
VariableDefiningInfo(pir::OpResult value,
VariableDefiningInfo(pir::Value value,
bool generated_by_vector = false,
int idx_in_vector = -1)
: value(value),
generated_by_vector(generated_by_vector),
idx_in_vector(idx_in_vector) {}
VariableDefiningInfo() {}

pir::OpResult value;
pir::Value value;

bool generated_by_vector =
false; // true if target variable is generated by Vector<Tensor>
Expand Down Expand Up @@ -78,7 +78,9 @@ class TranslationContext {
size_t count(const Key& key)
const; // Caution: not exactly same as count in stl library

void insert(const Key& key, const Value& value);
void UpdateValue(const Key& key, const Value& value);
void PushValue(const Key& key, const Value& value);
void PopValue(const Key& key);

Conatiner::const_iterator begin() const { return container_.begin(); }
Conatiner::const_iterator end() const { return container_.end(); }
Expand Down Expand Up @@ -136,6 +138,7 @@ class ProgramTranslator {
/// Translate methods for control flow ops.
pir::Operation* TranslateCondIfOperation(
const ConditionBlockCombination& cond_ops, pir::Block* dest_block);
void TranslateWhileOperation(const OpDesc* op, pir::Block* dest_block);
};

} // namespace translator
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pir::Operation* InsertSliceOperationForTarget(
op_info);
block->push_back(operation);
pir::OpResult target_op_result = operation->result(0);
param_map->insert(arg_name, VariableDefiningInfo(target_op_result));
param_map->UpdateValue(arg_name, VariableDefiningInfo(target_op_result));
return operation;
}

Expand Down
5 changes: 5 additions & 0 deletions test/cpp/pir/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ file(
${CMAKE_CURRENT_BINARY_DIR}/conditional_block_test.prog
EXPECTED_MD5 cf9dc869ca7f69e2d57b38dbf8427134)

file(
DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/while_op_test.prog
${CMAKE_CURRENT_BINARY_DIR}/while_op_test.prog
EXPECTED_MD5 290164ae52a496332b0be5829fc93bcd)

copy_if_different(${CMAKE_CURRENT_SOURCE_DIR}/TestParserText.txt
${CMAKE_CURRENT_BINARY_DIR}/TestParserText.txt)

Expand Down
72 changes: 72 additions & 0 deletions test/cpp/pir/core/program_translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,75 @@ TEST(IrParserTest, StartupProgram) {

EXPECT_TRUE(ssp.str() == ss.str());
}

TEST(OperatorDialectTest, WhileOpProgram) {
auto p = load_from_file("while_op_test.prog");
EXPECT_EQ(p.Size(), 3u);
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<OperatorDialect>();
ctx->GetOrRegisterDialect<pir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);

std::stringstream ss;
program->Print(ss);

LOG(INFO) << ss.str();

EXPECT_EQ(program->block()->size(), 4u);
size_t id = 0;
for (auto &op : *program->block()) {
if (id == 0 || id == 1) {
EXPECT_TRUE(op->isa<paddle::dialect::FullOp>());
}
if (id == 2) {
EXPECT_TRUE(op->isa<paddle::dialect::LessThanOp>());
}
if (id == 3) {
EXPECT_TRUE(op->isa<paddle::dialect::WhileOp>());
EXPECT_EQ(op->num_regions(), 1u);
// body block
pir::Block *body_block =
op->dyn_cast<paddle::dialect::WhileOp>().body_block();
size_t body_id = 0;
for (auto &op1 : *body_block) {
if (body_id == 0) {
EXPECT_TRUE(op1->isa<paddle::dialect::FullOp>());
}
if (body_id == 1) {
EXPECT_TRUE(op1->isa<paddle::dialect::ScaleOp>());
}
if (body_id == 2) {
EXPECT_TRUE(op1->isa<paddle::dialect::LessThanOp>());
}
if (body_id == 3) {
pir::Block *body_body_block =
op1->dyn_cast<paddle::dialect::WhileOp>().body_block();
size_t body_body_id = 0;
for (auto &op2 : *body_body_block) {
if (body_body_id == 0) {
EXPECT_TRUE(op2->isa<paddle::dialect::FullOp>());
}
if (body_body_id == 1) {
EXPECT_TRUE(op2->isa<paddle::dialect::ScaleOp>());
}
if (body_body_id == 2) {
EXPECT_TRUE(op2->isa<paddle::dialect::LessThanOp>());
}
if (body_body_id == 3) {
EXPECT_TRUE(op2->isa<pir::YieldOp>());
}
body_body_id++;
}
}
if (body_id == 4) {
EXPECT_TRUE(op1->isa<paddle::dialect::LessThanOp>());
}
if (body_id == 5) {
EXPECT_TRUE(op1->isa<pir::YieldOp>());
}
body_id++;
}
}
id++;
}
}

0 comments on commit 7386e26

Please sign in to comment.