diff --git a/paddle/fluid/operators/partial_concat_op.cc b/paddle/fluid/operators/partial_concat_op.cc index 3132a7bf27260..b9c1f6789d645 100644 --- a/paddle/fluid/operators/partial_concat_op.cc +++ b/paddle/fluid/operators/partial_concat_op.cc @@ -209,7 +209,9 @@ PD_REGISTER_STRUCT_KERNEL(partial_concat, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_STRUCT_KERNEL(partial_concat_grad, CPU, ALL_LAYOUT, @@ -217,4 +219,6 @@ PD_REGISTER_STRUCT_KERNEL(partial_concat_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/fluid/operators/partial_concat_op.cu b/paddle/fluid/operators/partial_concat_op.cu index ffef094fa96dd..fb746b2944acc 100644 --- a/paddle/fluid/operators/partial_concat_op.cu +++ b/paddle/fluid/operators/partial_concat_op.cu @@ -240,7 +240,9 @@ PD_REGISTER_STRUCT_KERNEL(partial_concat, double, int, int64_t, - plat::float16) {} + plat::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_STRUCT_KERNEL(partial_concat_grad, GPU, ALL_LAYOUT, @@ -249,4 +251,6 @@ PD_REGISTER_STRUCT_KERNEL(partial_concat_grad, double, int, int64_t, - plat::float16) {} + plat::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/incubate/layers/nn.py b/python/paddle/incubate/layers/nn.py index 95187f646534e..ee0a1dc69297f 100644 --- a/python/paddle/incubate/layers/nn.py +++ b/python/paddle/incubate/layers/nn.py @@ -532,7 +532,7 @@ def partial_concat(input, start_index=0, length=-1): Args: input(list): List of input Tensors with data type float32, float64, int32, - int64. + int64, complex64, complex128. start_index(int32, optional): The start index of each instance for partial concatenation. Default is 0. length(int32, optional): The length of each instance for partial concatenation. Default is -1. @@ -560,7 +560,15 @@ def partial_concat(input, start_index=0, length=-1): check_variable_and_dtype( x, 'input[' + str(id) + ']', - ['float16', 'float32', 'float64', 'int32', 'int64'], + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'complex64', + 'complex128', + ], 'partial_concat', ) check_type(start_index, 'start_index', (int), 'partial_concat') diff --git a/test/legacy_test/test_partial_concat_op.py b/test/legacy_test/test_partial_concat_op.py index 61a201970402a..fe5eae7fa4867 100644 --- a/test/legacy_test/test_partial_concat_op.py +++ b/test/legacy_test/test_partial_concat_op.py @@ -48,6 +48,15 @@ def setUp(self): np.random.random((self.batch_size, self.column)).astype(self.dtype) for num in range(self.var_num) ] + if self.dtype == np.complex64 or self.dtype == np.complex128: + self.vars = [ + ( + np.random.uniform(-1, 1, (self.batch_size, self.column)) + + 1j + * np.random.uniform(-1, 1, (self.batch_size, self.column)) + ).astype(self.dtype) + for num in range(self.var_num) + ] self.inputs = {'X': list(zip(self.var_names, self.vars))} self.attrs = {'start_index': self.start_index, 'length': self.length} y = np_partial_concat(self.vars[:], self.start_index, self.length) @@ -98,5 +107,77 @@ def init_para(self): self.var_num = 1 +class TestPartialConcatOp2_Complex64(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = -5 + self.length = -1 + self.var_num = 3 + + def init_kernel_type(self): + self.dtype = np.complex64 + + +class TestPartialConcatOp3_Complex64(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = 10 + self.length = 20 + self.var_num = 2 + + def init_kernel_type(self): + self.dtype = np.complex64 + + +class TestPartialConcatOp4_Complex64(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = -1 + self.length = -1 + self.var_num = 1 + + def init_kernel_type(self): + self.dtype = np.complex64 + + +class TestPartialConcatOp2_Complex128(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = -5 + self.length = -1 + self.var_num = 3 + + def init_kernel_type(self): + self.dtype = np.complex128 + + +class TestPartialConcatOp3_Complex128(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = 10 + self.length = 20 + self.var_num = 2 + + def init_kernel_type(self): + self.dtype = np.complex128 + + +class TestPartialConcatOp4_Complex128(TestPartialConcatOp): + def init_para(self): + self.batch_size = random.randint(1, 10) + self.column = random.randint(101, 200) + self.start_index = -1 + self.length = -1 + self.var_num = 1 + + def init_kernel_type(self): + self.dtype = np.complex128 + + if __name__ == '__main__': unittest.main()