Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】modify Call vjp interface for controlflow grad #58277

Merged
merged 32 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bce9b3b
tmp
xiaoguoguo626807 Aug 30, 2023
c2341a5
fix conflict
xiaoguoguo626807 Aug 30, 2023
4d30fdd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 31, 2023
c94252d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 19, 2023
3aa6686
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 22, 2023
3b3b5ea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 25, 2023
cae57c1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 7, 2023
d52fe87
[PIR]Migrate maximum into pir
Oct 8, 2023
9e5a0b1
Polish code
Oct 9, 2023
2218be2
add ir_grad of static_gradient
xiaoguoguo626807 Oct 9, 2023
b190b2f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 9, 2023
2ce9d92
Merge commit 'refs/pull/57929/head' of https://github.com/PaddlePaddl…
xiaoguoguo626807 Oct 9, 2023
02040b1
add test
xiaoguoguo626807 Oct 9, 2023
615c487
modify bug
xiaoguoguo626807 Oct 9, 2023
b48c163
modify
xiaoguoguo626807 Oct 10, 2023
3932dda
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 10, 2023
c96c27c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 11, 2023
f3c6854
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 11, 2023
17f00d9
add mean fill_constant test
xiaoguoguo626807 Oct 11, 2023
15eb924
modify cpu int32 test
xiaoguoguo626807 Oct 12, 2023
520840e
get_shape_tensor
xiaoguoguo626807 Oct 13, 2023
5200fff
fix conflict
xiaoguoguo626807 Oct 13, 2023
ce2e736
delete
xiaoguoguo626807 Oct 13, 2023
678b046
add default place
xiaoguoguo626807 Oct 13, 2023
8ff657c
fix conflict
xiaoguoguo626807 Oct 16, 2023
fc76f73
modify grad
xiaoguoguo626807 Oct 16, 2023
538ef4f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 19, 2023
438ab52
modify call_vjp
xiaoguoguo626807 Oct 20, 2023
d08befc
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Oct 20, 2023
20b0273
fix name
xiaoguoguo626807 Oct 20, 2023
41f7567
Merge branch 'call_vjp' of https://github.com/xiaoguoguo626807/Paddle…
xiaoguoguo626807 Oct 20, 2023
46547cb
new ci
xiaoguoguo626807 Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
}"""

OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<pir::OpResult>> {op_class_name}::Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
std::vector<std::vector<pir::OpResult>> {op_class_name}::Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& inputs_, const std::vector<std::vector<pir::OpResult>>& outputs, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); (void)op_obj;

VLOG(6) << "Prepare inputs of {op_grad_name}";
Expand Down Expand Up @@ -317,7 +317,7 @@ def gen_exclusive_interface_str(op_info, op_info_items):
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] not in vjp_interface_black_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& inputs_, const std::vector<std::vector<pir::OpResult>>& outputs, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
if op_info.op_phi_name[0] in decomp_interface_declare_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation* op);"
return exclusive_interface_str
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/interface/interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace paddle {
namespace dialect {
std::vector<std::vector<pir::OpResult>> VjpInterface::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<pir::Value>> out_grads_value;
Expand All @@ -30,7 +32,7 @@ std::vector<std::vector<pir::OpResult>> VjpInterface::Vjp(
}
out_grads_value.emplace_back(std::move(grad_value));
}
return impl_->vjp_(op, out_grads_value, stop_gradients);
return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients);
}
} // namespace dialect
} // namespace paddle
Expand Down
14 changes: 12 additions & 2 deletions paddle/fluid/pir/dialect/operator/interface/vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
struct Concept {
explicit Concept(std::vector<std::vector<pir::OpResult>> (*vjp)(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients))
: vjp_(vjp) {}
std::vector<std::vector<pir::OpResult>> (*vjp_)(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);
};
Expand All @@ -35,9 +39,11 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
struct Model : public Concept {
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
return ConcreteOp::Vjp(op, out_grads, stop_gradients);
return ConcreteOp::Vjp(op, inputs, outputs, out_grads, stop_gradients);
}

Model() : Concept(Vjp) {}
Expand All @@ -49,13 +55,17 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {

std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
return impl_->vjp_(op, out_grads, stop_gradients);
return impl_->vjp_(op, inputs, outputs, out_grads, stop_gradients);
}

std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class AddNOp : public pir::Op<AddNOp,
static void InferMeta(phi::InferMetaContext *infer_meta);
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
};
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ using IntArray = paddle::experimental::IntArray;

std::vector<std::vector<pir::OpResult>> AddNOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
AddNOp op_obj = op->dyn_cast<AddNOp>();
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,8 @@ void BindVjp(pybind11::module *m) {
m->def(
"call_vjp",
[](pir::Operation &fwd_op,
const std::vector<std::vector<pir::Value>> &inputs,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::OpResult>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients) {
py::list res;
Expand All @@ -704,8 +706,8 @@ void BindVjp(pybind11::module *m) {
vjp_interface,
phi::errors::InvalidArgument(
"The vjp function is not registered in %s op ", fwd_op.name()));
std::vector<std::vector<pir::OpResult>> vjp_res =
vjp_interface.Vjp(&fwd_op, out_grads, stop_gradients);
std::vector<std::vector<pir::OpResult>> vjp_res = vjp_interface.Vjp(
&fwd_op, inputs, outputs, out_grads, stop_gradients);
PADDLE_ENFORCE_EQ(
stop_gradients.size(),
vjp_res.size(),
Expand Down
44 changes: 30 additions & 14 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def append_backward_ops(
if op has grad_op, prepare its grad_op's inputs by value_to_valuegrad,
eg:
value_to_valuegrad[v3] = [[v3_g]];
v2_g = call_vjp(op3, [v3_g], [v2_stopgradient])
v2_g = call_vjp(op3, [[v2]], [[v3]],[[v3_g]], [[v2_stopgradient]])


special pattern 1:
Expand All @@ -339,7 +339,7 @@ def append_backward_ops(

v1 is inside python api, we don't describe it in backward process(state)
so v1_grad is inside vjp, we don't describe it in backward process(state)
[[v11_g, v12_g], v2_g] = call_vjp(combine_op, [v3_g], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient)
[[v11_g, v12_g], v2_g] = call_vjp(combine_op, [[v11, v12]], [[v3]],[[v3_g]], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient])


op_vjp is:
Expand All @@ -358,8 +358,9 @@ def append_backward_ops(
else continue to next op.
'''

def make_output_grad(op):
def make_output_with_output_grad(op):
zero_flag = [False] * op.num_results()
outputs = []
output_grads = []
for i, value in enumerate(op.results()):
if (
Expand Down Expand Up @@ -396,12 +397,15 @@ def make_output_grad(op):
# pattern case:
# this fwd_op's output is vectorType, it will split to
# Type by builtin.split op, so need get from split op's ouput
split_zero_flag, split_output_grad = make_output_grad(
value.first_use().owner()
)
(
split_zero_flag,
split_outputs,
split_output_grad,
) = make_output_with_output_grad(value.first_use().owner())
zero_flag[i] = all(split_zero_flag)
grad_values = [value[0] for value in split_output_grad]
state.value_to_valuegrad[value] = [grad_values]
outputs.append([info[0] for info in split_outputs])
else:
# first case:
# this fwd_op's output didn't used by other fwd_op,
Expand All @@ -424,16 +428,19 @@ def make_output_grad(op):

state.value_to_valuegrad[value] = [[grad_value]]

outputs.append([value])
output_grads.append(state.value_to_valuegrad[value][0])

return zero_flag, output_grads
return zero_flag, outputs, output_grads

def make_input_stopgradient(op):
def make_input_with_input_stopgradient(op):
inputs = []
input_grad_stopgradients = []
if op.name() == "builtin.combine":
grad_semantic_info = [True for _ in range(op.num_operands())]
else:
grad_semantic_info = op.get_input_grad_semantics()

for input, grad_semantic in zip(
op.operands_source(), grad_semantic_info
):
Expand All @@ -443,16 +450,22 @@ def make_input_stopgradient(op):
input.get_defining_op() is not None
and input.get_defining_op().name() == "builtin.combine"
):
stop_gradient = make_input_stopgradient(input.get_defining_op())
(
combine_inputs,
combine_stop_gradient,
) = make_input_with_input_stopgradient(input.get_defining_op())
inputs.append([info[0] for info in combine_inputs])
input_grad_stopgradients.append(
[info[0] for info in stop_gradient]
[info[0] for info in combine_stop_gradient]
)
else:
inputs.append([input])
if input.get_defining_op() is None or input in no_grad_set:
input_grad_stopgradients.append([True])
else:
input_grad_stopgradients.append([False])
return input_grad_stopgradients

return inputs, input_grad_stopgradients

def update_input_grad_map(op, input_grads):
i = 0
Expand Down Expand Up @@ -494,20 +507,23 @@ def update_input_grad_map(op, input_grads):
for op in clear_effective_forward_ops:
if paddle.framework.core.has_vjp(op):
# prepare output_grad
zero_flag, output_grads = make_output_grad(op)
zero_flag, outputs, output_grads = make_output_with_output_grad(op)

# all(zero_flag) support this op has no contribution for grad
# should be delete (prune sub_graph)
if len(output_grads) == 0 or all(zero_flag):
continue

# prepare input_grad stop_gradient info.
input_grad_stopgradients = make_input_stopgradient(op)
(
inputs,
input_grad_stopgradients,
) = make_input_with_input_stopgradient(op)

# create grad_op
before_ops_num = len(block.ops)
input_grads = paddle.framework.core.call_vjp(
op, output_grads, input_grad_stopgradients
op, inputs, outputs, output_grads, input_grad_stopgradients
)
after_ops_num = len(block.ops)

Expand Down
34 changes: 27 additions & 7 deletions test/cpp/prim/test_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ TEST(VJP, TanhBackwardTest) {
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<pir::Value>> inputs{{op1.out()}};
std::vector<std::vector<pir::OpResult>> outputs{{op2.out()}};
std::vector<std::vector<pir::Value>> out_grads{{op3.out()}};

pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh");
auto tanh_vjp_interface_impl =
op2_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients);
tanh_vjp_interface_impl->vjp_(
op2.operation(), inputs, outputs, out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

Expand Down Expand Up @@ -114,12 +117,15 @@ TEST(VJP, Tanh_BackwardTest) {
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<pir::Value>> inputs{{op1.out()}};
std::vector<std::vector<pir::OpResult>> outputs{{op2.out()}};
std::vector<std::vector<pir::Value>> out_grads{{op3.out()}};

pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh_");
auto tanh_vjp_interface_impl =
op2_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients);
tanh_vjp_interface_impl->vjp_(
op2.operation(), inputs, outputs, out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

Expand Down Expand Up @@ -169,12 +175,15 @@ TEST(VJP, MeanBackwardTest) {
std::vector<int64_t>{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<pir::Value>> inputs{{op1.out()}};
std::vector<std::vector<pir::OpResult>> outputs{{op2.out()}};
std::vector<std::vector<pir::Value>> out_grads{{op3.out()}};

pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.mean");
auto mean_vjp_interface_impl =
op2_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
mean_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients);
mean_vjp_interface_impl->vjp_(
op2.operation(), inputs, outputs, out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

Expand Down Expand Up @@ -227,11 +236,14 @@ TEST(VJP, ConcatBackwardTest) {
paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<bool>> stop_gradients{{false, false}};
std::vector<std::vector<pir::Value>> inputs{{op1.out(), op1.out()}};
std::vector<std::vector<pir::OpResult>> outputs{{op3.out()}};
std::vector<std::vector<pir::Value>> out_grads{{op4.out()}};
pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.concat");
auto concat_vjp_interface_impl =
op2_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
concat_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients);
concat_vjp_interface_impl->vjp_(
op3.operation(), inputs, outputs, out_grads, stop_gradients);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

auto place = platform::CPUPlace();
Expand Down Expand Up @@ -291,12 +303,15 @@ TEST(VJP, AddBackwardTest) {
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}, {false}};
std::vector<std::vector<pir::Value>> inputs{{op1.out(), op2.out()}};
std::vector<std::vector<pir::OpResult>> outputs{{op3.out()}};
std::vector<std::vector<pir::Value>> out_grads{{op4.out()}};

pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add");
auto add_vjp_interface_impl =
op3_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
add_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients);
add_vjp_interface_impl->vjp_(
op3.operation(), inputs, outputs, out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

Expand Down Expand Up @@ -356,13 +371,15 @@ TEST(VJP, Add_BackwardTest) {
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}, {false}};
std::vector<std::vector<pir::Value>> inputs{{op1.out(), op2.out()}};
std::vector<std::vector<pir::OpResult>> outputs{{op3.out()}};
std::vector<std::vector<pir::Value>> out_grads{{op4.out()}};

pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add_");
auto add_inplace_vjp_interface_impl =
op3_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
add_inplace_vjp_interface_impl->vjp_(
op3.operation(), out_grads, stop_gradients);
op3.operation(), inputs, outputs, out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

Expand Down Expand Up @@ -422,13 +439,16 @@ TEST(VJP, SplitBackwardTest) {
std::vector<int64_t>{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<pir::Value>> inputs{{op1.out()}};
std::vector<std::vector<pir::OpResult>> outputs{{op3.outputs()}};
std::vector<std::vector<pir::Value>> out_grads{{op3.result(0), op4.out()}};
pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.split");

auto concat_vjp_interface_impl =
op2_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();

concat_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients);
concat_vjp_interface_impl->vjp_(
op2.operation(), inputs, outputs, out_grads, stop_gradients);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

auto place = platform::CPUPlace();
Expand Down
Loading