Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】modify ir_backward to build If grad #59520

Merged
merged 4 commits into from
Nov 30, 2023

Conversation

xiaoguoguo626807
Copy link
Contributor

@xiaoguoguo626807 xiaoguoguo626807 commented Nov 29, 2023

PR types

others

PR changes

others

Description

pcard-67164

控制流ifop 反向组网初步打通,对原始反向逻辑有较大的修改
设计注意点:

  1. state 数据结构对应关系由program 变化为前向block , 由于不同子block 中同时使用父block变量,会改动state中存储关系,因此子block的state需要复制父block的state后进行操作。
  2. push op 的call_vjp 返回值为pop op的输出,而非输入的梯度,语义与其他call_vjp不同因此单独处理。
  3. if 的call_vjp 接口参数形式对齐其他op. 两级list 第一级表示输入; 第二级表示vector输入的内部元素。
  4. if_op call_vjp 后返回的输出值为cond梯度, 其他输入的梯度, 在经过子block反向后,yield 会单独给第一个输出赋值fake_opreslut, 其他输入梯度如有梯度则传出,如没有则跳过。如后续修改cond的stopgradient信息为True,需要适配此处逻辑。

TODO:

  1. if_op 的build函数没有处理stopgradient的传播, 应按照block中对应op输出的stopgradient修改,目前组网后手动修改保证单测通过。

优化方向:

  1. 统一获取if op 使用外部输入的接口get_used_external_value(op) 至op.operands_source()
  2. 为if, combine, while 配置opyamlinfermeta,使其能够调用get_input_grad_semantics()

@xiaoguoguo626807 xiaoguoguo626807 changed the title If grad 【pir】modify ir_backward to build If grad Nov 30, 2023
Copy link
Contributor

@winter-wang winter-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@changeyoung98 changeyoung98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

std::vector<pir::OpResult>{nullptr}};
for (size_t i = 0u; i < pop_op.num_results(); ++i) {
res[0].push_back(pop_op.result(i));
std::vector<std::vector<pir::OpResult>> res{inputs.size()};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿建议enforce_eq一下 inputs.size() == num_results() + 1 , 不然感觉出问题后不好排查。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个目前inputs.size() == num_resluts() 待cond的stopgrdient 修改完成重新适配

@xiaoguoguo626807 xiaoguoguo626807 merged commit f07e6f5 into PaddlePaddle:develop Nov 30, 2023
29 checks passed
@xiaoguoguo626807 xiaoguoguo626807 deleted the if_grad branch November 30, 2023 06:52
Copy link

paddle-bot bot commented Dec 1, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants