-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR】modify backward for controlflow op #59020
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
1653bca
add refresh stopgradint
xiaoguoguo626807 913c229
add refresh stopgradint
xiaoguoguo626807 730382b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 5a7bc04
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 e23bb6d
modofy
xiaoguoguo626807 b8cd27e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 5cd48ef
modify backward
xiaoguoguo626807 8bbfea9
modify
xiaoguoguo626807 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -317,7 +317,12 @@ def inverse_sort_op(ops): | |||||
|
||||||
|
||||||
def append_backward_ops( | ||||||
block, effective_forward_ops, no_grad_set, backward_ops, state | ||||||
fwd_block, | ||||||
bwd_block, | ||||||
effective_forward_ops, | ||||||
no_grad_set, | ||||||
backward_ops, | ||||||
state, | ||||||
): | ||||||
''' | ||||||
add grad_op in order of topological inverse sort | ||||||
|
@@ -351,11 +356,12 @@ def append_backward_ops( | |||||
value_to_valuegrad[v12] = [[v12_g]] | ||||||
value_to_valuegrad[v2] = [[v2_g]] | ||||||
|
||||||
if op don't has grad_op, if it don't has input and it's output has more than | ||||||
one output_grad, add sumop for grad aggregation. | ||||||
if op don't has grad_op: | ||||||
if it don't has input and it's output has more than | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
后续PR可以连带修改下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||||||
one output_grad, add sumop for grad aggregation. | ||||||
(eg: full op and get_parameter op etc.) | ||||||
|
||||||
else continue to next op. | ||||||
else continue to next op. | ||||||
''' | ||||||
|
||||||
def make_output_with_output_grad(op): | ||||||
|
@@ -374,8 +380,8 @@ def make_output_with_output_grad(op): | |||||
paddle.add_n( | ||||||
[item[0] for item in state.value_to_valuegrad[value]] | ||||||
) | ||||||
combineop = block.ops[len(block.ops) - 2] | ||||||
sumop = block.ops[len(block.ops) - 1] | ||||||
combineop = bwd_block.ops[len(bwd_block.ops) - 2] | ||||||
sumop = bwd_block.ops[len(bwd_block.ops) - 1] | ||||||
update_bwdop_structure( | ||||||
backward_ops, state.op_to_opgrad[op], combineop | ||||||
) | ||||||
|
@@ -507,16 +513,19 @@ def update_input_grad_map(op, input_grads): | |||||
# [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) | ||||||
# [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType) | ||||||
# [op4] (op4's inputs and outputs are not vectorType) | ||||||
# einsum has twp vectorType outputs, special pattern | ||||||
|
||||||
inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops) | ||||||
clear_effective_forward_ops = [] | ||||||
|
||||||
for op in effective_forward_ops: | ||||||
for op in inverse_effective_forward_ops: | ||||||
if op.name() != "builtin.combine" and op.name() != "builtin.split": | ||||||
clear_effective_forward_ops.append(op) | ||||||
|
||||||
# with bwd_block: | ||||||
for op in clear_effective_forward_ops: | ||||||
if paddle.framework.core.has_vjp(op): | ||||||
if op.name() == "pd_op.if" or op.name() == "pd_op.while": | ||||||
continue | ||||||
|
||||||
# prepare output_grad | ||||||
zero_flag, outputs, output_grads = make_output_with_output_grad(op) | ||||||
|
||||||
|
@@ -532,16 +541,16 @@ def update_input_grad_map(op, input_grads): | |||||
) = make_input_with_input_stopgradient(op) | ||||||
|
||||||
# create grad_op | ||||||
before_ops_num = len(block.ops) | ||||||
before_ops_num = len(bwd_block.ops) | ||||||
input_grads = paddle.framework.core.call_vjp( | ||||||
op, inputs, outputs, output_grads, input_grad_stopgradients | ||||||
) | ||||||
after_ops_num = len(block.ops) | ||||||
after_ops_num = len(bwd_block.ops) | ||||||
|
||||||
# update grad_op structure | ||||||
for i in range(before_ops_num, after_ops_num): | ||||||
update_bwdop_structure( | ||||||
backward_ops, state.op_to_opgrad[op], block.ops[i] | ||||||
backward_ops, state.op_to_opgrad[op], bwd_block.ops[i] | ||||||
) | ||||||
|
||||||
# update input_grad map | ||||||
|
@@ -558,8 +567,8 @@ def update_input_grad_map(op, input_grads): | |||||
for item in state.value_to_valuegrad[value] | ||||||
] | ||||||
) | ||||||
combineop = block.ops[len(block.ops) - 2] | ||||||
sumop = block.ops[len(block.ops) - 1] | ||||||
combineop = bwd_block.ops[len(bwd_block.ops) - 2] | ||||||
sumop = bwd_block.ops[len(bwd_block.ops) - 1] | ||||||
update_bwdop_structure( | ||||||
backward_ops, state.op_to_opgrad[op], combineop | ||||||
) | ||||||
|
@@ -651,10 +660,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): | |||||
block, effective_forward_ops, no_grad_set, inputs, complete_outputs | ||||||
) | ||||||
|
||||||
inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops) | ||||||
|
||||||
append_backward_ops( | ||||||
block, inverse_effective_forward_ops, no_grad_set, backward_ops, state | ||||||
block, block, effective_forward_ops, no_grad_set, backward_ops, state | ||||||
) | ||||||
# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) | ||||||
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( | ||||||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwd_block 在函数里没有用到么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
控制流分支会用到,现在还缺接口该分支未实现