Skip to content

Commit

Permalink
【complex op No.25】add complex support for cross (PaddlePaddle#63207)
Browse files Browse the repository at this point in the history
* add complex dtype for cross

* remove temp var when dtype is not complex
  • Loading branch information
zbt78 authored and co63oc committed Apr 8, 2024
1 parent 9437361 commit 6751f6d
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 24 deletions.
28 changes: 25 additions & 3 deletions paddle/phi/kernels/cpu/cross_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {

Expand Down Expand Up @@ -81,9 +83,27 @@ void CrossGradKernel(const Context &dev_ctx,
slice_size *= static_cast<int>(input_x_dims[i]);
}

int64_t numel = x.numel();
DenseTensor x_conj, y_conj;
DenseTensorMeta meta_xy(x.dtype(), x.dims());
x_conj.set_meta(meta_xy);
y_conj.set_meta(meta_xy);

auto *input_x_conj_data = dev_ctx.template Alloc<T>(&x_conj);

auto *input_y_conj_data = dev_ctx.template Alloc<T>(&y_conj);

phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor_x(
input_x.data<T>(), numel, input_x_conj_data);
phi::funcs::ConjFunctor<T> functor_y(
input_y.data<T>(), numel, input_y_conj_data);
for_range(functor_x);
for_range(functor_y);

std::vector<T> input_x_vec, input_y_vec, input_dout_vec;
phi::TensorToVector(input_x, dev_ctx, &input_x_vec);
phi::TensorToVector(input_y, dev_ctx, &input_y_vec);
phi::TensorToVector(x_conj, dev_ctx, &input_x_vec);
phi::TensorToVector(y_conj, dev_ctx, &input_y_vec);
phi::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec);
std::vector<T> out_dx_vec(output_x_grad->numel());
std::vector<T> out_dy_vec(output_y_grad->numel());
Expand Down Expand Up @@ -120,4 +140,6 @@ PD_REGISTER_KERNEL(cross_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
12 changes: 10 additions & 2 deletions paddle/phi/kernels/cpu/cross_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,13 @@ void CrossKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
cross, CPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(cross,
CPU,
ALL_LAYOUT,
phi::CrossKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
61 changes: 47 additions & 14 deletions paddle/phi/kernels/gpu/cross_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/index_calculator.h"

namespace phi {
Expand Down Expand Up @@ -162,27 +164,56 @@ void CrossGradKernel(const Context& dev_ctx,

const auto* input_x_data = input_x.data<T>();
const auto* input_y_data = input_y.data<T>();
int64_t numel = x.numel();
const auto* input_out_grad_data = input_out_grad.data<T>();
auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad);
auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad);
auto index_calculator = phi::funcs::IndexCalculator(
merged_dims.size() - 1, cal_dims, left_strides, full_strides);

int64_t numel = x.numel();
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);

CrossGrad<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_x_data,
input_y_data,
input_out_grad_data,
output_x_grad_data,
output_y_grad_data,
full_strides[merge_axis],
numel / 3,
index_calculator);
if (IsComplexType(x.dtype())) {
DenseTensor x_conj, y_conj;
DenseTensorMeta meta_xy(x.dtype(), x.dims());
x_conj.set_meta(meta_xy);
y_conj.set_meta(meta_xy);

auto* input_x_conj_data = dev_ctx.template Alloc<T>(&x_conj);
auto* input_y_conj_data = dev_ctx.template Alloc<T>(&y_conj);

phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor_x(
input_x_data, numel, input_x_conj_data);
phi::funcs::ConjFunctor<T> functor_y(
input_y_data, numel, input_y_conj_data);
for_range(functor_x);
for_range(functor_y);

CrossGrad<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_x_conj_data,
input_y_conj_data,
input_out_grad_data,
output_x_grad_data,
output_y_grad_data,
full_strides[merge_axis],
numel / 3,
index_calculator);
} else {
CrossGrad<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_x_data,
input_y_data,
input_out_grad_data,
output_x_grad_data,
output_y_grad_data,
full_strides[merge_axis],
numel / 3,
index_calculator);
}
}
} // namespace phi

Expand All @@ -195,4 +226,6 @@ PD_REGISTER_KERNEL(cross_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/cross_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ PD_REGISTER_KERNEL(cross,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
26 changes: 22 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,8 +1900,8 @@ def cross(x, y, axis=9, name=None):
If `axis` is not given, it defaults to the first axis found with the length 3.
Args:
x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64.
y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64.
x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64, complex64, complex128.
y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64, complex64, complex128.
axis (int, optional): The axis along which to compute the cross product. It defaults to be 9 which indicates using the first axis found with the length 3.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -1941,13 +1941,31 @@ def cross(x, y, axis=9, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
[
'float16',
'uint16',
'float32',
'float64',
"int32",
"int64",
"complex64",
"complex128",
],
'cross',
)
check_variable_and_dtype(
y,
'y',
['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
[
'float16',
'uint16',
'float32',
'float64',
"int32",
"int64",
"complex64",
"complex128",
],
'cross',
)
helper = LayerHelper("cross", **locals())
Expand Down
35 changes: 35 additions & 0 deletions test/legacy_test/test_cross_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ def setUp(self):
'X': np.random.random(self.shape).astype(self.dtype),
'Y': np.random.random(self.shape).astype(self.dtype),
}
if self.dtype is np.complex64 or self.dtype is np.complex128:
self.inputs = {
'X': (
np.random.random(self.shape)
+ 1j * np.random.random(self.shape)
).astype(self.dtype),
'Y': (
np.random.random(self.shape)
+ 1j * np.random.random(self.shape)
).astype(self.dtype),
}
self.init_output()

def initTestCase(self):
Expand Down Expand Up @@ -81,6 +92,30 @@ def init_output(self):
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}


class TestCrossComplex64Op(TestCrossOp):
def initTestCase(self):
self.shape = (2048, 3)
self.dtype = np.complex64

def init_output(self):
z_list = []
for i in range(2048):
z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i]))
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}


class TestCrossComplex128Op(TestCrossOp):
def initTestCase(self):
self.shape = (2048, 3)
self.dtype = np.complex128

def init_output(self):
z_list = []
for i in range(2048):
z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i]))
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Expand Down

0 comments on commit 6751f6d

Please sign in to comment.