Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR】modify backward for controlflow op #59020

Merged
merged 8 commits into from
Nov 16, 2023
41 changes: 24 additions & 17 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,12 @@ def inverse_sort_op(ops):


def append_backward_ops(
block, effective_forward_ops, no_grad_set, backward_ops, state
fwd_block,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwd_block 在函数里没有用到么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

控制流分支会用到,现在还缺接口该分支未实现

bwd_block,
effective_forward_ops,
no_grad_set,
backward_ops,
state,
):
'''
add grad_op in order of topological inverse sort
Expand Down Expand Up @@ -351,11 +356,12 @@ def append_backward_ops(
value_to_valuegrad[v12] = [[v12_g]]
value_to_valuegrad[v2] = [[v2_g]]

if op don't has grad_op, if it don't has input and it's output has more than
one output_grad, add sumop for grad aggregation.
if op don't has grad_op:
if it don't has input and it's output has more than
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if it don't has input and it's output has more than
if it doesn't has input and its output has more than

后续PR可以连带修改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

one output_grad, add sumop for grad aggregation.
(eg: full op and get_parameter op etc.)

else continue to next op.
else continue to next op.
'''

def make_output_with_output_grad(op):
Expand All @@ -374,8 +380,8 @@ def make_output_with_output_grad(op):
paddle.add_n(
[item[0] for item in state.value_to_valuegrad[value]]
)
combineop = block.ops[len(block.ops) - 2]
sumop = block.ops[len(block.ops) - 1]
combineop = bwd_block.ops[len(bwd_block.ops) - 2]
sumop = bwd_block.ops[len(bwd_block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
Expand Down Expand Up @@ -507,16 +513,19 @@ def update_input_grad_map(op, input_grads):
# [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType)
# [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType)
# [op4] (op4's inputs and outputs are not vectorType)
# einsum has twp vectorType outputs, special pattern

inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops)
clear_effective_forward_ops = []

for op in effective_forward_ops:
for op in inverse_effective_forward_ops:
if op.name() != "builtin.combine" and op.name() != "builtin.split":
clear_effective_forward_ops.append(op)

# with bwd_block:
for op in clear_effective_forward_ops:
if paddle.framework.core.has_vjp(op):
if op.name() == "pd_op.if" or op.name() == "pd_op.while":
continue

# prepare output_grad
zero_flag, outputs, output_grads = make_output_with_output_grad(op)

Expand All @@ -532,16 +541,16 @@ def update_input_grad_map(op, input_grads):
) = make_input_with_input_stopgradient(op)

# create grad_op
before_ops_num = len(block.ops)
before_ops_num = len(bwd_block.ops)
input_grads = paddle.framework.core.call_vjp(
op, inputs, outputs, output_grads, input_grad_stopgradients
)
after_ops_num = len(block.ops)
after_ops_num = len(bwd_block.ops)

# update grad_op structure
for i in range(before_ops_num, after_ops_num):
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], block.ops[i]
backward_ops, state.op_to_opgrad[op], bwd_block.ops[i]
)

# update input_grad map
Expand All @@ -558,8 +567,8 @@ def update_input_grad_map(op, input_grads):
for item in state.value_to_valuegrad[value]
]
)
combineop = block.ops[len(block.ops) - 2]
sumop = block.ops[len(block.ops) - 1]
combineop = bwd_block.ops[len(bwd_block.ops) - 2]
sumop = bwd_block.ops[len(bwd_block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
Expand Down Expand Up @@ -651,10 +660,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
block, effective_forward_ops, no_grad_set, inputs, complete_outputs
)

inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops)

append_backward_ops(
block, inverse_effective_forward_ops, no_grad_set, backward_ops, state
block, block, effective_forward_ops, no_grad_set, backward_ops, state
)
# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
Expand Down