-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[API/OP] Support FP16/BF16 in paddle.nonzero API/OP #51640
[API/OP] Support FP16/BF16 in paddle.nonzero API/OP #51640
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
phi::NonZeroKernel, | ||
int, | ||
bool, | ||
phi::dtype::float16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xpu无需修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/tensor/search.py
Outdated
'int16', | ||
'int32', | ||
'int64', | ||
'bfloat16', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时使用uint16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
|
||
def return_outputs(self): | ||
return {'Out': np.transpose(np.nonzero(self.inputs['Condition']))} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要transpose呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy.nonzero
默认返回的是 tuple 类型,可以理解为所有非0元素的x下标和y下标分别放置在不同 array 中,即返回由2个 array 组成的 tuple ;此功能对应 paddle API 中参数as_tuple
为True
的case(默认为False
)。因此numpy.nonzero(x)
等价于paddle.nonzero(x, as_tuple=True)
,numpy.transpose (numpy.nonzero(x))
等价于paddle.nonzero(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Bug fixes
PR changes
OPs
Describe
[API/OP] Support FP16/BF16 in paddle.nonzero API/OP