-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2Stat]Allow ifelse return buildin type in paddle cond #37888
[Dy2Stat]Allow ifelse return buildin type in paddle cond #37888
Conversation
Thanks for your contribution! |
int
in paddle cond# false_args = [ | ||
# to_static_variable(var) if id(var) in return_var_ids else var | ||
# for var in false_args | ||
# ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete unused codes.
@@ -102,6 +102,28 @@ def select_input(inputs, mask): | |||
return out | |||
|
|||
|
|||
def select_input_with_buildin_type(inputs, mask): | |||
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import 放在这里是因为有循环引用么?
@@ -2284,6 +2306,8 @@ def append_conditional_block_grad(self, parent_block, inside_block, | |||
def copy_var_to_parent_block(var, layer_helper): | |||
if var is None: | |||
return None | |||
if isinstance(var, (bool, float, six.integer_types)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(var, (bool, float, six.integer_types)): | |
if not isinstance(var, Variable): |
只要不是variable类型都直接返回。
isinstance(true_var, Variable)) or (isinstance(true_var, ( | ||
bool, float, six.integer_types)) and isinstance(false_var, | ||
Variable)): | ||
inputs = [to_static_variable(false_var), to_static_variable(true_var)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的处理逻辑是:
- true_var、false_var 都不是variable,则直接判断是否相等(包括list、dict等其他类型)
- true_var、false_var 其中有一个是Variable,PR的逻辑是会将其中一个转为Variable返回,这个合理么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…e#37888) * allow ifelse return `int` in paddle cond * add test and refine code * polish code, add test * code format
PR types
Others
PR changes
Others
Describe
PR的修改如下: