-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【pir】add fused_linear_param_grad_add_pass #58401
【pir】add fused_linear_param_grad_add_pass #58401
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一些小建议哈,下次顺便改下就OK~
if (y_trans) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (y_trans) { | |
return false; | |
} else { | |
return true; | |
} | |
return !y_trans; |
if (match_ctx.Tensor("dweight").Dtype() == | ||
match_ctx.Tensor("weight_grad").Dtype()) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (match_ctx.Tensor("dweight").Dtype() == | |
match_ctx.Tensor("weight_grad").Dtype()) { | |
return false; | |
} else { | |
return true; | |
} | |
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype(); |
if (y_trans) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (y_trans) { | |
return false; | |
} else { | |
return true; | |
} | |
return !y_trans; |
if (match_ctx.Tensor("dweight").Dtype() == | ||
match_ctx.Tensor("weight_grad").Dtype()) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (match_ctx.Tensor("dweight").Dtype() == | |
match_ctx.Tensor("weight_grad").Dtype()) { | |
return false; | |
} else { | |
return true; | |
} | |
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype(); |
if (match_ctx.Tensor("dweight").Dtype() == | ||
match_ctx.Tensor("weight_grad").Dtype()) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (match_ctx.Tensor("dweight").Dtype() == | |
match_ctx.Tensor("weight_grad").Dtype()) { | |
return false; | |
} else { | |
return true; | |
} | |
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype(); |
if (match_ctx.Tensor("dweight").Dtype() == | ||
match_ctx.Tensor("weight_grad").Dtype()) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (match_ctx.Tensor("dweight").Dtype() == | |
match_ctx.Tensor("weight_grad").Dtype()) { | |
return false; | |
} else { | |
return true; | |
} | |
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype(); |
res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { | ||
if (match_ctx.Tensor("dweight").Dtype() == | ||
match_ctx.Tensor("weight_grad").Dtype()) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { | |
if (match_ctx.Tensor("dweight").Dtype() == | |
match_ctx.Tensor("weight_grad").Dtype()) { | |
return false; | |
} else { | |
return true; | |
} | |
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype(); |
if (match_ctx.Tensor("dweight").Dtype() == | ||
match_ctx.Tensor("weight_grad").Dtype()) { | ||
return false; | ||
} else { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (match_ctx.Tensor("dweight").Dtype() == | |
match_ctx.Tensor("weight_grad").Dtype()) { | |
return false; | |
} else { | |
return true; | |
} | |
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype(); |
#include "paddle/pir/pass/pass.h" | ||
#include "paddle/pir/pass/pass_registry.h" | ||
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" | ||
namespace { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以加个空行~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢!#58420
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify
* fused_linear_param_grad_add_pass * modify * skip cpu ci * add muti_presion modify * modify * modify * modify * modify
PR types
others
PR changes
others
Description
pcard-67164
增加fused_linear_param_grad_add_pass相关pattern
add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add
matmul_grad + add_ -> matmul + fused_liner_param_gard_add
matmul + 0 = add_(0,1) -> fused_liner_param_gard_add
matmul + 1 = add_(1,0) -> fused_liner_param_gard_add
add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add
add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add
注意:
1.fused_liner_param_gard_add muti_precision 参数推导为(dweight.dtype (其他结果)== weight_grad(matmul结果) 时为false, 否则为true。
2.matmul_grad 的拆解融合目前只支持trans_x=false 的匹配,trans_x=true未支持
add_ 的inplace tensor 为其他结果,inplace matmul_grad结果未支持。
3. 模式匹配 前向matmul 的 第二个参数为weight , 不支持第一个参数为weight
4. 由于新ir组网得到的梯度累加算子是add_n, python 测无法组出上述pattern,在c++测进行测试,只测试组网替换,执行结果未测。