Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] New ir supprt cond part 1 #57265

Merged
merged 58 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
2895246
[IR] add control flow dialect
winter-wang Aug 30, 2023
f33b088
fix error
winter-wang Aug 30, 2023
24b512d
add error
winter-wang Aug 30, 2023
46cda49
test
winter-wang Aug 30, 2023
6b58650
fix
winter-wang Aug 30, 2023
45f1bb0
fix
winter-wang Aug 30, 2023
90c756f
fix
winter-wang Aug 30, 2023
5caccaf
fix
winter-wang Aug 30, 2023
a01232f
fix
winter-wang Aug 31, 2023
9a843ba
fix
winter-wang Aug 31, 2023
f416966
update
phlrain Sep 2, 2023
d333660
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 2, 2023
b71bf92
fix
winter-wang Sep 2, 2023
583a18d
update
phlrain Sep 4, 2023
8674952
Merge commit 'refs/pull/56799/head' of https://github.com/PaddlePaddl…
phlrain Sep 4, 2023
586e673
update
phlrain Sep 5, 2023
a73c075
refactor pd op to kernel pass
phlrain Sep 5, 2023
1f70f79
Merge commit 'refs/pull/56984/head' of https://github.com/PaddlePaddl…
phlrain Sep 5, 2023
40e6299
fix bug
phlrain Sep 6, 2023
31ce7d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 6, 2023
adf8d00
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 6, 2023
5b45c5b
update
phlrain Sep 6, 2023
31a9a05
add code
phlrain Sep 6, 2023
aaf53a0
update
phlrain Sep 6, 2023
a08f368
Merge commit 'refs/pull/56984/head' of https://github.com/PaddlePaddl…
phlrain Sep 6, 2023
4d150e3
fix merge bug
phlrain Sep 6, 2023
01e2853
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 6, 2023
70d3f48
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 10, 2023
55852be
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 11, 2023
38f54e1
fix bug
phlrain Sep 11, 2023
f2f001a
update
phlrain Sep 11, 2023
cafad45
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 11, 2023
179fd45
update
phlrain Sep 11, 2023
79b5ff7
update
phlrain Sep 11, 2023
e9cf31f
Merge commit 'refs/pull/56984/head' of https://github.com/PaddlePaddl…
phlrain Sep 11, 2023
71aac0a
fix bug
phlrain Sep 11, 2023
ca5b16b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 11, 2023
ec9683d
update
phlrain Sep 11, 2023
a1ecdd9
update
phlrain Sep 11, 2023
d8ab9eb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 11, 2023
3d079b7
fix bug
phlrain Sep 11, 2023
81022e6
update
phlrain Sep 12, 2023
ca4b437
update
phlrain Sep 12, 2023
0ba7246
fix bug
phlrain Sep 12, 2023
d2fd876
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 12, 2023
4053e96
remove useless code
phlrain Sep 12, 2023
bc5812e
update
phlrain Sep 12, 2023
76d210a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 12, 2023
9672e89
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 13, 2023
bafe0d2
update
phlrain Sep 13, 2023
9edc708
update
phlrain Sep 13, 2023
5001c48
new ir support cond part 1
phlrain Sep 13, 2023
5c0549f
fix compile bug
phlrain Sep 13, 2023
5e620ff
fix compile bug
phlrain Sep 13, 2023
bcde14f
fix compile bug
phlrain Sep 13, 2023
11db5a8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Sep 13, 2023
91cc6a4
fix bug
phlrain Sep 14, 2023
fc1a0d9
open test case
phlrain Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,12 @@ std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
std::shared_ptr<InterpreterCore> core = nullptr;

core.reset(new InterpreterCore(
place, {}, std::move(ir_program), scope, execution_config));
place, {}, ir_program->block(), scope, execution_config));

auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, scope, is_grad);
cached_value.core_ = core;
cached_value.ir_prog_ = std::move(ir_program);
return core;
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/executor_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class InterpreterCoreInfo {
struct CacheValue {
std::shared_ptr<InterpreterCore> core_{nullptr};
std::set<std::string> skip_eager_delete_vars_;
std::unique_ptr<::pir::Program> ir_prog_{nullptr};
};

bool IsAvailable(bool is_grad) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/new_executor/interpreter/plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ std::shared_ptr<::pir::Program> Plan::IrProgram(
return type_to_ir_program_.at(job_type);
}

void Plan::UpdateIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog) {
type_to_ir_program_[job_type] = ir_prog;
}

int64_t Plan::MicroBatchNum() const { return micro_batch_num_; }

} // namespace interpreter
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/new_executor/interpreter/plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ class Plan final {
const ProgramDesc* Program(const std::string& job_type) const;
std::shared_ptr<::pir::Program> IrProgram(const std::string& job_type) const;

void UpdateIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog);

int64_t MicroBatchNum() const;

private:
const std::vector<std::shared_ptr<Job>> job_list_;
const std::unordered_map<std::string, ProgramDesc*> type_to_program_;
const std::unordered_map<std::string, std::shared_ptr<::pir::Program>>
std::unordered_map<std::string, std::shared_ptr<::pir::Program>>
type_to_ir_program_;
int64_t micro_batch_num_;
};
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
InterpreterCore::InterpreterCore(
const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_block,
framework::Scope* scope,
const ExecutionConfig& execution_config) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place;
impl_ = std::make_unique<NewIRInterpreter>(
place, fetch_var_names, std::move(ir_prog), scope, execution_config);
place, fetch_var_names, ir_block, scope, execution_config);
}

InterpreterCore::~InterpreterCore() {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PD_DECLARE_bool(new_executor_use_local_scope);

namespace pir {
class Program;
class Block;
} // namespace pir

namespace paddle {
Expand All @@ -38,7 +38,7 @@ class InterpreterCore {
// This constructor is for New IR.
InterpreterCore(const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_prog,
Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig());
~InterpreterCore();
Expand Down
21 changes: 13 additions & 8 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ namespace framework {
NewIRInterpreter::NewIRInterpreter(
const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_block,
framework::Scope* scope,
const ExecutionConfig& execution_config)
: place_(place),
execution_config_(execution_config),
var_scope_(scope),
scope_(scope),
ir_program_(std::move(ir_prog)),
ir_block_(ir_block),
ir_stream_analyzer_(place),
fetch_var_names_(fetch_var_names) {
VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_;
Expand All @@ -92,8 +92,7 @@ NewIRInterpreter::NewIRInterpreter(
// TODO(zhangbo): delete var_scope
var_scope_.SetLocalScope(local_scope_);

execution_config_.AnalyzeThreadPoolConfig(place,
ir_program_->block()->size());
execution_config_.AnalyzeThreadPoolConfig(place, 1);
execution_config_.Log(/*log_level=*/8);

ir_instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
Expand Down Expand Up @@ -334,6 +333,10 @@ Scope* NewIRInterpreter::InnerScope() {
return local_scope_ != nullptr ? local_scope_ : scope_;
}

std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const {
return value_2_var_name_.at(value);
}

void NewIRInterpreter::UpdateSyncOpNum() {
int64_t sync_op_num = 0;
for (auto& ins : vec_instruction_base_) {
Expand Down Expand Up @@ -501,13 +504,15 @@ void NewIRInterpreter::BuildInstruction() {
VLOG(6) << "Build Instructions for new ir ... ";
vec_instruction_base_.clear();
size_t op_idx = 0;
for (auto& op : *ir_program_->block()) {
for (auto& op : *ir_block_) {
VLOG(6) << "Build Instruction for op: " << op_idx;
if (op->dialect()->name() == "builtin") {
if (interpreter::GetSpecialOpNames().count(op->name())) {
VLOG(6) << "skip process " << op->name();
continue;
}
} else if (op->dialect()->name() == "cf") {
continue;
} else if (op->dialect()->name() == "pd_kernel") {
auto op_name = op->attributes()
.at("op_name")
Expand Down Expand Up @@ -919,13 +924,14 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
// Build
std::stringstream ss;
ss << this;
::pir::BuildScope(*ir_program_->block(),
::pir::BuildScope(*ir_block_,
InnerScope(),
ss.str(),
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_);
&variable_list_,
&sub_blocks_);

interpreter::BuildId2VarName(var_name_2_id_, &id_2_var_name_);

Expand Down Expand Up @@ -973,7 +979,6 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}

// return Fetch Tensors
Scope* inner_scope = InnerScope();
if (FLAGS_enable_new_ir_in_executor) {
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "paddle/pir/core/value.h"

namespace ir {
class Program;
class Block;
} // namespace ir

namespace paddle {
Expand All @@ -36,7 +36,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
public:
NewIRInterpreter(const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_block,
Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig());

Expand Down Expand Up @@ -81,6 +81,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {

int GetIdByName(const std::string& name) const;

std::string GetNameByValue(::pir::Value value) const;

private:
// build graph
void UpdateSyncOpNum();
Expand Down Expand Up @@ -198,7 +200,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {

InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;

std::unique_ptr<::pir::Program> ir_program_{nullptr};
const ::pir::Block* ir_block_{nullptr};

std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_;

Expand All @@ -211,6 +213,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::unordered_map<int, std::string> id_2_var_name_;

std::vector<Variable*> variable_list_;
std::map<pir::Block*, paddle::framework::Scope*> sub_blocks_;

std::vector<int> var_ref_count_;

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,19 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place);
std::shared_ptr<pir::Program> shared_program = std::move(kernel_program);
plan_.UpdateIrProgram("base", shared_program);

if (FLAGS_new_ir_apply_inplace_pass) {
pir::PassManager pm(pir::IrContext::Instance(), 3);
pm.AddPass(pir::CreateInplacePass());
pm.Run(kernel_program.get());
pm.Run(shared_program.get());
}

interpretercores_.emplace_back(
std::make_shared<InterpreterCore>(place_,
fetch_var_names_,
std::move(kernel_program),
shared_program->block(),
scope_,
execution_config));
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/standalone_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class StandaloneExecutor {
private:
bool is_interpretercore_build_result_shared_{false};
const platform::Place place_;
const interpreter::Plan plan_;
interpreter::Plan plan_;

std::vector<framework::Scope*> micro_batch_scopes_;
std::vector<std::shared_ptr<InterpreterCore>> interpretercores_;
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class PhiKernelAdaptor {
variable_2_var_name;
std::map<std::string, int> var_name_2_id;
std::vector<paddle::framework::Variable*> variable_list;
std::map<pir::Block*, paddle::framework::Scope*> sub_blocks;
std::stringstream ss;
ss << this;

Expand All @@ -69,7 +70,8 @@ class PhiKernelAdaptor {
&value_2_var_name,
&variable_2_var_name,
&var_name_2_id,
&variable_list);
&variable_list,
&sub_blocks);
pir::IrContext* ctx = pir::IrContext::Instance();

ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down
59 changes: 48 additions & 11 deletions paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,21 @@
#include "glog/logging.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"

namespace pir {

const std::unordered_set<std::string> SpecialOps = {"pd_op.feed",
"pd_op.fetch",
"builtin.combine",
"builtin.set_parameter",
"builtin.get_parameter",
"builtin.slice",
"builtin.split",
"pd_op.data",
"pd_op.shadow_output",
"pd_op.if"};

void AddNewData(pir::Value value,
std::string name,
paddle::framework::Variable* var,
Expand Down Expand Up @@ -226,7 +238,8 @@ void HandleForSpecialOp(
std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
std::vector<paddle::framework::Variable*>* variable_list,
std::map<pir::Block*, paddle::framework::Scope*>* sub_blocks) {
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
Expand Down Expand Up @@ -425,6 +438,33 @@ void HandleForSpecialOp(
value_2_var_name->emplace(out_value, var_name);
}
}

if (op_name == "pd_op.if") {
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();

auto true_block = if_op.true_block();

auto false_block = if_op.false_block();

auto& true_branch_scope = inner_scope->NewScope();
sub_blocks->emplace(true_block, &true_branch_scope);

auto& false_branch_scope = inner_scope->NewScope();
sub_blocks->emplace(false_block, &false_branch_scope);

for (size_t i = 0; i < if_op->num_results(); ++i) {
// auto true_value = true_yeid_op->operand_source(i);

auto if_op_out_value = if_op->result(i);
BuildValue(if_op_out_value,
inner_scope,
var_name_prefix,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
}
}

void HandleForInplaceOp(
Expand Down Expand Up @@ -483,16 +523,17 @@ void HandleForInplaceOp(
}
}

// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other is
// created in inner_scope.
// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other
// is created in inner_scope.
void BuildScope(const pir::Block& block,
paddle::framework::Scope* inner_scope,
const std::string& var_name_prefix,
std::unordered_map<pir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
std::vector<paddle::framework::Variable*>* variable_list,
std::map<pir::Block*, paddle::framework::Scope*>* sub_blocks) {
VLOG(4) << "***** [before build] scope"
<< "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
Expand All @@ -507,19 +548,15 @@ void BuildScope(const pir::Block& block,
.AsString();
}
VLOG(4) << "build op:" << op_name;

if (op_name == "pd_op.feed" || op_name == "pd_op.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "builtin.split" || op_name == "pd_op.data" ||
op_name == "pd_op.shadow_output") {
if (SpecialOps.count(op_name)) {
HandleForSpecialOp(op,
inner_scope,
var_name_prefix,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
variable_list,
sub_blocks);
continue;
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void BuildScope(const pir::Block& block,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list);
std::vector<paddle::framework::Variable*>* variable_list,
std::map<pir::Block*, paddle::framework::Scope*>* sub_blocks);

void BuildRuntimeContext(
pir::Operation* op,
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ class ConstantFoldingPattern : public pir::RewritePattern {

// Execute program
exe_config_.create_local_scope = false;
paddle::framework::InterpreterCore core(
phi::CPUPlace{},
fetch_var_names,
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
&scope_,
exe_config_);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(temp_program.get());
paddle::framework::InterpreterCore core(phi::CPUPlace{},
fetch_var_names,
kernel_program->block(),
&scope_,
exe_config_);

paddle::framework::FetchList fetch_list = core.Run({});

Expand Down
Loading