-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OneDNN] Fc elementwise add fusion #58276
base: develop
Are you sure you want to change the base?
[OneDNN] Fc elementwise add fusion #58276
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
1a16786
to
a6016d5
Compare
88d8b8b
to
436e6df
Compare
Sorry to inform you that 436e6df's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
auto residual_data_md = dnnl::memory::desc( | ||
{MB, OC}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recommend to get_mdesc from the residual input tensor instead of assumption here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inner product primitive, the dst is always NC(mentioned in doc). That's why we can not use residual mdesc directly, have to make sure residual md shape is NC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when the residual won't be NC? BTW, the data type can also be bf16, right?
@@ -506,6 +623,10 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { | |||
ip_cache->src_mem = *src_memory_p; | |||
ip_cache->weights_mem = *weights_memory_p; | |||
ip_cache->dst_mem = *dst_memory_p; | |||
if (residual_data && residual_data_memory_p) { | |||
ip_cache->residual_data = *residual_data; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why need cache residual_data?
# 'Scale_in_eltwise': self.residual_scale, | ||
# 'fuse_residual_connection': True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it expected?
@LLee233 Please help on a review, thanks:) |
out->dims(), | ||
residual_param->dims(), | ||
phi::errors::InvalidArgument( | ||
"Output and elementwise parameter need to have the " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after using post_binary_add, do we still need to force residual data and dst to have same dims? I just think binary-add should support broadcast.
8950d76
to
300015a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, just want to ask if we should still keep "fc_eltwise_add" since now it becomes "binary_add" (have extra input).
// For Inner Product primitives, the destination always N * C | ||
auto residual_data = ctx.Input<phi::DenseTensor>("ResidualData"); | ||
auto residual_data_md = | ||
dnnl::memory::desc({MB, OC}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does residual_data_md
should have determined shape? Since binary_add
has enabled broadcast, maybe {1, 1} or {MB, 1} is also good?
c1da50a
to
f70bfbc
Compare
@XieYunshen , hi, could you please help me approve the setting TIMEOUT properties? |
@XiaoguangHu01 , hi, could you help me approve CI 'the usage of const_cast'? thanks |
Sorry to inform you that f70bfbc's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR types
New features
PR changes
Others
Description
Paddle does not support in-place computation now, implement fc_elementise_add using fc + binary_add. It is also a re-implementation of the previous PR #55504 directly delete the pass to solve the accuracy problem.