Skip to content

Commit

Permalink
[pir/dy2static] support infer mode in pir dy2static. (PaddlePaddle#58588
Browse files Browse the repository at this point in the history
)

* [pir/dy2static] support infer mode in pir dy2static. Adding RunableProgram interface to simplify partial program layer.

* fix
  • Loading branch information
2742195759 authored and zeroRains committed Nov 8, 2023
1 parent 1ff3d63 commit 29ef87e
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 487 deletions.
12 changes: 8 additions & 4 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,15 +456,19 @@ inline void NewIRRunProgramAPI(

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
print_stream << "BackwardProgram is empty in test mode.\n";
}
std::cout << "Program (fwd | bwd): \n" << print_stream.str() << std::endl;
}

Expand Down
44 changes: 27 additions & 17 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -835,20 +835,6 @@ Operation *BuildOpFrom(
return cloned_op;
}

std::shared_ptr<Program> ProgramClone(const Program &program) {
// Limitation of this function:
// 1. don't support Parameters.
// 2. don't support Regions in operator.
pir::IrContext *ctx = pir::IrContext::Instance();
auto cloned_program = std::make_shared<Program>(ctx);
std::unordered_map<pir::Value, pir::Value> value_map;
for (auto &op : *program.block()) {
auto *cloned_op = BuildOpFrom(op, value_map);
cloned_program->block()->push_back(cloned_op);
}
return cloned_program;
}

std::list<Operation *>::const_iterator list_offset(const Block *block,
int start_idx) {
auto it = block->begin();
Expand Down Expand Up @@ -964,7 +950,31 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
no_need_buffer_values.end());
}

SplitedResult ForwardBackwardSplit(
using OpResultMap = std::unordered_map<pir::OpResult, pir::OpResult>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
const Program &program,
const std::vector<pir::OpResult> &op_result_forward_inputs,
const std::vector<pir::OpResult> &op_result_forward_params,
const std::vector<pir::OpResult> &op_result_forward_outputs) {
// Limitation of this function:
// 1. don't support Parameters.
// 2. don't support Regions in operator.
pir::IrContext *ctx = pir::IrContext::Instance();
auto cloned_program = std::make_shared<Program>(ctx);
std::unordered_map<pir::Value, pir::Value> value_map;
for (auto &op : *program.block()) {
auto *cloned_op = BuildOpFrom(op, value_map);
cloned_program->block()->push_back(cloned_op);
}
std::unordered_map<pir::OpResult, pir::OpResult> op_result_map;
for (auto &pair : value_map) {
op_result_map[pair.first.dyn_cast<pir::OpResult>()] =
pair.second.dyn_cast<pir::OpResult>();
}
return std::make_pair(cloned_program, op_result_map);
}

SplitedResult SplitForwardBackward(
const Program &program,
const std::vector<pir::OpResult> &op_result_forward_inputs,
const std::vector<pir::OpResult> &op_result_forward_params,
Expand Down Expand Up @@ -1201,8 +1211,8 @@ SplitedResult ForwardBackwardSplit(
}

void BindUtils(pybind11::module *m) {
m->def("program_clone", ProgramClone);
m->def("program_split", ForwardBackwardSplit);
m->def("clone_program", CloneProgram);
m->def("split_program", SplitForwardBackward);
m->def("fake_op_result", FakeOpResult);
m->def("is_fake_op_result", IsFakeOpResult);
m->def("set_global_program",
Expand Down
10 changes: 10 additions & 0 deletions paddle/pir/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,18 @@ class IR_API OpResult : public Value {
OpResult(detail::OpResultImpl *impl); // NOLINT
// Access classof annd dyn_cast_from.
friend Value;
friend struct std::hash<OpResult>;
static bool classof(Value value);
static OpResult dyn_cast_from(Value value);
};

} // namespace pir

namespace std {
template <>
struct hash<pir::OpResult> {
std::size_t operator()(const pir::OpResult &obj) const {
return std::hash<pir::Value>()(obj);
}
};
} // namespace std
Loading

0 comments on commit 29ef87e

Please sign in to comment.