Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2stat] Fix Using Tuple for Transpose in Dy2stat #28574

Merged
merged 4 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5459,7 +5459,7 @@ def transpose(x, perm, name=None):

Args:
x (Variable): The input Tensor. It is a N-D Tensor of data types float32, float64, int32.
perm (list): Permute the input according to the data of perm.
perm (list|tuple): Permute the input according to the data of perm.
name (str): The name of this layer. It is optional.

Returns:
Expand Down Expand Up @@ -5492,14 +5492,12 @@ def transpose(x, perm, name=None):

.. code-block:: python

# use append_batch_size=False to avoid prepending extra
# batch size in shape
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[2, 3, 4],
dtype='float32', append_batch_size=False)
x_transposed = fluid.layers.transpose(x, perm=[1, 0, 2])
print x_transposed.shape
#(3L, 2L, 4L)
import paddle

x = paddle.randn([2, 3, 4])
x_transposed = paddle.transpose(x, perm=[1, 0, 2])
print(x_transposed.shape)
# [3L, 2L, 4L]

"""
if in_dygraph_mode():
Expand All @@ -5509,8 +5507,9 @@ def transpose(x, perm, name=None):
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
'transpose')
check_type(perm, 'perm', list, 'transpose')

check_type(perm, 'perm', (list, tuple), 'transpose')
if isinstance(perm, tuple):
perm = list(perm)
if len(perm) != len(x.shape):
raise ValueError(
"Input(perm) is the permutation of dimensions of Input(x), "
Expand Down
35 changes: 35 additions & 0 deletions python/paddle/fluid/tests/unittests/test_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard

paddle.enable_static()

class TestTransposeOp(OpTest):
def setUp(self):
Expand Down Expand Up @@ -113,6 +114,7 @@ def initTestCase(self):

class TestTransposeOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float64')

Expand Down Expand Up @@ -149,6 +151,39 @@ def test_each_elem_value_check():

self.assertRaises(ValueError, test_each_elem_value_check)

class TestTransposeApi(unittest.TestCase):
def test_static_out(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='x', shape=[2, 3, 4], dtype='float32')
x_trans1 = paddle.transpose(x, perm=[1, 0, 2])
x_trans2 = paddle.transpose(x, perm=(2, 1, 0))
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
x_np = np.random.random([2, 3, 4]).astype("float32")
result1, result2 = exe.run(feed={"x": x_np}, fetch_list=[x_trans1, x_trans2])
expected_result1 = np.transpose(x_np, [1, 0, 2])
expected_result2 = np.transpose(x_np, (2, 1, 0))

np.testing.assert_array_equal(result1, expected_result1)
np.testing.assert_array_equal(result2, expected_result2)

def test_dygraph_out(self):
# This is an old test before 2.0 API so we need to disable static
# to trigger dygraph
paddle.disable_static()
x = paddle.randn([2, 3, 4])
x_trans1 = paddle.transpose(x, perm=[1, 0, 2])
x_trans2 = paddle.transpose(x, perm=(2, 1, 0))
x_np = x.numpy()
expected_result1 = np.transpose(x_np, [1, 0, 2])
expected_result2 = np.transpose(x_np, (2, 1, 0))

np.testing.assert_array_equal(x_trans1.numpy(), expected_result1)
np.testing.assert_array_equal(x_trans2.numpy(), expected_result2)
# This is an old test before 2.0 API so we enable static again after
# dygraph test
paddle.enable_static()

class TestTAPI(unittest.TestCase):
def test_out(self):
Expand Down