diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 02ccf7a71fa10..9ad445d62ff1c 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -317,7 +317,12 @@ def inverse_sort_op(ops): def append_backward_ops( - block, effective_forward_ops, no_grad_set, backward_ops, state + fwd_block, + bwd_block, + effective_forward_ops, + no_grad_set, + backward_ops, + state, ): ''' add grad_op in order of topological inverse sort @@ -351,11 +356,12 @@ def append_backward_ops( value_to_valuegrad[v12] = [[v12_g]] value_to_valuegrad[v2] = [[v2_g]] - if op don't has grad_op, if it don't has input and it's output has more than - one output_grad, add sumop for grad aggregation. + if op don't has grad_op: + if it don't has input and it's output has more than + one output_grad, add sumop for grad aggregation. (eg: full op and get_parameter op etc.) - else continue to next op. + else continue to next op. ''' def make_output_with_output_grad(op): @@ -374,8 +380,8 @@ def make_output_with_output_grad(op): paddle.add_n( [item[0] for item in state.value_to_valuegrad[value]] ) - combineop = block.ops[len(block.ops) - 2] - sumop = block.ops[len(block.ops) - 1] + combineop = bwd_block.ops[len(bwd_block.ops) - 2] + sumop = bwd_block.ops[len(bwd_block.ops) - 1] update_bwdop_structure( backward_ops, state.op_to_opgrad[op], combineop ) @@ -507,16 +513,19 @@ def update_input_grad_map(op, input_grads): # [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) # [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType) # [op4] (op4's inputs and outputs are not vectorType) - # einsum has twp vectorType outputs, special pattern + inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops) clear_effective_forward_ops = [] - for op in effective_forward_ops: + for op in inverse_effective_forward_ops: if op.name() != "builtin.combine" and op.name() != "builtin.split": clear_effective_forward_ops.append(op) - + # with bwd_block: for op in clear_effective_forward_ops: if paddle.framework.core.has_vjp(op): + if op.name() == "pd_op.if" or op.name() == "pd_op.while": + continue + # prepare output_grad zero_flag, outputs, output_grads = make_output_with_output_grad(op) @@ -532,16 +541,16 @@ def update_input_grad_map(op, input_grads): ) = make_input_with_input_stopgradient(op) # create grad_op - before_ops_num = len(block.ops) + before_ops_num = len(bwd_block.ops) input_grads = paddle.framework.core.call_vjp( op, inputs, outputs, output_grads, input_grad_stopgradients ) - after_ops_num = len(block.ops) + after_ops_num = len(bwd_block.ops) # update grad_op structure for i in range(before_ops_num, after_ops_num): update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], block.ops[i] + backward_ops, state.op_to_opgrad[op], bwd_block.ops[i] ) # update input_grad map @@ -558,8 +567,8 @@ def update_input_grad_map(op, input_grads): for item in state.value_to_valuegrad[value] ] ) - combineop = block.ops[len(block.ops) - 2] - sumop = block.ops[len(block.ops) - 1] + combineop = bwd_block.ops[len(bwd_block.ops) - 2] + sumop = bwd_block.ops[len(bwd_block.ops) - 1] update_bwdop_structure( backward_ops, state.op_to_opgrad[op], combineop ) @@ -651,10 +660,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): block, effective_forward_ops, no_grad_set, inputs, complete_outputs ) - inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops) - append_backward_ops( - block, inverse_effective_forward_ops, no_grad_set, backward_ops, state + block, block, effective_forward_ops, no_grad_set, backward_ops, state ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(