-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【complex op】No.12、14 add complex support for square & reciprocal #60821
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
研发大哥 麻烦review一下~ |
@@ -294,8 +294,10 @@ HOSTDEVICE inline complex<T>& operator*=(complex<T>& a, // NOLINT | |||
thrust::complex<T>(b.real, b.imag)); | |||
return a; | |||
#else | |||
a.real = a.real * b.real - a.imag * b.imag; | |||
a.imag = a.imag * b.real + b.imag * a.real; | |||
T r = a.real * b.real - a.imag * b.imag; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看着之前的是有点问题,你是否找一些测试案例,测试一下这种情况?看*=能否产生正确的结果,如果不能请展示一下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是我学习cumprod的代码时发现的,在cpu端梯度反向传播时会进行复数的*=
操作,然后产生错误。比如说这张图中,x_grad[0][0]的梯度应该是conj(1 + x[1][0]) = conj(1 + 2 + 3j)=3-3j
,但是这里用了*=
,相应的计算逻辑是conj(1 + 1 *= x[1][0])=conj(1 + [(1 * 2 - 0*3) + (0*2+3*2)j])=conj(3+6j)
。应该就是因为在计算a.imag时使用了新的a.real导致的。
self.check_grad(['X'], 'Out', max_relative_error=0.01, check_pir=True) | ||
if self.dtype == np.complex64 or self.dtype == np.complex128: | ||
self.check_grad( | ||
['X'], 'Out', max_relative_error=0.03, check_pir=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看这里对于复数类型的进行了特判,将设置max_relatice_error=0.03
,绝对误差相差多少呢,max_relatice_error=0.02
能否通过呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我在本机测试的是0.02多一点,所以扩大到0.03了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR types
Others
PR changes
OPs
Description
*+
操作好像有点问题,在pr里修改了下【complex op】paddlepaddle 支持复数 #56145