Skip to content

Commit

Permalink
[Dy2stat]Support for i in [1,2,3] statements in dy2stat (#37259) (#…
Browse files Browse the repository at this point in the history
…37356)

该PR使得动转静模块能够正确转换如下的for i in [1, 2, 3]语句。
  • Loading branch information
0x45f authored Nov 19, 2021
1 parent b559475 commit 44db219
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,8 @@ def is_for_range_iter(self):
gast.Name) and self.node.iter.func.id == "range"

def is_for_iter(self):
if isinstance(self.node.iter, (gast.Name, gast.Attribute)):
if isinstance(self.node.iter,
(gast.Name, gast.Attribute, gast.List, gast.Tuple)):
return True
elif isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,24 @@ def forward(self, x):
return z


# 21. for original list
@paddle.jit.to_static
def for_original_list():
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in [1, 2, 3]:
z = z + x
return z


# 22. for original tuple
@paddle.jit.to_static
def for_original_tuple():
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in (1, 2, 3):
z = z + x
return z


class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
Expand Down Expand Up @@ -344,6 +362,13 @@ def transformed_result_compare(self):
self.assertTrue(np.allclose(x.numpy(), y.numpy()))


class TestTransformForOriginalList(TestTransform):
def _run(self, to_static):
program_translator.enable(to_static)
with fluid.dygraph.guard():
return self.dygraph_func()


class TestTransformError(TestTransformBase):
def transformed_error(self, etype):
with self.assertRaises(etype):
Expand Down Expand Up @@ -471,5 +496,21 @@ def set_test_func(self):
self.dygraph_func = ForwardContainsForLayer()


class TestForOriginalList(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_list

def test_transformed_result_compare(self):
self.transformed_result_compare()


class TestForOriginalTuple(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_tuple

def test_transformed_result_compare(self):
self.transformed_result_compare()


if __name__ == '__main__':
unittest.main()

0 comments on commit 44db219

Please sign in to comment.