Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】add fused_linear_param_grad_add_pass #58401

Conversation

xiaoguoguo626807
Copy link
Contributor

@xiaoguoguo626807 xiaoguoguo626807 commented Oct 26, 2023

PR types

others

PR changes

others

Description

pcard-67164
增加fused_linear_param_grad_add_pass相关pattern
add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add
matmul_grad + add_ -> matmul + fused_liner_param_gard_add
matmul + 0 = add_(0,1) -> fused_liner_param_gard_add
matmul + 1 = add_(1,0) -> fused_liner_param_gard_add
add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add
add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add

注意:
1.fused_liner_param_gard_add muti_precision 参数推导为(dweight.dtype (其他结果)== weight_grad(matmul结果) 时为false, 否则为true。
2.matmul_grad 的拆解融合目前只支持trans_x=false 的匹配,trans_x=true未支持
add_ 的inplace tensor 为其他结果,inplace matmul_grad结果未支持。
3. 模式匹配 前向matmul 的 第二个参数为weight , 不支持第一个参数为weight
4. 由于新ir组网得到的梯度累加算子是add_n, python 测无法组出上述pattern,在c++测进行测试,只测试组网替换,执行结果未测。

@paddle-bot
Copy link

paddle-bot bot commented Oct 26, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Oct 26, 2023
Copy link
Contributor

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一些小建议哈,下次顺便改下就OK~

Comment on lines +56 to +60
if (y_trans) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (y_trans) {
return false;
} else {
return true;
}
return !y_trans;

Comment on lines +65 to +70
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype();

Comment on lines +123 to +127
if (y_trans) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (y_trans) {
return false;
} else {
return true;
}
return !y_trans;

Comment on lines +132 to +137
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype();

Comment on lines +187 to +192
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype();

Comment on lines +236 to +241
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype();

Comment on lines +288 to +294
res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool {
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool {
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype();

Comment on lines +338 to +343
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (match_ctx.Tensor("dweight").Dtype() ==
match_ctx.Tensor("weight_grad").Dtype()) {
return false;
} else {
return true;
}
return match_ctx.Tensor("dweight").Dtype() != match_ctx.Tensor("weight_grad").Dtype();

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
namespace {
Copy link
Contributor

@yuanlehome yuanlehome Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以加个空行~

Copy link
Contributor Author

@xiaoguoguo626807 xiaoguoguo626807 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢!#58420

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xiaoguoguo626807 xiaoguoguo626807 merged commit 6f7cd08 into PaddlePaddle:develop Oct 26, 2023
@xiaoguoguo626807 xiaoguoguo626807 deleted the fused_linear_param_grad_add branch October 26, 2023 11:16
xiaoguoguo626807 added a commit that referenced this pull request Oct 27, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 30, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 30, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 30, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 30, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 30, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 31, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 31, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
cxxly pushed a commit to cxxly/Paddle that referenced this pull request Oct 31, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
@paddle-bot paddle-bot bot removed the contributor External developers label Nov 3, 2023
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify

* modify

* modify

* modify

* modify
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants