-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support control flow for static build [Step 2: support conditional_block] #56696
Support control flow for static build [Step 2: support conditional_block] #56696
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… control_flow_support-step-2
…he last op in the block
… control_flow_support-step-2
@@ -1272,6 +1272,8 @@ set_tests_properties( | |||
set_tests_properties( | |||
test_cuda_graph_static_mode_error | |||
PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1") | |||
set_tests_properties(test_conditional_block |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加个注释说明下这个单测的情况吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -14,6 +14,7 @@ | |||
|
|||
#pragma once | |||
|
|||
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个头文件中的符号在static_build.h
中不需要使用,加到static_build.cc
中防止头文件污染。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
!FLAGS_new_executor_use_cuda_graph && | ||
interpreter::BlockCanBeStaticBuilt(block); | ||
|
||
for (auto& op : block.AllOps()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个逻辑合并到BlockCanBeStaticBuilt
里吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -361,6 +365,37 @@ void FakeInitializeOutputsForOperatorBase(const OperatorBase& op, | |||
*dev_ctx, target_place, dtype, out_tensor->layout(), out_tensor); | |||
} | |||
} | |||
} else if (op_type == "conditional_block") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多个并行分支处理不同OP,建议按OP名称字典序排序
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -185,6 +196,34 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names, | |||
} | |||
} | |||
|
|||
void ProgramInterpreter::PreStaticBuild() { | |||
SetDeviceId(place_); | |||
#ifdef PADDLE_WITH_DNNL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码在Run函数里有类似的实现,建议合并成一处,统一实现成Build
接口
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/fluid/framework/operator.h
Outdated
@@ -104,6 +104,27 @@ constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@"; | |||
constexpr char kAllKernelsMustComputeRuntimeShape[] = | |||
"@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@"; | |||
|
|||
struct VarInfo { | |||
std::string name_; | |||
std::string dtype_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype和place类型建议用框架定义的数据类型,不用std::string
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里用string应该可以更好的兼容没有init的情况,像 GetDtype 这类接口的返回值也是 string,用string也更方便后面打印debug log
paddle/fluid/framework/operator.cc
Outdated
@@ -953,6 +953,66 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { | |||
return ss.str(); | |||
} | |||
|
|||
std::vector<VarInfo> OperatorBase::InputVarsInfo(const Scope* scope) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InputVarsInfo
和OutputVarsInfo
代码逻辑非常类似,考虑重复代码是否可以合并成一处,一些差异化的信息通过函数参数传入,减少重复代码。no_need_buffer_vars
的处理,也是适用于output的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
bool is_skip_fake_init = false; | ||
std::unordered_set<std::string> following_input_vars; | ||
|
||
if (static_build && op->Type() == "conditional_block") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个逻辑代码建议放到HandleOperatorBase
处理控制流算子的分支中,而不需要在这里处理一些控制流算子特有的信息后再通过函数参数传入。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
// Note(sonder): static_build is not supported if the output of | ||
// conditional_block is changed after static build. | ||
if (out_var_info_before_build.size() != out_var_info_after_build.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在什么情况下会出现static_build前后var数量不一致?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没遇到过这种情况,现在已经删除了这个判断
if (out_var_info_before_build[i] != out_var_info_after_build[i]) { | ||
auto var_name = out_var_info_before_build[i].name_; | ||
if (following_input_vars.count(var_name)) { | ||
PADDLE_THROW(phi::errors::PreconditionNotMet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议在报错信息中强调下是输出的dtype或place发生了改变,直接写output改变容易被误解成数据改变。另外,建议报错信息中输出具体改变的变量名称,以及out_var_info_before_build和out_var_info_after_build信息,在用户出错时容易通过log判断原因。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
… control_flow_support-step-2
if (!OpsCanSkipedFakeAllocInStaticBuild.count(op_type)) { | ||
if (in_black_list || | ||
(is_operator_base && | ||
!OperatorBasesHandledInStaticBuild.count(op_type)) || | ||
is_custom_op || use_mkldnn) { | ||
is_custom_op || use_mkldnn || !is_sub_block_static_build) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议在下面The following OPs are unable to static build:
的log里也加上sub_block_static_build
的信息输出,方便通过log判断static_build是因为哪种情况退化的。另外,为了和其它情况统一,布尔值为True
表示满足不支持的情况,建议将is_sub_block_static_build
改成sub_block_can_not_static_build
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
bool is_skip_fake_init = false; | ||
std::unordered_set<std::string> following_input_vars; | ||
|
||
if (op->Type() == "conditional_block") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的逻辑好像没改?这个逻辑代码建议放到HandleOperatorBase处理控制流算子的分支中,而不需要在这里处理一些控制流算子特有的信息后再通过函数参数传入。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个已经是在HandleOperatorBase中了吧,之前是在 HandleOperatorBase 外处理好传入进 HandleOperatorBase 中的
PADDLE_THROW(phi::errors::PreconditionNotMet( | ||
"The outputs' dtype/place of conditional_block is " | ||
"changed after static build. Befer static build, the " | ||
"output %s's dtype is %s, place is %s. After static " | ||
"build, the output %s's dtype is %s, place is %s.", | ||
var_name, | ||
out_var_info_before_build[i].dtype_, | ||
out_var_info_before_build[i].place_, | ||
var_name, | ||
out_var_info_after_build[i].dtype_, | ||
out_var_info_after_build[i].place_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_THROW(phi::errors::PreconditionNotMet( | |
"The outputs' dtype/place of conditional_block is " | |
"changed after static build. Befer static build, the " | |
"output %s's dtype is %s, place is %s. After static " | |
"build, the output %s's dtype is %s, place is %s.", | |
var_name, | |
out_var_info_before_build[i].dtype_, | |
out_var_info_before_build[i].place_, | |
var_name, | |
out_var_info_after_build[i].dtype_, | |
out_var_info_after_build[i].place_)); | |
PADDLE_THROW(phi::errors::PreconditionNotMet( | |
"The output %s s' dtype/place of conditional_block is " | |
"changed after static build. Befer static build, the " | |
"dtype is %s, place is %s. After static " | |
"build, the dtype is %s, place is %s.", | |
var_name, | |
out_var_info_before_build[i].dtype_, | |
out_var_info_before_build[i].place_, | |
out_var_info_after_build[i].dtype_, | |
out_var_info_after_build[i].place_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ock] (PaddlePaddle#56696) * add conditional_block to OperatorBasesHandledInStaticBuild * run op in FakeInitializeOutputsForOperatorBase * add init_success judge * fix build error * fix * add SetSubBlockCore func * add PreStaticRun func * add PreStaticRun to interpreter_base and new_ir_inter * recover codes * add PreStaticBuild and BlockCanBeStaticBuilt * fix logic about RunPreStaticBuild * change CreateOpFromOpDesc type * fix build error * fix build error * remove IsOperatorBasesHandledInStaticBuild * recover BlockCanBeStaticBuilt * add logic about conditional_block run static build * recover codes * recover BlockCanBeStaticBuilt * support static build condational block op when condational block is the last op in the block * fix error * fix logic about last op * fit for sub block can't open static build * add IsStaticBuild * fix build error * fit logic when sub block can't open static build * close static build when sub_block don't support static_build * recover third party * add is_skil_fake_init logic * set the backend of the lamb * change start index * add if conditional for cal is_skip_fake_init * change name * close static_build for test_conditional_block * add static buiild support for conditional block in case of the output's dtype/place is changed but the following op is not use this output * fix logic error * fix timeout error * fix * remove useless codes * fix * fix * fix build error * move GetVarsInfo and RunPreStaticBuild from opeartor to static_build * fix lamb backend registe * fix build error * fix build error * remove lamp op test from new_ir_op_test_white_list * fix * move generating following_input_vars logic to static_build.cc * remove HasInfo * fix build error * recover codes and turn off the flag
…ock] (PaddlePaddle#56696) * add conditional_block to OperatorBasesHandledInStaticBuild * run op in FakeInitializeOutputsForOperatorBase * add init_success judge * fix build error * fix * add SetSubBlockCore func * add PreStaticRun func * add PreStaticRun to interpreter_base and new_ir_inter * recover codes * add PreStaticBuild and BlockCanBeStaticBuilt * fix logic about RunPreStaticBuild * change CreateOpFromOpDesc type * fix build error * fix build error * remove IsOperatorBasesHandledInStaticBuild * recover BlockCanBeStaticBuilt * add logic about conditional_block run static build * recover codes * recover BlockCanBeStaticBuilt * support static build condational block op when condational block is the last op in the block * fix error * fix logic about last op * fit for sub block can't open static build * add IsStaticBuild * fix build error * fit logic when sub block can't open static build * close static build when sub_block don't support static_build * recover third party * add is_skil_fake_init logic * set the backend of the lamb * change start index * add if conditional for cal is_skip_fake_init * change name * close static_build for test_conditional_block * add static buiild support for conditional block in case of the output's dtype/place is changed but the following op is not use this output * fix logic error * fix timeout error * fix * remove useless codes * fix * fix * fix build error * move GetVarsInfo and RunPreStaticBuild from opeartor to static_build * fix lamb backend registe * fix build error * fix build error * remove lamp op test from new_ir_op_test_white_list * fix * move generating following_input_vars logic to static_build.cc * remove HasInfo * fix build error * recover codes and turn off the flag
…ock] (PaddlePaddle#56696) * add conditional_block to OperatorBasesHandledInStaticBuild * run op in FakeInitializeOutputsForOperatorBase * add init_success judge * fix build error * fix * add SetSubBlockCore func * add PreStaticRun func * add PreStaticRun to interpreter_base and new_ir_inter * recover codes * add PreStaticBuild and BlockCanBeStaticBuilt * fix logic about RunPreStaticBuild * change CreateOpFromOpDesc type * fix build error * fix build error * remove IsOperatorBasesHandledInStaticBuild * recover BlockCanBeStaticBuilt * add logic about conditional_block run static build * recover codes * recover BlockCanBeStaticBuilt * support static build condational block op when condational block is the last op in the block * fix error * fix logic about last op * fit for sub block can't open static build * add IsStaticBuild * fix build error * fit logic when sub block can't open static build * close static build when sub_block don't support static_build * recover third party * add is_skil_fake_init logic * set the backend of the lamb * change start index * add if conditional for cal is_skip_fake_init * change name * close static_build for test_conditional_block * add static buiild support for conditional block in case of the output's dtype/place is changed but the following op is not use this output * fix logic error * fix timeout error * fix * remove useless codes * fix * fix * fix build error * move GetVarsInfo and RunPreStaticBuild from opeartor to static_build * fix lamb backend registe * fix build error * fix build error * remove lamp op test from new_ir_op_test_white_list * fix * move generating following_input_vars logic to static_build.cc * remove HasInfo * fix build error * recover codes and turn off the flag
…ock] (PaddlePaddle#56696) * add conditional_block to OperatorBasesHandledInStaticBuild * run op in FakeInitializeOutputsForOperatorBase * add init_success judge * fix build error * fix * add SetSubBlockCore func * add PreStaticRun func * add PreStaticRun to interpreter_base and new_ir_inter * recover codes * add PreStaticBuild and BlockCanBeStaticBuilt * fix logic about RunPreStaticBuild * change CreateOpFromOpDesc type * fix build error * fix build error * remove IsOperatorBasesHandledInStaticBuild * recover BlockCanBeStaticBuilt * add logic about conditional_block run static build * recover codes * recover BlockCanBeStaticBuilt * support static build condational block op when condational block is the last op in the block * fix error * fix logic about last op * fit for sub block can't open static build * add IsStaticBuild * fix build error * fit logic when sub block can't open static build * close static build when sub_block don't support static_build * recover third party * add is_skil_fake_init logic * set the backend of the lamb * change start index * add if conditional for cal is_skip_fake_init * change name * close static_build for test_conditional_block * add static buiild support for conditional block in case of the output's dtype/place is changed but the following op is not use this output * fix logic error * fix timeout error * fix * remove useless codes * fix * fix * fix build error * move GetVarsInfo and RunPreStaticBuild from opeartor to static_build * fix lamb backend registe * fix build error * fix build error * remove lamp op test from new_ir_op_test_white_list * fix * move generating following_input_vars logic to static_build.cc * remove HasInfo * fix build error * recover codes and turn off the flag
PR types
Performance optimization
PR changes
Others
Description
当下 conditional_block 是直接采用跑算子的方法,本 PR 的核心目的是将 conditional_block 加入到 OperatorBasesHandledInStaticBuild 中并使其支持 static_build
相关Issue: