Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Allow ifelse return buildin type in paddle cond #37888

Merged
merged 5 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,6 @@ def _remove_no_value_return_var(out):

def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars):

return_var_ids = [id(var) for var in return_vars]
# NOTE 1: Returned vars of Paddle op `control_flow.cond` must be Paddle Tensors
# NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`,
# which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized.
true_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in true_args
]
false_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in false_args
]

pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args))
Expand Down
41 changes: 38 additions & 3 deletions python/paddle/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,41 @@ def select_input(inputs, mask):
return out


def select_input_with_buildin_type(inputs, mask):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import 放在这里是因为有循环引用么?

support_ret_buildin_type = (bool, float, six.integer_types)
false_var, true_var = inputs

if isinstance(false_var, Variable) and isinstance(true_var, Variable):
return select_input(inputs, mask)

elif (isinstance(false_var, (support_ret_buildin_type)) and
isinstance(false_var, type(true_var))):
if false_var == true_var:
return false_var
else:
inputs = [
to_static_variable(false_var), to_static_variable(true_var)
]
# Deal with the situations like this: false_var is int and true_var is Variable
elif ((isinstance(false_var, support_ret_buildin_type) and
isinstance(true_var, Variable)) or
(isinstance(true_var, support_ret_buildin_type) and
isinstance(false_var, Variable))):
inputs = [to_static_variable(false_var), to_static_variable(true_var)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的处理逻辑是:

  1. true_var、false_var 都不是variable,则直接判断是否相等(包括list、dict等其他类型)
  2. true_var、false_var 其中有一个是Variable,PR的逻辑是会将其中一个转为Variable返回,这个合理么?

warnings.warn(
"Return results from different branches in cond are not same type: "
"false_var returned by fasle_fn is '{}' and true_var of true_fn is "
"'{}'".format(type(false_var), type(true_var)))
else:
raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'".
format(type(false_var), type(true_var)))

return select_input(inputs, mask)


def split_lod_tensor(input, mask, level=0):
"""
This function takes in an input that contains the complete lod information,
Expand Down Expand Up @@ -2282,8 +2317,8 @@ def append_conditional_block_grad(self, parent_block, inside_block,


def copy_var_to_parent_block(var, layer_helper):
if var is None:
return None
if not isinstance(var, Variable):
return var
prog = layer_helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0, "Got wrong parent block index when assigning var to parent scope in control_flow"
Expand Down Expand Up @@ -2466,7 +2501,7 @@ def false_func():
format(e))

mask = cast(pred, dtype='int32')
merge_func = lambda false_var, true_var : select_input([false_var, true_var], mask)
merge_func = lambda false_var, true_var : select_input_with_buildin_type([false_var, true_var], mask)
merged_output = map_structure(merge_func, false_output, true_output)
return merged_output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,53 @@ def if_tensor_case(x):
x += 1

return x


def dyfunc_ifelse_ret_int1(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return y, index
else:
y = x[index] + 2
index = index + 1
return y, index


def dyfunc_ifelse_ret_int2(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return y, index
else:
y = x[index] + 2
index = index + 1
return y


def dyfunc_ifelse_ret_int3(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return index
else:
y = x[index] + 2
return y


def dyfunc_ifelse_ret_int4(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return 'unsupport ret'
else:
y = x[index] + 2
return y
Original file line number Diff line number Diff line change
Expand Up @@ -365,5 +365,60 @@ def case_func(training):
self.assertEqual(paddle.jit.to_static(case_func)(True), -2)


class TestDy2StIfElseRetInt1(unittest.TestCase):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int1
self.out = self.get_dy2stat_out()

def get_dy2stat_out(self):
ProgramTranslator().enable(True)
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)
ProgramTranslator().enable(False)
return out

def test_ast_to_func(self):
self.assertIsInstance(self.out[0], paddle.Tensor)
self.assertIsInstance(self.out[1], int)


class TestDy2StIfElseRetInt2(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int2
self.out = self.get_dy2stat_out()

def test_ast_to_func(self):
self.assertIsInstance(self.out[0], paddle.Tensor)
self.assertIsInstance(self.out[1], paddle.Tensor)


class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int3
self.out = self.get_dy2stat_out()

def test_ast_to_func(self):
self.assertIsInstance(self.out, paddle.Tensor)


class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int4

def test_ast_to_func(self):
with self.assertRaises(TypeError):
ProgramTranslator().enable(True)
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)

def __del__(self):
ProgramTranslator().enable(False)
super(TestDy2StIfElseRetInt4, self).__del__()


if __name__ == '__main__':
unittest.main()