Skip to content

Commit

Permalink
【New IR】add No grad sermonic info in ir_backward.py (PaddlePaddle#57206)
Browse files Browse the repository at this point in the history
* first test

* modify combine op bug

* modify combine op bug
  • Loading branch information
xiaoguoguo626807 authored Sep 12, 2023
1 parent f98a3c0 commit a6411b7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"tanh",
"mean",
"divide",
"sum",
"add",
"concat",
"split",
Expand Down
24 changes: 0 additions & 24 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,5 @@ namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;

std::vector<std::vector<pir::OpResult>> SumOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
SumOp op_obj = op->dyn_cast<SumOp>();
Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));

Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));

bool keepdim = op->attribute("keepdim").dyn_cast<pir::BoolAttribute>().data();
bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<pir::OpResult>> res(2, std::vector<pir::OpResult>(1));
if (tensor_res[0][0].defined()) {
res[0][0] =
std::static_pointer_cast<primitive::LazyTensor>(tensor_res[0][0].impl())
->value()
.dyn_cast<pir::OpResult>();
}
return res;
}

} // namespace dialect
} // namespace paddle
22 changes: 20 additions & 2 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,15 @@ def make_output_grad(op):

def make_input_stopgradient(op):
input_grad_stopgradients = []
for input in op.operands_source():
if op.name() == "builtin.combine":
grad_semantic_info = [True for _ in range(op.num_operands())]
else:
grad_semantic_info = op.get_input_grad_semantics()
for input, grad_semantic in zip(
op.operands_source(), grad_semantic_info
):
if not grad_semantic:
continue
if input.get_defining_op().name() == "builtin.combine":
stop_gradient = make_input_stopgradient(input.get_defining_op())
input_grad_stopgradients.append(
Expand All @@ -423,7 +431,16 @@ def make_input_stopgradient(op):
return input_grad_stopgradients

def update_input_grad_map(op, input_grads):
for i, input in enumerate(op.operands_source()):
i = 0
if op.name() == "builtin.combine":
grad_semantic_info = [True for _ in range(op.num_operands())]
else:
grad_semantic_info = op.get_input_grad_semantics()
for input, grad_semantic in zip(
op.operands_source(), grad_semantic_info
):
if not grad_semantic:
continue
if input.get_defining_op().name() == "builtin.combine":
update_input_grad_map(input.get_defining_op(), input_grads[i])
else:
Expand All @@ -432,6 +449,7 @@ def update_input_grad_map(op, input_grads):
state.value_to_valuegrad[input].append(input_grad)
else:
state.value_to_valuegrad[input].append([input_grad])
i += 1

# there are four patterns:
# [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType)
Expand Down

0 comments on commit a6411b7

Please sign in to comment.