-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PTen]elementwise_sub kernel refactor #37260
[PTen]elementwise_sub kernel refactor #37260
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -96,6 +96,40 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) { | |||
return out; | |||
} | |||
|
|||
PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) { | |||
// 1. Get kernel signature and kernel | |||
auto kernel_key_set = ParseKernelKeyByInputArgs(x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前对于add和sub,这里应该是不仅根据x来确定kernel类型,可以参考原先elementwise的GetExpectedKernelType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这个我记个TODO,elementwise算子迁移后改一下
void ElementwiseSub(const CPUContext& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& y, | ||
int axis, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个axis参数会在原先哪些场景用到,API已经没有这个参数了,我们能否去掉?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不行,axis是用来broadcast用的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for PR-CI-OP-benchmark
PR types
Others
PR changes
OPs
Describe
elementwise_sub Kernel refactor