-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【complex op】No.19 add complex support for triangular_solve #59529
Changes from 4 commits
1be84c7
22e495b
ea76c11
d6d3b55
dab4109
7ad9317
bda2de6
c94bf35
fc4012d
499a1f4
dcc9c1b
e28882c
bc14276
bf83456
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,10 +51,23 @@ def setUp(self): | |
self.python_api = paddle.tensor.linalg.triangular_solve | ||
self.config() | ||
|
||
self.inputs = { | ||
'X': np.random.random(self.x_shape).astype(self.dtype), | ||
'Y': np.random.random(self.y_shape).astype(self.dtype), | ||
} | ||
if self.dtype is np.complex64 or self.dtype is np.complex128: | ||
self.inputs = { | ||
'X': ( | ||
np.random.random(self.x_shape) | ||
+ 1j * np.random.random(self.x_shape) | ||
).astype(self.dtype), | ||
'Y': ( | ||
np.random.random(self.y_shape) | ||
+ 1j * np.random.random(self.y_shape) | ||
).astype(self.dtype), | ||
} | ||
else: | ||
self.inputs = { | ||
'X': np.random.random(self.x_shape).astype(self.dtype), | ||
'Y': np.random.random(self.y_shape).astype(self.dtype), | ||
} | ||
|
||
self.attrs = { | ||
'upper': self.upper, | ||
'transpose': self.transpose, | ||
|
@@ -248,6 +261,46 @@ def set_output(self): | |
self.output = np.matmul(np.linalg.inv(x), y) | ||
|
||
|
||
# 3D(broadcast) + 3D complex64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里只测试了3D+3D的,考虑和 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有的形状不行, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
为什么有的形状不行,不同的形状除了误差增大,是否有可能会有走不通的逻辑? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该和形状没太大关系,因为相同的形状下, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
你看过执行的逻辑了吗?确认是同一个逻辑的话可以放宽,因为有的算子会针对不同的shape组合走不同的逻辑。另外记得把其他形状的测试补齐哦 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看过了,不同形状的逻辑是一样的。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
ok,max gradient diff到了0.007多,你现在调整为0.05了,是否放宽的太多了,可以0.007多的,可以放宽到0.008,其余类似。不能只是加 |
||
class TestTriangularSolveOpCp64(TestTriangularSolveOp): | ||
""" | ||
case complex64 | ||
""" | ||
|
||
def config(self): | ||
self.x_shape = [1, 10, 10] | ||
self.y_shape = [6, 10, 12] | ||
self.upper = False | ||
self.transpose = False | ||
self.unitriangular = False | ||
self.dtype = "complex64" | ||
|
||
def set_output(self): | ||
x = np.tril(self.inputs['X']) | ||
y = self.inputs['Y'] | ||
self.output = np.linalg.solve(x, y) | ||
|
||
|
||
# 3D(broadcast) + 3D complex128 | ||
class TestTriangularSolveCp128(TestTriangularSolveOp): | ||
""" | ||
case complex128 | ||
""" | ||
|
||
def config(self): | ||
self.x_shape = [1, 10, 10] | ||
self.y_shape = [6, 10, 12] | ||
self.upper = False | ||
self.transpose = False | ||
self.unitriangular = False | ||
self.dtype = "complex128" | ||
|
||
def set_output(self): | ||
x = np.tril(self.inputs['X']) | ||
y = self.inputs['Y'] | ||
self.output = np.linalg.solve(x, y) | ||
|
||
|
||
class TestTriangularSolveAPI(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2021) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么改动了,有看过添加这个pr吗,为什么之前要用phi::dtype::complex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数调用的是
cblas
中用于解方程组的单精度复数版本的cblas_ctrsm
,同时下面也有调用
cblas
中用于解方程组的双精度复数版本的cblas_ztrsm
,所以说A,B的类型应该是一样的。个人感觉这里应该是笔误。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
您好,可以再帮忙review一下吗~