Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op】No.19 add complex support for triangular_solve #59529

Merged
merged 14 commits into from
Jan 3, 2024
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/blas/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ struct CBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> alpha,
const phi::dtype::complex<float> *A,
const int lda,
phi::dtype::complex<double> *B,
phi::dtype::complex<float> *B,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么改动了,有看过添加这个pr吗,为什么之前要用phi::dtype::complex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数调用的是cblas中用于解方程组的单精度复数版本的 cblas_ctrsm,
同时下面也有调用cblas中用于解方程组的双精度复数版本的cblas_ztrsm,
所以说A,B的类型应该是一样的。个人感觉这里应该是笔误。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

您好,可以再帮忙review一下吗~

const int ldb) {
cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class MatrixReduceSumFunctor<T, CPUContext> {

template class MatrixReduceSumFunctor<float, CPUContext>;
template class MatrixReduceSumFunctor<double, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, CPUContext>;

} // namespace funcs
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor<T, GPUContext> {

template class MatrixReduceSumFunctor<float, GPUContext>;
template class MatrixReduceSumFunctor<double, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, GPUContext>;

} // namespace funcs
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
14 changes: 10 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,9 +3186,9 @@ def triangular_solve(

Args:
x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
more batch dimensions. Its data type should be float32, float64, complex64, complex128.
y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is
zero or more batch dimensions. Its data type should be float32 or float64.
zero or more batch dimensions. Its data type should be float32, float64, complex64, complex128.
upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular
system of equations. Default: True.
transpose (bool, optional): whether `x` should be transposed before calculation. Default: False.
Expand Down Expand Up @@ -3227,10 +3227,16 @@ def triangular_solve(
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper("triangular_solve", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'triangular_solve'
x,
'x',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64'], 'triangular_solve'
y,
'y',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)

Expand Down
61 changes: 57 additions & 4 deletions test/legacy_test/test_triangular_solve_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,23 @@ def setUp(self):
self.python_api = paddle.tensor.linalg.triangular_solve
self.config()

self.inputs = {
'X': np.random.random(self.x_shape).astype(self.dtype),
'Y': np.random.random(self.y_shape).astype(self.dtype),
}
if self.dtype is np.complex64 or self.dtype is np.complex128:
self.inputs = {
'X': (
np.random.random(self.x_shape)
+ 1j * np.random.random(self.x_shape)
).astype(self.dtype),
'Y': (
np.random.random(self.y_shape)
+ 1j * np.random.random(self.y_shape)
).astype(self.dtype),
}
else:
self.inputs = {
'X': np.random.random(self.x_shape).astype(self.dtype),
'Y': np.random.random(self.y_shape).astype(self.dtype),
}

self.attrs = {
'upper': self.upper,
'transpose': self.transpose,
Expand Down Expand Up @@ -248,6 +261,46 @@ def set_output(self):
self.output = np.matmul(np.linalg.inv(x), y)


# 3D(broadcast) + 3D complex64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只测试了3D+3D的,考虑和float64一样,测试其他情况是否可行

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的形状不行,max gradient diff到了0.007多,超过0.005了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的形状不行,max gradient diff到了0.007多,超过0.005了。

为什么有的形状不行,不同的形状除了误差增大,是否有可能会有走不通的逻辑?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该和形状没太大关系,因为相同的形状下,complex64通不过的情况,换成complex128就会通过。逻辑应该一样都能走通。
会不会和这个有关link,我看有的算子也把精度放宽到0.06了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该和形状没太大关系,因为相同的形状下,complex64通不过的情况,换成complex128就会通过。逻辑应该一样都能走通。 会不会和这个有关link,我看有的算子也把精度放宽到0.06了。

你看过执行的逻辑了吗?确认是同一个逻辑的话可以放宽,因为有的算子会针对不同的shape组合走不同的逻辑。另外记得把其他形状的测试补齐哦

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看过了,不同形状的逻辑是一样的。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看过了,不同形状的逻辑是一样的。

ok,max gradient diff到了0.007多,你现在调整为0.05了,是否放宽的太多了,可以0.007多的,可以放宽到0.008,其余类似。不能只是加complex64的,把complex128的其他形状的测试也加上。

class TestTriangularSolveOpCp64(TestTriangularSolveOp):
"""
case complex64
"""

def config(self):
self.x_shape = [1, 10, 10]
self.y_shape = [6, 10, 12]
self.upper = False
self.transpose = False
self.unitriangular = False
self.dtype = "complex64"

def set_output(self):
x = np.tril(self.inputs['X'])
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)


# 3D(broadcast) + 3D complex128
class TestTriangularSolveCp128(TestTriangularSolveOp):
"""
case complex128
"""

def config(self):
self.x_shape = [1, 10, 10]
self.y_shape = [6, 10, 12]
self.upper = False
self.transpose = False
self.unitriangular = False
self.dtype = "complex128"

def set_output(self):
x = np.tril(self.inputs['X'])
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)


class TestTriangularSolveAPI(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
Expand Down