Skip to content

Commit

Permalink
【PIR】modify backward for controlflow op (#59020)
Browse files Browse the repository at this point in the history
* add refresh stopgradint

* add refresh stopgradint

* modofy

* modify backward

* modify
  • Loading branch information
xiaoguoguo626807 authored Nov 16, 2023
1 parent 499c8d2 commit dc9e4ed
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,12 @@ def inverse_sort_op(ops):


def append_backward_ops(
block, effective_forward_ops, no_grad_set, backward_ops, state
fwd_block,
bwd_block,
effective_forward_ops,
no_grad_set,
backward_ops,
state,
):
'''
add grad_op in order of topological inverse sort
Expand Down Expand Up @@ -351,11 +356,12 @@ def append_backward_ops(
value_to_valuegrad[v12] = [[v12_g]]
value_to_valuegrad[v2] = [[v2_g]]
if op don't has grad_op, if it don't has input and it's output has more than
one output_grad, add sumop for grad aggregation.
if op don't has grad_op:
if it don't has input and it's output has more than
one output_grad, add sumop for grad aggregation.
(eg: full op and get_parameter op etc.)
else continue to next op.
else continue to next op.
'''

def make_output_with_output_grad(op):
Expand All @@ -374,8 +380,8 @@ def make_output_with_output_grad(op):
paddle.add_n(
[item[0] for item in state.value_to_valuegrad[value]]
)
combineop = block.ops[len(block.ops) - 2]
sumop = block.ops[len(block.ops) - 1]
combineop = bwd_block.ops[len(bwd_block.ops) - 2]
sumop = bwd_block.ops[len(bwd_block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
Expand Down Expand Up @@ -507,16 +513,19 @@ def update_input_grad_map(op, input_grads):
# [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType)
# [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType)
# [op4] (op4's inputs and outputs are not vectorType)
# einsum has twp vectorType outputs, special pattern

inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops)
clear_effective_forward_ops = []

for op in effective_forward_ops:
for op in inverse_effective_forward_ops:
if op.name() != "builtin.combine" and op.name() != "builtin.split":
clear_effective_forward_ops.append(op)

# with bwd_block:
for op in clear_effective_forward_ops:
if paddle.framework.core.has_vjp(op):
if op.name() == "pd_op.if" or op.name() == "pd_op.while":
continue

# prepare output_grad
zero_flag, outputs, output_grads = make_output_with_output_grad(op)

Expand All @@ -532,16 +541,16 @@ def update_input_grad_map(op, input_grads):
) = make_input_with_input_stopgradient(op)

# create grad_op
before_ops_num = len(block.ops)
before_ops_num = len(bwd_block.ops)
input_grads = paddle.framework.core.call_vjp(
op, inputs, outputs, output_grads, input_grad_stopgradients
)
after_ops_num = len(block.ops)
after_ops_num = len(bwd_block.ops)

# update grad_op structure
for i in range(before_ops_num, after_ops_num):
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], block.ops[i]
backward_ops, state.op_to_opgrad[op], bwd_block.ops[i]
)

# update input_grad map
Expand All @@ -558,8 +567,8 @@ def update_input_grad_map(op, input_grads):
for item in state.value_to_valuegrad[value]
]
)
combineop = block.ops[len(block.ops) - 2]
sumop = block.ops[len(block.ops) - 1]
combineop = bwd_block.ops[len(bwd_block.ops) - 2]
sumop = bwd_block.ops[len(bwd_block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
Expand Down Expand Up @@ -651,10 +660,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
block, effective_forward_ops, no_grad_set, inputs, complete_outputs
)

inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops)

append_backward_ops(
block, inverse_effective_forward_ops, no_grad_set, backward_ops, state
block, block, effective_forward_ops, no_grad_set, backward_ops, state
)
# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
Expand Down

0 comments on commit dc9e4ed

Please sign in to comment.