Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2stat]remove no_value using var.name for ifelse #36513

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME


def convert_while_loop(cond, body, loop_vars):
Expand Down Expand Up @@ -204,10 +205,45 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars):

"""
if isinstance(pred, Variable):
return _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars)
out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars)
else:
return _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)

return _remove_no_value_return_var(out)


def _remove_no_value_return_var(out):
if out and isinstance(out, tuple):
processed_out = out
align_ret = out[0]
if isinstance(align_ret, tuple):
for index, item in enumerate(align_ret):
if isinstance(item, Variable) and (
RETURN_NO_VALUE_VAR_NAME in item.name):
# return None
if index == 0:
processed_out = (None, ) + out[1:]
elif index == 1:
processed_out = align_ret[:1] + out[1:]
else:
processed_out = (align_ret[:index], ) + out[1:]
break

for index, item in enumerate(processed_out):
if isinstance(item, Variable) and (
RETURN_NO_VALUE_VAR_NAME in item.name):
processed_out = processed_out[:index]

if not processed_out:
return None
elif len(processed_out) == 1:
return processed_out[0]
else:
return processed_out

else:
return out


def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def create_fill_constant_node(name, value):
func_code = "{} = paddle.fluid.layers.fill_constant(shape=[1], ".format(
name)
if isinstance(value, bool):
func_code += "dtype='bool', value={})".format(value)
func_code += "dtype='bool', value={}, name='{}')".format(value, name)
return gast.parse(func_code).body[0]
if isinstance(value, float):
func_code += "dtype='float64', value={})".format(value)
func_code += "dtype='float64', value={}, name='{}')".format(value, name)
return gast.parse(func_code).body[0]

if isinstance(value, int):
func_code += "dtype='int64', value={})".format(value)
func_code += "dtype='int64', value={}, name='{}')".format(value, name)
return gast.parse(func_code).body[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,100 @@ def test_tensor_shape(self):
self.assertTrue(np.array_equal(out.numpy(), x.numpy()))


class TestIfElseNoValue(unittest.TestCase):
def test_else_ret_none(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])

@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
z = x - 1
return None

@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
return None

out = with_common_value(input_x, False)
self.assertIsNone(out)
out = without_common_value(input_x, False)
self.assertIsNone(out)

def test_else_ret_c(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])

@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
z = x - 1
return c

@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
return c

out = with_common_value(input_x, False)
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
out = without_common_value(input_x, False)
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
y, z = with_common_value(input_x, True)
self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2))

def test_else_ret_cz(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])

@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z, 1
else:
c = x + 1
z = x - 1
return c, z

@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z, 1
else:
c = x + 1
d = x - 1
return c, d

c, z = with_common_value(input_x, False)
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x - 1))
c, d = without_common_value(input_x, False)
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(d), paddle.tolist(input_x - 1))


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_source_code(func):
class StaticCode1():
def dyfunc_with_if_else(x_v, label=None):
__return_value_init_0 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
shape=[1], dtype='float64', value=0.0, name='__return_value_init_0')
__return_value_0 = __return_value_init_0

def true_fn_0(x_v):
Expand Down Expand Up @@ -116,7 +116,7 @@ class StaticCode2():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
__return_value_init_1 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
shape=[1], dtype='float64', value=0.0, name='__return_value_init_1')
__return_value_1 = __return_value_init_1

def true_fn_3(x_v):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,22 @@ def test_feed_mismatch_shape(self):
class TestVariableTransFunc(unittest.TestCase):
def test_create_fill_constant_node(self):
node = create_fill_constant_node("a", 1.0)
source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0)"
self.assertEqual(ast_to_source_code(node).strip(), source)
source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0, name='a')"
self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', ''))

node = create_fill_constant_node("b", True)
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
self.assertEqual(ast_to_source_code(node).strip(), source)
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True, name='b')"
self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', ''))

node = create_fill_constant_node("c", 4293)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
self.assertEqual(ast_to_source_code(node).strip(), source)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293, name='c')"
self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', ''))

self.assertIsNone(create_fill_constant_node("e", None))
self.assertIsNone(create_fill_constant_node("e", []))
Expand Down