Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] Unify PT flags in dy2st and run PT in AST #60410

Merged
merged 4 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 81 additions & 39 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,11 @@ inline void PirRunProgramAPI(
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/false)) {
if (!interpretercore_info_cache.Has(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false,
/*in_pir_mode=*/true)) {
paddle::platform::RecordEvent record_event(
"create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined,
Expand Down Expand Up @@ -555,8 +558,12 @@ inline void PirRunProgramAPI(
1);
VLOG(2) << "Get interpretercore cache by program:" << program_id;
// Step 1. get cache interpretercore
auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/false);
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false,
/*in_pir_mode=*/true);
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScopeByValue(
Expand Down Expand Up @@ -631,6 +638,12 @@ inline void RunProgramAPI(
int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id"));
auto place = egr::Controller::Instance().GetExpectedPlace();

bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st;
if (attrs.count("in_pir_pt_mode")) {
in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode"));
}
in_pir_pt_mode = in_pir_pt_mode || FLAGS_enable_pir_in_executor;

// NOTE(chenweihang): In order not to add new variable type, use vector
// here. Originally, here can use scope directly.
auto *out_scope_vec = &step_scope;
Expand Down Expand Up @@ -688,8 +701,11 @@ inline void RunProgramAPI(
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/false)) {
if (!interpretercore_info_cache.Has(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false,
/*in_pir_mode=*/in_pir_pt_mode)) {
paddle::platform::RecordEvent record_event(
"create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -702,12 +718,7 @@ inline void RunProgramAPI(
details::ShareTensorsIntoScope(params, global_inner_scope);
// Step 2. create new interpretercore

bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st;
if (attrs.count("in_pir_pt_mode")) {
in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode"));
}

if (FLAGS_enable_pir_in_executor || in_pir_pt_mode) {
if (in_pir_pt_mode) {
// build new ir program
auto ir_program =
paddle::framework::ConstructFowardIrProgram(forward_global_block,
Expand Down Expand Up @@ -765,6 +776,7 @@ inline void RunProgramAPI(
global_inner_scope,
place_hash_key,
false,
in_pir_pt_mode,
skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
} else {
Expand All @@ -774,8 +786,12 @@ inline void RunProgramAPI(
1);
VLOG(2) << "Get interpretercore cahce by program:" << program_id;
// Step 1. get cache interpretercore
auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/false);
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false,
/*in_pir_mode=*/in_pir_pt_mode);
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope);
Expand Down Expand Up @@ -840,6 +856,12 @@ inline void RunProgramGradAPI(

int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id"));

bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st;
if (attrs.count("in_pir_pt_mode")) {
in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode"));
}
in_pir_pt_mode = in_pir_pt_mode || FLAGS_enable_pir_in_executor;

auto place = egr::Controller::Instance().GetExpectedPlace();
VLOG(2) << "RunProgramGradOp use interpretercore to execute program.";

Expand All @@ -858,8 +880,11 @@ inline void RunProgramGradAPI(
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/true)) {
if (!interpretercore_info_cache.Has(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/true,
/*in_pir_mode=*/in_pir_pt_mode)) {
paddle::platform::RecordEvent record_event(
"create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -869,12 +894,7 @@ inline void RunProgramGradAPI(
<< program_id;
details::ShareTensorsIntoScope(out_grad, global_inner_scope);

bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st;
if (attrs.count("in_pir_pt_mode")) {
in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode"));
}

if (FLAGS_enable_pir_in_executor || in_pir_pt_mode) {
if (in_pir_pt_mode) {
auto res =
paddle::framework::ConstructBackwardIrProgram(backward_global_block,
out_grad,
Expand Down Expand Up @@ -904,14 +924,19 @@ inline void RunProgramGradAPI(
// share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core.
if (interpretercore_info_cache.Has(
program_id, global_inner_scope, place_hash_key, false)) {
auto fwd_interpreter_core = interpretercore_info_cache
.GetMutable(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false)
.core_;
if (interpretercore_info_cache.Has(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false,
/*in_pir_mode=*/in_pir_pt_mode)) {
auto fwd_interpreter_core =
interpretercore_info_cache
.GetMutable(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false,
/*in_pir_mode=*/in_pir_pt_mode)
.core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
<< interpreter_core.get();
Expand All @@ -938,6 +963,7 @@ inline void RunProgramGradAPI(
global_inner_scope,
place_hash_key,
/*is_grad=*/true,
in_pir_pt_mode,
skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
} else {
Expand All @@ -946,8 +972,12 @@ inline void RunProgramGradAPI(
paddle::platform::TracerEventType::UserDefined,
1);
VLOG(2) << "Get interpretercore cahce by program:" << program_id;
auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/true);
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/true,
/*in_pir_mode=*/in_pir_pt_mode);
interpreter_core = cached_value.core_;

// update scope
Expand Down Expand Up @@ -1054,8 +1084,11 @@ inline void PirRunProgramGradAPI(
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/true)) {
if (!interpretercore_info_cache.Has(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/true,
/*in_pir_mode=*/true)) {
paddle::platform::RecordEvent record_event(
"create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -1080,13 +1113,17 @@ inline void PirRunProgramGradAPI(
// share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core.
if (interpretercore_info_cache.Has(
program_id, global_inner_scope, place_hash_key, false)) {
if (interpretercore_info_cache.Has(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false,
/*in_pir_mode=*/true)) {
auto fwd_interpreter_core = interpretercore_info_cache
.GetMutable(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/false)
/*is_grad=*/false,
/*in_pir_mode=*/true)
.core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
Expand All @@ -1107,6 +1144,7 @@ inline void PirRunProgramGradAPI(
global_inner_scope,
place_hash_key,
/*is_grad=*/true,
/*in_pir_mode=*/true,
skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
details::print_collection(skip_eager_delete_vars);
Expand All @@ -1116,8 +1154,12 @@ inline void PirRunProgramGradAPI(
paddle::platform::TracerEventType::UserDefined,
1);
VLOG(2) << "Get interpretercore cahce by program:" << program_id;
auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, global_inner_scope, place_hash_key, /*is_grad=*/true);
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id,
global_inner_scope,
place_hash_key,
/*is_grad=*/true,
/*in_pir_mode=*/true);
interpreter_core = cached_value.core_;

if (interpreter_core->GetVariableScope()->GetMutableScope() !=
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ std::shared_ptr<InterpreterCore> CreateProgramInterpreterCoreInfoToCache(
place, program_desc.Block(0), scope, execution_config));

auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, scope, place_hash_key, is_grad);
program_id, scope, place_hash_key, is_grad, /*in_pir_mode=*/false);
cached_value.core_ = core;
return core;
}
Expand Down Expand Up @@ -355,7 +355,7 @@ std::shared_ptr<InterpreterCore> CreatePirInterpreterCoreInfoToCache(
place, {}, ir_program->block(), scope, execution_config));

auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, scope, place_hash_key, is_grad);
program_id, scope, place_hash_key, is_grad, /*in_pir_mode=*/true);
cached_value.core_ = core;
cached_value.ir_prog_ = std::move(ir_program);
return core;
Expand Down
18 changes: 12 additions & 6 deletions paddle/fluid/framework/executor_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ class InterpreterCoreInfoCache {
bool Has(int64_t program_id,
const framework::Scope* scope,
const int64_t& place_hash_key,
bool is_grad) {
if (FLAGS_enable_pir_in_executor || FLAGS_enable_pir_with_pt_in_dy2st) {
bool is_grad,
bool in_pir_mode) {
if (in_pir_mode) {
int64_t scope_i = reinterpret_cast<int64_t>(scope);
program_id = hash_with_seed(program_id, scope_i);
program_id = hash_with_seed(program_id, place_hash_key);
Expand All @@ -209,8 +210,9 @@ class InterpreterCoreInfoCache {
InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id,
const framework::Scope* scope,
const int64_t& place_hash_key,
bool is_grad) {
if (FLAGS_enable_pir_in_executor || FLAGS_enable_pir_with_pt_in_dy2st) {
bool is_grad,
bool in_pir_mode) {
if (in_pir_mode) {
int64_t scope_i = reinterpret_cast<int64_t>(scope);
program_id = hash_with_seed(program_id, scope_i);
program_id = hash_with_seed(program_id, place_hash_key);
Expand All @@ -222,16 +224,20 @@ class InterpreterCoreInfoCache {
const framework::Scope* scope,
const int64_t& place_hash_key,
bool is_grad,
bool in_pir_mode,
const std::set<std::string>& skip_vars) {
auto& cached_value = GetMutable(program_id, scope, place_hash_key, is_grad);
auto& cached_value =
GetMutable(program_id, scope, place_hash_key, is_grad, in_pir_mode);
cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
}

std::set<std::string>& GetSkipEagerDeleteVars(int64_t program_id,
const framework::Scope* scope,
const int64_t& place_hash_key,
bool in_pir_mode,
bool is_grad) {
auto& cached_value = GetMutable(program_id, scope, place_hash_key, is_grad);
auto& cached_value =
GetMutable(program_id, scope, place_hash_key, is_grad, in_pir_mode);
return cached_value.skip_eager_delete_vars_;
}

Expand Down
Loading