-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support control flow for static build [Step 3: support while] #57616
Merged
From00
merged 65 commits into
PaddlePaddle:develop
from
AndSonder:control_flow_support-step-3
Sep 23, 2023
Merged
Changes from 64 commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
9b469eb
add conditional_block to OperatorBasesHandledInStaticBuild
AndSonder db8a69c
run op in FakeInitializeOutputsForOperatorBase
AndSonder 8770ed9
add init_success judge
AndSonder 483ea72
fix build error
AndSonder 224ccfa
fix
AndSonder 7fc6735
add SetSubBlockCore func
AndSonder 7ced658
add PreStaticRun func
AndSonder 32ec0cf
add PreStaticRun to interpreter_base and new_ir_inter
AndSonder 274cc0e
recover codes
AndSonder 889e7b8
add PreStaticBuild and BlockCanBeStaticBuilt
AndSonder 8a8225b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle
AndSonder 8902d01
fix logic about RunPreStaticBuild
AndSonder 11a9a5d
change CreateOpFromOpDesc type
AndSonder 4b3a0f1
fix build error
AndSonder 5e40452
fix build error
AndSonder 8171ef7
remove IsOperatorBasesHandledInStaticBuild
AndSonder bab22a8
recover BlockCanBeStaticBuilt
AndSonder 522a005
add logic about conditional_block run static build
AndSonder 6fb3fc7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 9b872c8
recover codes
AndSonder 9c3a0f4
recover BlockCanBeStaticBuilt
AndSonder 3b5c905
support static build condational block op when condational block is t…
AndSonder 32b5d4c
fix error
AndSonder 87538c7
fix logic about last op
AndSonder b56f7fa
fit for sub block can't open static build
AndSonder c4b3d82
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 1f68f61
add IsStaticBuild
AndSonder 9456c05
fix build error
AndSonder 68af0ba
fit logic when sub block can't open static build
AndSonder 5eddb19
close static build when sub_block don't support static_build
AndSonder 3af4c06
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle
AndSonder 2c7e2ce
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 42151f9
recover third party
AndSonder 4653175
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle
AndSonder d6f3d83
add is_skil_fake_init logic
AndSonder 18f4b2c
set the backend of the lamb
AndSonder 35e6a96
change start index
AndSonder 845ccd8
add if conditional for cal is_skip_fake_init
AndSonder 9676fc5
change name
AndSonder 3350eea
close static_build for test_conditional_block
AndSonder dffbdeb
add static buiild support for conditional block in case of the output…
AndSonder a592249
fix logic error
AndSonder 78d5312
fix timeout error
AndSonder a9e8a7f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle
AndSonder 3f3494f
fix
AndSonder 0e777f4
remove useless codes
AndSonder c1a2123
fix
AndSonder d1dddce
fix
AndSonder 71fc516
fix build error
AndSonder d13bc09
move GetVarsInfo and RunPreStaticBuild from opeartor to static_build
AndSonder 4795474
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 611c035
fix lamb backend registe
AndSonder b180105
fix build error
AndSonder a640ee4
fix build error
AndSonder 90292fb
remove lamp op test from new_ir_op_test_white_list
AndSonder e7baf2c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 7cd51be
fix
AndSonder 222008c
move generating following_input_vars logic to static_build.cc
AndSonder 914ec7b
remove HasInfo
AndSonder e3ed639
fix build error
AndSonder 4d8dfea
recover codes and turn off the flag
AndSonder 11d66c9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 41cac94
add support for while
AndSonder 26fa4bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle
AndSonder ed21e0c
fix
AndSonder File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,13 +20,20 @@ | |
#include "paddle/fluid/framework/reader.h" | ||
#include "paddle/fluid/operators/reader/buffered_reader.h" | ||
|
||
#include "paddle/fluid/operators/controlflow/control_flow_op_helper.h" | ||
#include "paddle/fluid/operators/controlflow/while_op_helper.h" | ||
|
||
#ifdef PADDLE_WITH_DNNL | ||
#include "paddle/fluid/platform/mkldnn_helper.h" | ||
#endif | ||
|
||
#include "paddle/fluid/platform/flags.h" | ||
|
||
PHI_DECLARE_bool(cache_inference_while_scope); | ||
|
||
// These Ops is OperatorBase, but we have been handle them in static build | ||
std::set<std::string> OperatorBasesHandledInStaticBuild = {"read", | ||
"conditional_block"}; | ||
std::set<std::string> OperatorBasesHandledInStaticBuild = { | ||
"read", "conditional_block", "while"}; | ||
|
||
std::set<std::string> OperatorBasesMustRunInStaticBuild = { | ||
"create_double_buffer_reader", "create_py_reader"}; | ||
|
@@ -386,9 +393,9 @@ void FakeInitializeTensorBase(const platform::DeviceContext& dev_ctx, | |
} | ||
} | ||
|
||
void RunPreStaticBuild(const framework::Scope& scope, | ||
const platform::Place& dev_place, | ||
const OperatorBase& op) { | ||
void RunConditionalBlockPreStaticBuild(const framework::Scope& scope, | ||
const platform::Place& dev_place, | ||
const OperatorBase& op) { | ||
auto* scope_var = scope.FindVar(op.Output("Scope")); | ||
PADDLE_ENFORCE_NOT_NULL( | ||
scope_var, | ||
|
@@ -434,6 +441,193 @@ void RunPreStaticBuild(const framework::Scope& scope, | |
core->Build({}, &op_func_nodes); | ||
} | ||
|
||
void RunWhileBlockPreStaticBuild(const framework::Scope& scope, | ||
const platform::Place& dev_place, | ||
const OperatorBase& op) { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
scope.FindVar(op.Input("Condition")), | ||
platform::errors::NotFound("Input(Condition) of WhileOp is not found.")); | ||
|
||
#ifdef PADDLE_WITH_DNNL | ||
// Executor on being destroyed clears oneDNN cache and resets | ||
// registered model data layout. This is unwanted for nested | ||
// Executors (executors declared inside control ops) | ||
platform::DontClearMKLDNNCache(dev_place); | ||
#endif | ||
auto* block = op.Attr<framework::BlockDesc*>("sub_block"); | ||
|
||
// get device context from pool | ||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); | ||
auto& dev_ctx = *pool.Get(dev_place); | ||
|
||
bool is_test = op.Attr<bool>("is_test"); | ||
|
||
std::set<std::string> no_copy_var_names; | ||
if (!is_test) { | ||
// set all persistable parameters into no_copy_var_names. | ||
auto* global_block = block; | ||
|
||
while (global_block->ID() != 0) global_block = global_block->ParentBlock(); | ||
auto all_vars = global_block->AllVars(); | ||
std::for_each(all_vars.begin(), | ||
all_vars.end(), | ||
[&no_copy_var_names](framework::VarDesc* var) { | ||
if (var->IsParameter()) | ||
no_copy_var_names.insert(var->Name()); | ||
}); | ||
|
||
const std::vector<framework::OpDesc*>& all_ops = block->AllOps(); | ||
for (const framework::OpDesc* item : all_ops) { | ||
const framework::VariableNameMap& input_var_names = item->Inputs(); | ||
const framework::VariableNameMap& output_var_names = item->Outputs(); | ||
for (auto& ipt : input_var_names) { | ||
for (const std::string& var_name : ipt.second) { | ||
if (operators::StrInVaraiableNameMap(var_name, output_var_names)) { | ||
no_copy_var_names.insert(var_name); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
auto step_scopes = scope.FindVar(op.Output("StepScopes")) | ||
->GetMutable<std::vector<framework::Scope*>>(); | ||
|
||
if (!step_scopes->empty()) { | ||
platform::DeviceContextPool::Instance().Get(dev_place)->Wait(); | ||
for (auto& s : *step_scopes) { | ||
if (scope.HasKid(s)) { | ||
scope.DeleteScope(s); | ||
} | ||
} | ||
step_scopes->clear(); | ||
} | ||
|
||
PADDLE_ENFORCE_EQ(step_scopes->size(), | ||
0, | ||
platform::errors::PreconditionNotMet( | ||
"The Output(StepScope) of WhileOp should be empty.")); | ||
|
||
auto& skip_vars = | ||
op.Attr<std::vector<std::string>>("skip_eager_deletion_vars"); | ||
|
||
// note(lvyongkang): The assign op in while loop may change the place of | ||
// variable. However, InterpreterCore fix the kernel of every ops during its | ||
// first run. A cpu tensor may become gpu tensor after first run. This will | ||
// lead to segmetation fault when it's used in a cpu kernel. Here we record | ||
// the place of every inputs and restore their place after | ||
// InterpreterCore.run(). | ||
std::map<std::string, phi::Place> input_var_original_places; | ||
for (const auto& in_name : op.Inputs("X")) { | ||
framework::Variable* var = scope.FindVar(in_name); | ||
if (var == nullptr) { | ||
VLOG(4) << "[while op]" | ||
<< "input not found:" << in_name; | ||
} | ||
|
||
if (var->Type() == framework::proto::VarType::LOD_TENSOR) { | ||
input_var_original_places[in_name] = | ||
(var->Get<phi::DenseTensor>()).place(); | ||
} else { | ||
VLOG(10) << "[while op]" | ||
<< "skip backup input " << in_name << " type:" | ||
<< framework::TransToPhiDataType( | ||
framework::ToVarType(var->Type())); | ||
} | ||
} | ||
|
||
LOG_FIRST_N(INFO, 1) << "[ControlFlow][WhileOp] New Executor is Running."; | ||
std::unique_ptr<InterpreterCore> core; | ||
|
||
framework::Scope placeholder; // Don't care if it's valid, just for | ||
// initialize InterpreterCore | ||
framework::interpreter::ExecutionConfig execution_config; | ||
execution_config.create_local_scope = false; | ||
execution_config.used_for_control_flow_op = true; | ||
execution_config.skip_gc_vars = | ||
std::set<std::string>(skip_vars.begin(), skip_vars.end()); | ||
|
||
core.reset(new framework::InterpreterCore( | ||
dev_place, *block, &placeholder, execution_config)); | ||
|
||
if (!is_test) { | ||
auto& current_scope = scope.NewScope(); | ||
step_scopes->push_back(¤t_scope); | ||
|
||
std::vector<std::string> rename_vars; | ||
for (const std::string& input_var_name : op.Inputs("X")) { | ||
if (no_copy_var_names.find(input_var_name) == no_copy_var_names.end()) { | ||
std::string input_var_rename = input_var_name + "@TMP_COPY"; | ||
framework::Variable* input_var = scope.FindVar(input_var_name); | ||
if (input_var->IsType<phi::DenseTensor>()) { | ||
rename_vars.push_back(input_var_rename); | ||
auto input_var_tensor = input_var->Get<phi::DenseTensor>(); | ||
auto* rename_input_var_tensor = current_scope.Var(input_var_rename) | ||
->GetMutable<phi::DenseTensor>(); | ||
framework::TensorCopy( | ||
input_var_tensor, dev_place, rename_input_var_tensor); | ||
rename_input_var_tensor->set_lod(input_var_tensor.lod()); | ||
} | ||
} | ||
} | ||
|
||
operators::BuildScopeForControlFlowOp(*core, *block, ¤t_scope); | ||
core->reset_scope(¤t_scope); | ||
|
||
std::vector<paddle::framework::OpFuncNode> op_func_nodes; | ||
core->Build({}, &op_func_nodes); | ||
|
||
// restore inputs place | ||
for (const auto& n : input_var_original_places) { | ||
const std::string& in_name = n.first; | ||
const phi::Place& original_place = n.second; | ||
// input vars exist in `scope` not `current_scope` | ||
operators::TransferVariablePlace( | ||
&scope, in_name, original_place, dev_ctx); | ||
} | ||
|
||
for (auto& var_rename : rename_vars) { | ||
std::string input_var_name = | ||
var_rename.substr(0, var_rename.size() - strlen("@TMP_COPY")); | ||
current_scope.Rename(var_rename, input_var_name); | ||
} | ||
} else { | ||
framework::Scope* current_scope = nullptr; | ||
if (!FLAGS_cache_inference_while_scope) { | ||
current_scope = &(scope.NewScope()); | ||
operators::BuildScopeForControlFlowOp(*core, *block, current_scope); | ||
core->reset_scope(current_scope); | ||
} else { | ||
auto cached_inference_scope = &(scope.NewScope()); | ||
operators::BuildScopeForControlFlowOp( | ||
*core, *block, cached_inference_scope); | ||
core->reset_scope(cached_inference_scope); | ||
current_scope = cached_inference_scope; | ||
} | ||
|
||
for (auto& name : current_scope->LocalVarNames()) { | ||
auto* var = current_scope->Var(name); | ||
if (var->IsType<phi::DenseTensor>()) { | ||
// Clear all lod information for all lod_tensors. | ||
auto* t = var->GetMutable<phi::DenseTensor>(); | ||
framework::LoD empty_lod; | ||
t->set_lod(empty_lod); | ||
} else if (var->IsType<framework::LoDTensorArray>()) { | ||
// Clear elements of all tensor arrays. | ||
auto* t = var->GetMutable<framework::LoDTensorArray>(); | ||
t->clear(); | ||
} | ||
} | ||
|
||
std::vector<paddle::framework::OpFuncNode> op_func_nodes; | ||
core->Build({}, &op_func_nodes); | ||
|
||
if (!FLAGS_cache_inference_while_scope) { | ||
scope.DeleteScope(current_scope); | ||
} | ||
} | ||
} | ||
|
||
void FakeInitializeOutputsForOperatorBase( | ||
const OperatorBase& op, | ||
const phi::Place& place, | ||
|
@@ -447,7 +641,7 @@ void FakeInitializeOutputsForOperatorBase( | |
phi::DeviceContext* dev_ctx = | ||
platform::DeviceContextPool::Instance().Get(place); | ||
|
||
if (op_type == "conditional_block") { | ||
if (op_type == "conditional_block" || op_type == "while") { | ||
// Note(sonder): skip fake init for conditional_block when there is no | ||
// op with kernel after it. | ||
bool skip_fake_init = true; | ||
|
@@ -456,7 +650,7 @@ void FakeInitializeOutputsForOperatorBase( | |
for (size_t i = 0; i < following_ops.size(); ++i) { | ||
if (dynamic_cast<framework::OperatorWithKernel*>( | ||
following_ops[i].get()) != nullptr) { | ||
VLOG(4) << "Find op with kernel after conditional_block : " | ||
VLOG(4) << "Find op with kernel after " << op_type << ": " | ||
<< following_ops[i]->Type(); | ||
skip_fake_init = false; | ||
auto input_vars_info = GetVarsInfo( | ||
|
@@ -474,7 +668,14 @@ void FakeInitializeOutputsForOperatorBase( | |
const std::vector<VarMetaInfo> out_var_info_before_build = | ||
GetVarsInfo(scope, op.Outputs(), op); | ||
|
||
RunPreStaticBuild(*scope, place, op); | ||
VLOG(3) << "debug1: " << op.DebugStringEx(scope); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要写 |
||
if (op_type == "conditional_block") { | ||
RunConditionalBlockPreStaticBuild(*scope, place, op); | ||
} else { | ||
RunWhileBlockPreStaticBuild(*scope, place, op); | ||
} | ||
VLOG(3) << "debug2: " << op.DebugStringEx(scope); | ||
|
||
const std::vector<VarMetaInfo> out_var_info_after_build = | ||
GetVarsInfo(scope, op.Outputs(), op); | ||
|
||
|
@@ -487,10 +688,11 @@ void FakeInitializeOutputsForOperatorBase( | |
auto var_name = out_var_info_before_build[i].name_; | ||
if (following_input_vars.count(var_name)) { | ||
PADDLE_THROW(phi::errors::PreconditionNotMet( | ||
"The output %s s' dtype/place of conditional_block is " | ||
"The output %s s' dtype/place of %s is " | ||
"changed after static build. Befer static build, the " | ||
"dtype is %s, place is %s. After static " | ||
"build, the dtype is %s, place is %s.", | ||
op_type, | ||
var_name, | ||
out_var_info_before_build[i].dtype_, | ||
out_var_info_before_build[i].place_, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
项目内头文件引用应按字典序排列,有宏定义包裹的头文件引用置于无宏定义包裹的头文件之后。