From 4b21c66d001f05fe9dd2aa9730646010d4671097 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Thu, 10 Aug 2023 11:42:08 +0000 Subject: [PATCH 01/30] support ir api form prim --- paddle/fluid/ir/dialect/pd_api.cc | 14 +++++ paddle/fluid/ir/dialect/pd_api.h | 4 ++ .../fluid/primitive/backend/static_backend.cc | 51 +++++++++++++++++++ .../fluid/primitive/backend/static_backend.h | 22 ++++++++ 4 files changed, 91 insertions(+) diff --git a/paddle/fluid/ir/dialect/pd_api.cc b/paddle/fluid/ir/dialect/pd_api.cc index df88dd9cc7348..e7bc5ae5b7124 100644 --- a/paddle/fluid/ir/dialect/pd_api.cc +++ b/paddle/fluid/ir/dialect/pd_api.cc @@ -63,6 +63,20 @@ ir::OpResult full(std::vector shape, return full_op.out(); } +ir::OpResult reshape(ir::OpResult x, std::vector shape) { + paddle::dialect::ReshapeOp reshape_op = + APIBuilder::Instance().GetBuilder()->Build( + x, shape); + return reshape_op.out(); +} + +ir::OpResult tile(ir::OpResult x, std::vector repeat_times) { + paddle::dialect::TileOp tile_op = + APIBuilder::Instance().GetBuilder()->Build( + x, repeat_times); + return tile_op.out(); +} + ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) { paddle::dialect::TanhGradOp tanh_grad_op = APIBuilder::Instance().GetBuilder()->Build( diff --git a/paddle/fluid/ir/dialect/pd_api.h b/paddle/fluid/ir/dialect/pd_api.h index a44c8bb83a76a..df9db936ba85b 100644 --- a/paddle/fluid/ir/dialect/pd_api.h +++ b/paddle/fluid/ir/dialect/pd_api.h @@ -35,6 +35,10 @@ ir::OpResult sum(ir::OpResult x, ir::OpResult divide(ir::OpResult x, ir::OpResult y); +ir::OpResult reshape(ir::OpResult x, std::vector shape); + +ir::OpResult tile(ir::OpResult x, std::vector repeat_times = {}); + ir::OpResult full(std::vector shape, float value, phi::DataType dtype = phi::DataType::FLOAT32, diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index b0a515c0d75af..30fcfa36404c8 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -59,6 +59,57 @@ Tensor mean_grad(const Tensor& x, return Tensor(std::make_shared(op_res)); } +template <> +Tensor divide(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::divide(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor sum(const Tensor& x, + std::vector axis, + phi::DataType dtype, + bool keepdim) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::sum(x_res, axis, dtype, keepdim); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor full(std::vector shape, + float value, + phi::DataType dtype, + phi::Place place) { + ir::OpResult op_res = paddle::dialect::full(shape, value, dtype, place); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor reshape(const Tensor& x, std::vector shape) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::reshape(x_res, shape); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor tile(const Tensor& x, std::vector repeat_times) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times); + return Tensor(std::make_shared(op_res)); +} + } // namespace experimental } // namespace backend } // namespace primitive diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index bd1fb737b8658..38e4bc0da7fdf 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -35,6 +35,28 @@ Tensor mean_grad(const Tensor& x, std::vector axis = {}, bool keepdim = false, bool reduce_all = false); + +template +Tensor divide(const Tensor& x, const Tensor& y); + +template +Tensor sum(const Tensor& x, + std::vector axis = {}, + phi::DataType dtype = phi::DataType::UNDEFINED, + bool keepdim = false); + +template +Tensor full(std::vector shape, + float value, + phi::DataType dtype = phi::DataType::FLOAT32, + phi::Place place = phi::CPUPlace()); + +template +Tensor reshape(const Tensor& x, std::vector shape); + +template +Tensor tile(const Tensor& x, std::vector repeat_times = {}); + } // namespace experimental } // namespace backend } // namespace primitive From 37125f9f807c258ee20cf67e09eb3af056751aa3 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Thu, 10 Aug 2023 12:08:45 +0000 Subject: [PATCH 02/30] convert vector of int to intarray --- paddle/fluid/ir/dialect/pd_api.cc | 7 +++++ paddle/fluid/ir/dialect/pd_api.h | 2 ++ .../fluid/primitive/backend/static_backend.cc | 29 +++++++++++++------ .../fluid/primitive/backend/static_backend.h | 13 +++++---- 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_api.cc b/paddle/fluid/ir/dialect/pd_api.cc index e7bc5ae5b7124..09e71cc958b84 100644 --- a/paddle/fluid/ir/dialect/pd_api.cc +++ b/paddle/fluid/ir/dialect/pd_api.cc @@ -70,6 +70,13 @@ ir::OpResult reshape(ir::OpResult x, std::vector shape) { return reshape_op.out(); } +ir::OpResult expand(ir::OpResult x, std::vector shape) { + paddle::dialect::ExpandOp expand_op = + APIBuilder::Instance().GetBuilder()->Build( + x, shape); + return expand_op.out(); +} + ir::OpResult tile(ir::OpResult x, std::vector repeat_times) { paddle::dialect::TileOp tile_op = APIBuilder::Instance().GetBuilder()->Build( diff --git a/paddle/fluid/ir/dialect/pd_api.h b/paddle/fluid/ir/dialect/pd_api.h index df9db936ba85b..acf679239a285 100644 --- a/paddle/fluid/ir/dialect/pd_api.h +++ b/paddle/fluid/ir/dialect/pd_api.h @@ -37,6 +37,8 @@ ir::OpResult divide(ir::OpResult x, ir::OpResult y); ir::OpResult reshape(ir::OpResult x, std::vector shape); +ir::OpResult expand(ir::OpResult x, std::vector shape = {}); + ir::OpResult tile(ir::OpResult x, std::vector repeat_times = {}); ir::OpResult full(std::vector shape, diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index 30fcfa36404c8..3cd83c3958aea 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -73,40 +73,51 @@ Tensor divide(const Tensor& x, const Tensor& y) { template <> Tensor sum(const Tensor& x, - std::vector axis, + const IntArray& axis, phi::DataType dtype, bool keepdim) { ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); - ir::OpResult op_res = paddle::dialect::sum(x_res, axis, dtype, keepdim); + ir::OpResult op_res = + paddle::dialect::sum(x_res, axis.GetData(), dtype, keepdim); return Tensor(std::make_shared(op_res)); } template <> -Tensor full(std::vector shape, - float value, +Tensor full(const IntArray& shape, + const Scalar& value, phi::DataType dtype, phi::Place place) { - ir::OpResult op_res = paddle::dialect::full(shape, value, dtype, place); + ir::OpResult op_res = + paddle::dialect::full(shape.GetData(), value.to(), dtype, place); return Tensor(std::make_shared(op_res)); } template <> -Tensor reshape(const Tensor& x, std::vector shape) { +Tensor reshape(const Tensor& x, const IntArray& shape) { ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); - ir::OpResult op_res = paddle::dialect::reshape(x_res, shape); + ir::OpResult op_res = paddle::dialect::reshape(x_res, shape.GetData()); return Tensor(std::make_shared(op_res)); } template <> -Tensor tile(const Tensor& x, std::vector repeat_times) { +Tensor expand(const Tensor& x, const IntArray& shape) { ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); - ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times); + ir::OpResult op_res = paddle::dialect::expand(x_res, shape.GetData()); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor tile(const Tensor& x, const IntArray& repeat_times) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times.GetData()); return Tensor(std::make_shared(op_res)); } diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index 38e4bc0da7fdf..95f689751bc7a 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -41,21 +41,24 @@ Tensor divide(const Tensor& x, const Tensor& y); template Tensor sum(const Tensor& x, - std::vector axis = {}, + const IntArray& axis = {}, phi::DataType dtype = phi::DataType::UNDEFINED, bool keepdim = false); template -Tensor full(std::vector shape, - float value, +Tensor full(const IntArray& shape, + const Scalar& value, phi::DataType dtype = phi::DataType::FLOAT32, phi::Place place = phi::CPUPlace()); template -Tensor reshape(const Tensor& x, std::vector shape); +Tensor reshape(const Tensor& x, const IntArray& shape); template -Tensor tile(const Tensor& x, std::vector repeat_times = {}); +Tensor expand(const Tensor& x, const IntArray& shape); + +template +Tensor tile(const Tensor& x, const IntArray& repeat_times = {}); } // namespace experimental } // namespace backend From 2c0166cee77679dcd934e1bf65cb923159d6bcd1 Mon Sep 17 00:00:00 2001 From: wangruting Date: Fri, 11 Aug 2023 07:40:44 +0000 Subject: [PATCH 03/30] add reference of lbfgs --- python/paddle/optimizer/lbfgs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/optimizer/lbfgs.py b/python/paddle/optimizer/lbfgs.py index d309e67772b4c..b44ef80f3cfbd 100644 --- a/python/paddle/optimizer/lbfgs.py +++ b/python/paddle/optimizer/lbfgs.py @@ -143,6 +143,8 @@ def _strong_wolfe( a_lo = aj; end(repeat) + + referance: https://github.com/pytorch/pytorch """ d_norm = d.abs().max() @@ -275,7 +277,6 @@ def _strong_wolfe( # Armijo condition not satisfied or not lower than lowest point bracket[high_pos] = alpha bracket_f[high_pos] = loss_new - # bracket_g[high_pos] = grad_new.clone(memory_format=torch.contiguous_format) bracket_g[high_pos] = grad_new.clone() bracket_gtd[high_pos] = gtd_new low_pos, high_pos = ( From 37883b2697816d0af4c1401ead3851f6068bed51 Mon Sep 17 00:00:00 2001 From: wangruting Date: Fri, 11 Aug 2023 07:44:02 +0000 Subject: [PATCH 04/30] add reference of lbfgs --- python/paddle/optimizer/lbfgs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/optimizer/lbfgs.py b/python/paddle/optimizer/lbfgs.py index b44ef80f3cfbd..7afd59e65ade1 100644 --- a/python/paddle/optimizer/lbfgs.py +++ b/python/paddle/optimizer/lbfgs.py @@ -144,7 +144,7 @@ def _strong_wolfe( a_lo = aj; end(repeat) - referance: https://github.com/pytorch/pytorch + reference: https://github.com/pytorch/pytorch """ d_norm = d.abs().max() From 8834e65007afe00e2bf26c5fc54747ebd81f6590 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Thu, 10 Aug 2023 11:42:08 +0000 Subject: [PATCH 05/30] support ir api for prim --- paddle/fluid/ir/dialect/pd_api.cc | 52 ++++++++ paddle/fluid/ir/dialect/pd_api.h | 17 +++ .../fluid/primitive/backend/static_backend.cc | 111 ++++++++++++++++++ .../fluid/primitive/backend/static_backend.h | 40 +++++++ 4 files changed, 220 insertions(+) diff --git a/paddle/fluid/ir/dialect/pd_api.cc b/paddle/fluid/ir/dialect/pd_api.cc index df88dd9cc7348..e5b0c702df39f 100644 --- a/paddle/fluid/ir/dialect/pd_api.cc +++ b/paddle/fluid/ir/dialect/pd_api.cc @@ -53,6 +53,37 @@ ir::OpResult divide(ir::OpResult x, ir::OpResult y) { return divide_op.out(); } +ir::OpResult add(ir::OpResult x, ir::OpResult y) { + paddle::dialect::AddOp add_op = + APIBuilder::Instance().GetBuilder()->Build(x, y); + return add_op.out(); +} + +ir::OpResult multiply(ir::OpResult x, ir::OpResult y) { + paddle::dialect::MultiplyOp multiply_op = + APIBuilder::Instance().GetBuilder()->Build( + x, y); + return multiply_op.out(); +} + +ir::OpResult elementwise_pow(ir::OpResult x, ir::OpResult y) { + paddle::dialect::ElementwisePowOp elementwise_pow_op = + APIBuilder::Instance() + .GetBuilder() + ->Build(x, y); + return elementwise_pow_op.out(); +} + +ir::OpResult scale(ir::OpResult x, + float scale, + float bias, + bool bias_after_scale) { + paddle::dialect::ScaleOp scale_op = + APIBuilder::Instance().GetBuilder()->Build( + x, scale, bias, bias_after_scale); + return scale_op.out(); +} + ir::OpResult full(std::vector shape, float value, phi::DataType dtype, @@ -63,6 +94,27 @@ ir::OpResult full(std::vector shape, return full_op.out(); } +ir::OpResult reshape(ir::OpResult x, std::vector shape) { + paddle::dialect::ReshapeOp reshape_op = + APIBuilder::Instance().GetBuilder()->Build( + x, shape); + return reshape_op.out(); +} + +ir::OpResult expand(ir::OpResult x, std::vector shape) { + paddle::dialect::ExpandOp expand_op = + APIBuilder::Instance().GetBuilder()->Build( + x, shape); + return expand_op.out(); +} + +ir::OpResult tile(ir::OpResult x, std::vector repeat_times) { + paddle::dialect::TileOp tile_op = + APIBuilder::Instance().GetBuilder()->Build( + x, repeat_times); + return tile_op.out(); +} + ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) { paddle::dialect::TanhGradOp tanh_grad_op = APIBuilder::Instance().GetBuilder()->Build( diff --git a/paddle/fluid/ir/dialect/pd_api.h b/paddle/fluid/ir/dialect/pd_api.h index a44c8bb83a76a..942298753e385 100644 --- a/paddle/fluid/ir/dialect/pd_api.h +++ b/paddle/fluid/ir/dialect/pd_api.h @@ -33,8 +33,25 @@ ir::OpResult sum(ir::OpResult x, phi::DataType dtype = phi::DataType::UNDEFINED, bool keepdim = false); +ir::OpResult add(ir::OpResult x, ir::OpResult y); + ir::OpResult divide(ir::OpResult x, ir::OpResult y); +ir::OpResult multiply(ir::OpResult x, ir::OpResult y); + +ir::OpResult elementwise_pow(ir::OpResult x, ir::OpResult y); + +ir::OpResult scale(ir::OpResult x, + float scale = 1.0, + float bias = 0.0, + bool bias_after_scale = true); + +ir::OpResult reshape(ir::OpResult x, std::vector shape); + +ir::OpResult expand(ir::OpResult x, std::vector shape = {}); + +ir::OpResult tile(ir::OpResult x, std::vector repeat_times = {}); + ir::OpResult full(std::vector shape, float value, phi::DataType dtype = phi::DataType::FLOAT32, diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index b0a515c0d75af..710251f2094ee 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -59,6 +59,117 @@ Tensor mean_grad(const Tensor& x, return Tensor(std::make_shared(op_res)); } +template <> +Tensor divide(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::divide(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor add(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::add(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor multiply(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::multiply(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor elementwise_pow(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::elementwise_pow(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor scale(const Tensor& x, + const Scalar& scale, + float bias, + bool bias_after_scale) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = + paddle::dialect::scale(x_res, scale.to(), bias, bias_after_scale); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor sum(const Tensor& x, + const IntArray& axis, + phi::DataType dtype, + bool keepdim) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = + paddle::dialect::sum(x_res, axis.GetData(), dtype, keepdim); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor full(const IntArray& shape, + const Scalar& value, + phi::DataType dtype, + phi::Place place) { + ir::OpResult op_res = + paddle::dialect::full(shape.GetData(), value.to(), dtype, place); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor reshape(const Tensor& x, const IntArray& shape) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::reshape(x_res, shape.GetData()); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor expand(const Tensor& x, const IntArray& shape) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::expand(x_res, shape.GetData()); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor tile(const Tensor& x, const IntArray& repeat_times) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times.GetData()); + return Tensor(std::make_shared(op_res)); +} + } // namespace experimental } // namespace backend } // namespace primitive diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index bd1fb737b8658..131674ea274bf 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -35,6 +35,46 @@ Tensor mean_grad(const Tensor& x, std::vector axis = {}, bool keepdim = false, bool reduce_all = false); + +template +Tensor divide(const Tensor& x, const Tensor& y); + +template +Tensor add(const Tensor& x, const Tensor& y); + +template +Tensor multiply(const Tensor& x, const Tensor& y); + +template +Tensor elementwise_pow(const Tensor& x, const Tensor& y); + +template +Tensor scale(const Tensor& x, + const Scalar& scale = 1.0, + float bias = 0.0, + bool bias_after_scale = true); + +template +Tensor sum(const Tensor& x, + const IntArray& axis = {}, + phi::DataType dtype = phi::DataType::UNDEFINED, + bool keepdim = false); + +template +Tensor full(const IntArray& shape, + const Scalar& value, + phi::DataType dtype = phi::DataType::FLOAT32, + phi::Place place = phi::CPUPlace()); + +template +Tensor reshape(const Tensor& x, const IntArray& shape); + +template +Tensor expand(const Tensor& x, const IntArray& shape); + +template +Tensor tile(const Tensor& x, const IntArray& repeat_times = {}); + } // namespace experimental } // namespace backend } // namespace primitive From 0cc4336a5437e9d53f17762c3d98605f091b6626 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Tue, 15 Aug 2023 03:36:38 +0000 Subject: [PATCH 06/30] Add more gen api --- .../fluid/ir/dialect/op_generator/api_gen.py | 97 ++++++++++++++----- .../fluid/ir/dialect/op_generator/op_gen.py | 2 +- 2 files changed, 76 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/api_gen.py b/paddle/fluid/ir/dialect/op_generator/api_gen.py index 4dbe7d33540c9..7680ddfb1228c 100644 --- a/paddle/fluid/ir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/api_gen.py @@ -14,6 +14,7 @@ import argparse import os +import re import yaml from op_gen import OpCompatParser, OpInfoParser, to_pascal_case @@ -27,6 +28,7 @@ #include "paddle/ir/core/value.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/common/scalar.h" {body} @@ -62,18 +64,38 @@ {in_combine} {compute_op} {out_slice} - {out_combine} {return_result} }} """ -COMBINE_OP_TEMPLATE = """auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" - -COMPUTE_OP_TEMPLATE = """paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build({args});""" - - -API_LIST = ['add_n', 'mean', 'sum', 'divide', 'full', 'tanh_grad', 'mean_grad'] +COMBINE_OP_TEMPLATE = """ + auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" + +SLICE_OP_TEMPLATE = """ + auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" + +COMPUTE_OP_TEMPLATE = """ + paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build({args});""" + +API_LIST = [ + 'add_n', + 'mean', + 'sum', + 'divide', + 'full', + 'tanh_grad', + 'mean_grad', + 'concat', + 'add', + 'multiply', + 'elementwise_pow', + 'scale', + 'reshape', + 'expand', + 'tile', + 'add_grad', +] OP_RESULT = 'ir::OpResult' VECTOR_TYPE = 'ir::VectorType' @@ -86,6 +108,7 @@ class CodeGen: def __init__(self) -> None: self._type_map = { 'paddle::dialect::DenseTensorType': 'ir::OpResult', + 'paddle::dialect::SelectedRowsType': 'ir::OpResult', 'ir::VectorType': 'std::vector', } @@ -126,6 +149,8 @@ def _gen_api_attrs(self, op_info, with_default): name_list, type_list, default_value_list ): if with_default and default_value is not None: + if type in ['float', 'double']: + default_value = default_value.strip('"') ret.append( '{type} {name} = {default_value}'.format( type=type, name=name, default_value=default_value @@ -140,9 +165,19 @@ def _gen_api_args(self, op_info, with_default_attr): attrs = self._gen_api_attrs(op_info, with_default_attr) return (inputs + ', ' + attrs).strip(', ') + def _gen_ret_type(self, op_info): + type_list = op_info.output_type_list + assert len(type_list) >= 1 + if len(type_list) > 1: + return 'std::tuple<{}>'.format( + ', '.join([self._type_map[type] for type in type_list]) + ) + elif len(type_list) == 1: + return self._type_map[type_list[0]] + def _gen_one_declare(self, op_info, op_name): return API_DECLARE_TEMPLATE.format( - ret_type=OP_RESULT, + ret_type=self._gen_ret_type(op_info), api_name=op_name, args=self._gen_api_args(op_info, True), ) @@ -205,33 +240,51 @@ def _gen_compute_op(self, op_info, op_name, in_combine_op_list): op_inst_name, ) - def _gen_out_slice(self): - return '' + def _gen_out_slice_and_ret_list(self, op_info, op_inst_name): + name_list = op_info.output_name_list + type_list = op_info.output_type_list - def _gen_out_combine(self): - return '' + slice_op_str = '' + ret_list = [] + for i, (name, type) in enumerate(zip(name_list, type_list)): + if VECTOR_TYPE in type: + slice_op_name = f'{name}_slice_op' + slice_op_str += SLICE_OP_TEMPLATE.format( + op_name=slice_op_name, in_name=f'{op_inst_name}.result({i})' + ) + ret_list.append(f'{slice_op_name}.outputs()') + else: + ret_list.append(f'{op_inst_name}.result({i})') + return slice_op_str, ret_list - def _gen_return_result(self, op_info, op_inst_name): - output_name_list = op_info.output_name_list - assert len(output_name_list) == 1 - return f'return {op_inst_name}.result(0);' + def _gen_return_result(self, ret_list): + assert len(ret_list) >= 1 + if len(ret_list) > 1: + return 'return std::make_tuple({});'.format(', '.join(ret_list)) + else: + return f'return {ret_list[0]};' def _gen_one_impl(self, op_info, op_name): in_combine, in_combine_op_list = self._gen_in_combine(op_info) compute_op, op_inst_name = self._gen_compute_op( op_info, op_name, in_combine_op_list ) + out_slice, ret_list = self._gen_out_slice_and_ret_list( + op_info, op_inst_name + ) - return API_IMPL_TEMPLATE.format( - ret_type=OP_RESULT, + ret = API_IMPL_TEMPLATE.format( + ret_type=self._gen_ret_type(op_info), api_name=op_name, args=self._gen_api_args(op_info, False), in_combine=in_combine, compute_op=compute_op, - out_slice=self._gen_out_slice(), - out_combine=self._gen_out_combine(), - return_result=self._gen_return_result(op_info, op_inst_name), - ).replace(' \n', '') + out_slice=out_slice, + return_result=self._gen_return_result(ret_list), + ) + + ret = re.sub(r' +\n', '', ret) + return ret def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path): impl_str = '' diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index a204d64b00f48..5bbb5c80c0693 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -259,7 +259,7 @@ def __init__(self, op_yaml_item, op_compat_item): 'bool': ['ir::BoolAttribute', 'bool'], 'bool[]': [ 'ir::ArrayAttribute', - 'const std::vecot&', + 'const std::vector&', ], 'str': ['ir::StrAttribute', 'const std::string&'], 'str[]': [ From 0afe2edb26d71b5f40718cc94750ddb6cf26e973 Mon Sep 17 00:00:00 2001 From: wangruting Date: Tue, 15 Aug 2023 08:22:06 +0000 Subject: [PATCH 07/30] concat python api to concat_grad --- .../op_generator/vjp_interface_gen_op_list.py | 2 +- .../dialect/{pd_api.cc => pd_manual_api.cc} | 24 +++++++++- .../ir/dialect/{pd_api.h => pd_manual_api.h} | 6 +++ paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 33 +++++++++++++ .../fluid/primitive/backend/static_backend.cc | 33 ++++++++++++- .../fluid/primitive/backend/static_backend.h | 5 ++ paddle/fluid/primitive/rule/vjp/vjp.cc | 47 ++++++++++++++++++- paddle/fluid/primitive/rule/vjp/vjp.h | 6 +++ paddle/fluid/pybind/ops_api.cc | 8 ++++ paddle/fluid/pybind/static_op_function.cc | 22 ++++++++- paddle/fluid/pybind/static_op_function.h | 1 + paddle/ir/core/builtin_op.cc | 12 +++++ paddle/ir/core/builtin_op.h | 21 ++++++++- python/paddle/tensor/manipulation.py | 4 ++ test/ir/new_ir/test_build_op.py | 20 ++++++++ 15 files changed, 238 insertions(+), 6 deletions(-) rename paddle/fluid/ir/dialect/{pd_api.cc => pd_manual_api.cc} (76%) rename paddle/fluid/ir/dialect/{pd_api.h => pd_manual_api.h} (87%) diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 3201651e4696c..674d7e790b090 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -21,4 +21,4 @@ # TODO(wanghao107) # remove this file and support Vjp methods # code gen. -vjp_interface_gen_op_list = ["tanh", "mean"] +vjp_interface_gen_op_list = ["tanh", "mean", "concat"] diff --git a/paddle/fluid/ir/dialect/pd_api.cc b/paddle/fluid/ir/dialect/pd_manual_api.cc similarity index 76% rename from paddle/fluid/ir/dialect/pd_api.cc rename to paddle/fluid/ir/dialect/pd_manual_api.cc index 6405f7dce7e80..c4832a3465844 100644 --- a/paddle/fluid/ir/dialect/pd_api.cc +++ b/paddle/fluid/ir/dialect/pd_manual_api.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/pd_api.h" +#include "paddle/fluid/ir/dialect/pd_manual_api.h" #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/ir/core/builder.h" @@ -55,6 +55,15 @@ ir::OpResult divide(ir::OpResult x, ir::OpResult y) { return divide_op.out(); } +ir::OpResult concat(std::vector x, float axis) { + auto combine_op = + APIBuilder::Instance().GetBuilder()->Build(x); + auto concat_op = + APIBuilder::Instance().GetBuilder()->Build( + combine_op.out(), axis); + return concat_op.out(); +} + ir::OpResult full(const std::vector& shape, float value, phi::DataType dtype, @@ -83,5 +92,18 @@ ir::OpResult mean_grad(ir::OpResult x, return mean_grad_op.result(0); } +std::vector concat_grad(std::vector x, + ir::OpResult out_grad, + ir::OpResult axis) { + auto combine_op = + APIBuilder::Instance().GetBuilder()->Build(x); + + paddle::dialect::ConcatGradOp concat_grad_op = + APIBuilder::Instance().GetBuilder()->Build( + combine_op.out(), out_grad, axis); + auto slice_op = APIBuilder::Instance().GetBuilder()->Build( + concat_grad_op.result(0)); + return slice_op.outputs(); +} } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_api.h b/paddle/fluid/ir/dialect/pd_manual_api.h similarity index 87% rename from paddle/fluid/ir/dialect/pd_api.h rename to paddle/fluid/ir/dialect/pd_manual_api.h index 9581e0a4e7ee1..7155fb535ef9e 100644 --- a/paddle/fluid/ir/dialect/pd_api.h +++ b/paddle/fluid/ir/dialect/pd_manual_api.h @@ -40,6 +40,8 @@ ir::OpResult full(const std::vector& shape, phi::DataType dtype = phi::DataType::FLOAT32, const phi::Place& place = phi::CPUPlace()); +ir::OpResult concat(std::vector x, float axis); + ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out); ir::OpResult mean_grad(ir::OpResult x, @@ -47,5 +49,9 @@ ir::OpResult mean_grad(ir::OpResult x, const std::vector& axis = {}, bool keepdim = false, bool reduce_all = false); + +std::vector concat_grad(std::vector x, + ir::OpResult out_grad, + ir::OpResult axis); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index be43ddd60491c..6dbc3e478976c 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/desc_tensor.h" +#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/op_base.h" #include "paddle/phi/common/int_array.h" @@ -98,5 +99,37 @@ std::vector> MeanOp::Vjp( } return res; } + +std::vector> ConcatOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + ConcatOp op_obj = op->dyn_cast(); + ir::CombineOp combine_op_obj = + op_obj.x().GetDefiningOp()->dyn_cast(); + std::vector x; + for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { + x.emplace_back(std::make_shared( + combine_op_obj.inputs()[idx])); + } + + Tensor out_grad( + std::make_shared(out_grads[0][0])); + + Tensor axis( + std::make_shared(op_obj.axis())); + + std::vector> tensor_res = + primitive::experimental::concat_vjp(x, out_grad, axis, stop_gradients); + std::vector> res(1, std::vector(1)); + if (tensor_res[0][0].defined()) { + res[0][0] = std::static_pointer_cast( + tensor_res[0][0].impl()) + ->getValue() + .dyn_cast(); + } + return res; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index b041d3710c25d..c1da28e8145d1 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/primitive/backend/static_backend.h" -#include "paddle/fluid/ir/dialect/pd_api.h" +#include "paddle/fluid/ir/dialect/pd_manual_api.h" #include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/fluid/primitive/type/desc_tensor.h" @@ -59,6 +59,37 @@ Tensor mean_grad(const Tensor& x, return Tensor(std::make_shared(op_res)); } +template <> +std::vector concat_grad(const std::vector& x, + const Tensor& out_grad, + const Tensor& axis) { + std::vector x_res; + for (uint64_t idx = 0; idx < x.size(); idx++) { + x_res.emplace_back(std::static_pointer_cast(x[idx].impl()) + ->getValue() + .dyn_cast()); + } + + ir::OpResult out_grad_res = + std::static_pointer_cast(out_grad.impl()) + ->getValue() + .dyn_cast(); + + ir::OpResult axis_res = std::static_pointer_cast(axis.impl()) + ->getValue() + .dyn_cast(); + + std::vector op_res = + paddle::dialect::concat_grad(x_res, out_grad_res, axis_res); + + std::vector op_result; + for (uint64_t idx = 0; idx < op_res.size(); idx++) { + op_result.emplace_back( + std::make_shared(op_res[idx])); + } + return op_result; +} + } // namespace experimental } // namespace backend } // namespace primitive diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index 09835bb759674..1b65cd62553db 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -37,6 +37,11 @@ Tensor mean_grad(const Tensor& x, const IntArray& axis = {}, bool keepdim = false, bool reduce_all = false); + +template +std::vector concat_grad(const std::vector& x, + const Tensor& out_grad, + const Tensor& axis); } // namespace experimental } // namespace backend } // namespace primitive diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index b5f0acf98c1d8..c1aada25fa80e 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/primitive/rule/vjp/vjp.h" -#include "paddle/fluid/ir/dialect/pd_api.h" +#include "paddle/fluid/ir/dialect/pd_manual_api.h" #include "paddle/fluid/primitive/backend/static_backend.h" #include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/ir/core/operation.h" @@ -111,6 +111,51 @@ std::vector> mean_vjp( return vjp_res; } +std::vector> concat_vjp( + const std::vector& x, + const Tensor& out_grad, + const Tensor& axis, + const std::vector>& stop_gradients) { + std::vector> vjp_res( + 1, std::vector(1)); + // get concat_grad res. + std::vector op_res = + backend::experimental::concat_grad( + x, out_grad, axis); + + // set op stop_gradient info + // TODO(wanghao107): Replace with more generic code. + // Support set stop_gradients for all ops. + ir::Operation* grad_op = + std::static_pointer_cast( + op_res[0].impl()) + ->getValue() + .dyn_cast() + .owner(); + uint32_t num_res = grad_op->num_results(); + std::vector ir_stop_gradients(num_res); + for (size_t i = 0; i < num_res; i++) { + if (stop_gradients[0][i]) { + ir_stop_gradients[i] = + ir::BoolAttribute::get(ir::IrContext::Instance(), true); + } else { + ir_stop_gradients[i] = + ir::BoolAttribute::get(ir::IrContext::Instance(), false); + } + } + grad_op->set_attribute( + "stop_gradient", + ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients)); + + // construct vjp result by op result and stop_gradients info + for (auto idx = 0; idx <= op_res[0].size(); idx++) { + if (!stop_gradients[0][idx]) { + vjp_res[0][idx] = op_res[idx]; + } + } + return vjp_res; +} + } // namespace experimental } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.h b/paddle/fluid/primitive/rule/vjp/vjp.h index 48bc2affa9db4..ab1fcc2f26d2e 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.h +++ b/paddle/fluid/primitive/rule/vjp/vjp.h @@ -46,6 +46,12 @@ std::vector> mean_vjp( bool reduce_all, const std::vector>& stop_gradients); +std::vector> concat_vjp( + const std::vector& x, + const Tensor& out_grad, + const Tensor& axis, + const std::vector>& stop_gradients); + namespace details { // NOTE: this namespace will store // primitive ops grad composite rules. diff --git a/paddle/fluid/pybind/ops_api.cc b/paddle/fluid/pybind/ops_api.cc index 56998d621c736..27cbf36aecbcb 100644 --- a/paddle/fluid/pybind/ops_api.cc +++ b/paddle/fluid/pybind/ops_api.cc @@ -40,6 +40,10 @@ static PyObject *divide(PyObject *self, PyObject *args, PyObject *kwargs) { return static_api_divide(self, args, kwargs); } +static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) { + return static_api_concat(self, args, kwargs); +} + static PyMethodDef OpsAPI[] = {{"add_n", (PyCFunction)(void (*)(void))add_n, METH_VARARGS | METH_KEYWORDS, @@ -56,6 +60,10 @@ static PyMethodDef OpsAPI[] = {{"add_n", (PyCFunction)(void (*)(void))divide, METH_VARARGS | METH_KEYWORDS, "C++ interface function for divide."}, + {"concat", + (PyCFunction)(void (*)(void))concat, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for concat."}, {"full", (PyCFunction)(void (*)(void))full, METH_VARARGS | METH_KEYWORDS, diff --git a/paddle/fluid/pybind/static_op_function.cc b/paddle/fluid/pybind/static_op_function.cc index ad992fab4972f..153a3131f9263 100644 --- a/paddle/fluid/pybind/static_op_function.cc +++ b/paddle/fluid/pybind/static_op_function.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/pybind/static_op_function.h" -#include "paddle/fluid/ir/dialect/pd_api.h" +#include "paddle/fluid/ir/dialect/pd_manual_api.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/op_function_common.h" @@ -109,6 +109,26 @@ PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs) { } } +PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs) { + try { + VLOG(6) << "Add concat op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + // Get OpResult from args + PyObject *x_obj = PyTuple_GET_ITEM(args, 0); + auto x = CastPyArg2VectorOfOpResult("concat", x_obj, 0); + + PyObject *axis_obj = PyTuple_GET_ITEM(args, 1); + paddle::experimental::Scalar axis = CastPyArg2Scalar(axis_obj, "concat", 1); + + // Call ir static api + auto out = paddle::dialect::concat(x, axis.to()); + return ToPyObject(out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) { try { VLOG(6) << "Add full op into program"; diff --git a/paddle/fluid/pybind/static_op_function.h b/paddle/fluid/pybind/static_op_function.h index 22bee5c344837..02d4777eeef05 100644 --- a/paddle/fluid/pybind/static_op_function.h +++ b/paddle/fluid/pybind/static_op_function.h @@ -28,6 +28,7 @@ PyObject *static_api_add_n(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_mean(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_sum(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs); +PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs); } // namespace pybind diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index 8aff2f1f1909c..bb2acd9606399 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -207,6 +207,18 @@ void SliceOp::Verify() const { output_type); } +void SliceOp::Build(Builder &builder, + OperationArgument &argument, + const ir::OpResult &input) { + argument.inputs = {input}; + std::vector outputs_types; + for (size_t idx = 0; idx < input.type().dyn_cast().size(); + ++idx) { + argument.output_types.emplace_back( + input.type().dyn_cast()[idx]); + } +} + const char *ConstantOp::attributes_name[attributes_num] = {"value"}; // NOLINT void ConstantOp::Build(Builder &builder, diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index fe5b7116a29dd..b840fd9cf0e98 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -93,6 +93,13 @@ class IR_API CombineOp : public ir::Op { const std::vector &inputs); void Verify() const; + std::vector inputs() { + std::vector inputs; + for (uint32_t idx = 0; idx < num_operands(); idx++) { + inputs.push_back(operand_source(static_cast(idx))); + } + return inputs; + } ir::OpResult out() { return result(0); } }; @@ -108,8 +115,20 @@ class IR_API SliceOp : public ir::Op { static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const ir::OpResult &input); + void Verify() const; - ir::OpResult out() { return result(0); } + ir::Value input() { return operand_source(0); } + std::vector outputs() { + std::vector outputs; + for (uint32_t idx = 0; idx < num_results(); idx++) { + outputs.push_back(result(static_cast(idx))); + } + return outputs; + } }; class IR_API ConstantLikeTrait : public OpTraitBase { diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 8053c86cba9de..98da3330238a1 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1120,6 +1120,10 @@ def concat(x, axis=0, name=None): input = [t for t in input if t.shape.count(0) == 0] return _C_ops.concat(input, axis) else: + if paddle.ir.core._use_new_ir_api(): + if not isinstance(input, Variable): + input = [t for t in input if t.shape.count(0) == 0] + return paddle._ir_ops.concat(input, axis) check_type(input, 'input', (list, tuple, Variable), 'concat') if not isinstance(input, Variable): for id, x in enumerate(input): diff --git a/test/ir/new_ir/test_build_op.py b/test/ir/new_ir/test_build_op.py index c49b0ae14939c..d2b5fff3fd0e8 100644 --- a/test/ir/new_ir/test_build_op.py +++ b/test/ir/new_ir/test_build_op.py @@ -102,5 +102,25 @@ def test_insertion_point(self): paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) +class TestBuildOp4(unittest.TestCase): + def test_build_concat_op(self): + newir_program = get_ir_program() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.concat([tanh_out, tanh_out], 0) + print(newir_program) + self.assertEqual(out.get_defining_op().name(), "pd.concat") + self.assertEqual( + out.get_defining_op() + .operands()[0] + .source() + .get_defining_op() + .name(), + "builtin.combine", + ) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + if __name__ == "__main__": unittest.main() From afd61eaadd5af09cb581072a8ba6d606833b2845 Mon Sep 17 00:00:00 2001 From: wangruting Date: Tue, 15 Aug 2023 09:15:58 +0000 Subject: [PATCH 08/30] fix gen conflict --- paddle/fluid/ir/dialect/pd_manual_api.cc | 9 --------- paddle/fluid/ir/dialect/pd_manual_api.h | 1 - test/ir/new_ir/test_build_op.py | 1 - 3 files changed, 11 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_manual_api.cc b/paddle/fluid/ir/dialect/pd_manual_api.cc index b4accf70427d4..df80e1639940e 100644 --- a/paddle/fluid/ir/dialect/pd_manual_api.cc +++ b/paddle/fluid/ir/dialect/pd_manual_api.cc @@ -20,15 +20,6 @@ namespace paddle { namespace dialect { -ir::OpResult concat(std::vector x, float axis) { - auto combine_op = - APIBuilder::Instance().GetBuilder()->Build(x); - auto concat_op = - APIBuilder::Instance().GetBuilder()->Build( - combine_op.out(), axis); - return concat_op.out(); -} - std::vector concat_grad(std::vector x, ir::OpResult out_grad, ir::OpResult axis) { diff --git a/paddle/fluid/ir/dialect/pd_manual_api.h b/paddle/fluid/ir/dialect/pd_manual_api.h index d6d1c1ba2c7b3..dff38ef565cb2 100644 --- a/paddle/fluid/ir/dialect/pd_manual_api.h +++ b/paddle/fluid/ir/dialect/pd_manual_api.h @@ -22,7 +22,6 @@ namespace paddle { namespace dialect { -ir::OpResult concat(std::vector x, float axis); std::vector concat_grad(std::vector x, ir::OpResult out_grad, diff --git a/test/ir/new_ir/test_build_op.py b/test/ir/new_ir/test_build_op.py index d2b5fff3fd0e8..e54e493b99a77 100644 --- a/test/ir/new_ir/test_build_op.py +++ b/test/ir/new_ir/test_build_op.py @@ -109,7 +109,6 @@ def test_build_concat_op(self): paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) with paddle.ir.core.program_guard(newir_program): out = paddle.concat([tanh_out, tanh_out], 0) - print(newir_program) self.assertEqual(out.get_defining_op().name(), "pd.concat") self.assertEqual( out.get_defining_op() From 2f3a72a17bf27d278a5b62a3c387521b9469cf52 Mon Sep 17 00:00:00 2001 From: Charles-hit Date: Wed, 16 Aug 2023 05:40:29 +0000 Subject: [PATCH 09/30] support vjp prim mode in new ir --- paddle/fluid/framework/type_info.cc | 3 +- .../fluid/ir/dialect/op_generator/api_gen.py | 2 + .../op_generator/vjp_interface_gen_op_list.py | 2 +- paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 98 +++++++++---- .../fluid/primitive/backend/eager_backend.cc | 4 +- .../fluid/primitive/backend/eager_backend.h | 4 +- .../fluid/primitive/backend/static_backend.cc | 40 ++--- .../fluid/primitive/backend/static_backend.h | 4 +- paddle/fluid/primitive/primitive/primitive.h | 63 +++++++- .../fluid/primitive/rule/vjp/CMakeLists.txt | 16 +- paddle/fluid/primitive/rule/vjp/details.h | 137 ++++++++++++++++++ .../fluid/primitive/rule/vjp/eager_utils.cc | 26 ++++ .../fluid/primitive/rule/vjp/static_utils.cc | 25 ++++ paddle/fluid/primitive/rule/vjp/utils.h | 103 +++++++++++++ paddle/fluid/primitive/rule/vjp/vjp.cc | 77 +++++++--- paddle/fluid/primitive/rule/vjp/vjp.h | 27 ++-- paddle/fluid/primitive/type/desc_tensor.h | 8 +- test/prim/new_ir_prim/test_vjp_prim.py | 61 ++++++++ 18 files changed, 602 insertions(+), 98 deletions(-) create mode 100644 paddle/fluid/primitive/rule/vjp/details.h create mode 100644 paddle/fluid/primitive/rule/vjp/eager_utils.cc create mode 100644 paddle/fluid/primitive/rule/vjp/static_utils.cc create mode 100644 paddle/fluid/primitive/rule/vjp/utils.h create mode 100644 test/prim/new_ir_prim/test_vjp_prim.py diff --git a/paddle/fluid/framework/type_info.cc b/paddle/fluid/framework/type_info.cc index 35fc167e49746..4545c2455f460 100644 --- a/paddle/fluid/framework/type_info.cc +++ b/paddle/fluid/framework/type_info.cc @@ -41,8 +41,7 @@ template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; -template class TypeInfoTraits; +template class TypeInfoTraits; template class TypeInfoTraits; diff --git a/paddle/fluid/ir/dialect/op_generator/api_gen.py b/paddle/fluid/ir/dialect/op_generator/api_gen.py index 7680ddfb1228c..6e31326705966 100644 --- a/paddle/fluid/ir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/api_gen.py @@ -95,6 +95,8 @@ 'expand', 'tile', 'add_grad', + 'divide_grad', + 'sum_grad', ] OP_RESULT = 'ir::OpResult' VECTOR_TYPE = 'ir::VectorType' diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 3201651e4696c..985a8b4011e68 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -21,4 +21,4 @@ # TODO(wanghao107) # remove this file and support Vjp methods # code gen. -vjp_interface_gen_op_list = ["tanh", "mean"] +vjp_interface_gen_op_list = ["tanh", "mean", "divide", "sum"] diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index be43ddd60491c..c173c59eaed5f 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -31,18 +31,16 @@ std::vector> TanhOp::Vjp( const std::vector>& out_grads, const std::vector>& stop_gradients) { TanhOp op_obj = op->dyn_cast(); - Tensor out( - std::make_shared(op_obj.out())); - Tensor grad_out( - std::make_shared(out_grads[0][0])); + Tensor out(std::make_shared(op_obj.out())); + Tensor grad_out(std::make_shared(out_grads[0][0])); std::vector> tensor_res = - primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); + primitive::tanh_vjp(out, grad_out, stop_gradients); std::vector> res(1, std::vector(1)); if (tensor_res[0][0].defined()) { - res[0][0] = std::static_pointer_cast( - tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); + res[0][0] = + std::static_pointer_cast(tensor_res[0][0].impl()) + ->getValue() + .dyn_cast(); } return res; } @@ -56,18 +54,16 @@ std::vector> Tanh_Op::Vjp( // so use the non-inplace version instead currently. // Support inplace in the future. Tanh_Op op_obj = op->dyn_cast(); - Tensor out( - std::make_shared(op_obj.out())); - Tensor grad_out( - std::make_shared(out_grads[0][0])); + Tensor out(std::make_shared(op_obj.out())); + Tensor grad_out(std::make_shared(out_grads[0][0])); std::vector> tensor_res = - primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); + primitive::tanh_vjp(out, grad_out, stop_gradients); std::vector> res(1, std::vector(1)); if (tensor_res[0][0].defined()) { - res[0][0] = std::static_pointer_cast( - tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); + res[0][0] = + std::static_pointer_cast(tensor_res[0][0].impl()) + ->getValue() + .dyn_cast(); } return res; } @@ -77,24 +73,72 @@ std::vector> MeanOp::Vjp( const std::vector>& out_grads, const std::vector>& stop_gradients) { MeanOp op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor out_grad( - std::make_shared(out_grads[0][0])); + Tensor x(std::make_shared(op_obj.x())); + Tensor out_grad(std::make_shared(out_grads[0][0])); IntArray axis = op->attribute("axis") .dyn_cast() .data(); bool keepdim = op->attribute("keepdim").dyn_cast().data(); bool reduce_all = false; + std::vector> tensor_res = primitive::mean_vjp( + x, out_grad, axis, keepdim, reduce_all, stop_gradients); + std::vector> res(1, std::vector(1)); + if (tensor_res[0][0].defined()) { + res[0][0] = + std::static_pointer_cast(tensor_res[0][0].impl()) + ->getValue() + .dyn_cast(); + } + return res; +} + +std::vector> DivideOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + DivideOp op_obj = op->dyn_cast(); + Tensor x(std::make_shared(op_obj.x())); + Tensor y(std::make_shared(op_obj.y())); + Tensor out(std::make_shared(op_obj.out())); + Tensor out_grad(std::make_shared(out_grads[0][0])); + + int axis = -1; std::vector> tensor_res = - primitive::experimental::mean_vjp( - x, out_grad, axis, keepdim, reduce_all, stop_gradients); + primitive::divide_vjp(x, y, out, out_grad, axis, stop_gradients); + std::vector> res(2, std::vector(1)); + for (size_t i = 0; i < 2; ++i) { + if (tensor_res[i][0].defined()) { + res[i][0] = std::static_pointer_cast( + tensor_res[i][0].impl()) + ->getValue() + .dyn_cast(); + } + } + return res; +} + +std::vector> SumOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + SumOp op_obj = op->dyn_cast(); + Tensor x(std::make_shared(op_obj.x())); + Tensor out_grad(std::make_shared(out_grads[0][0])); + + IntArray axis = op->attribute("axis") + .dyn_cast() + .data(); + bool keepdim = op->attribute("keepdim").dyn_cast().data(); + bool reduce_all = false; + std::vector> tensor_res = primitive::sum_vjp( + x, out_grad, axis, keepdim, reduce_all, stop_gradients); std::vector> res(1, std::vector(1)); if (tensor_res[0][0].defined()) { - res[0][0] = std::static_pointer_cast( - tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); + res[0][0] = + std::static_pointer_cast(tensor_res[0][0].impl()) + ->getValue() + .dyn_cast(); } return res; } diff --git a/paddle/fluid/primitive/backend/eager_backend.cc b/paddle/fluid/primitive/backend/eager_backend.cc index 5c06c0143f65e..ca2184c49a6f9 100644 --- a/paddle/fluid/primitive/backend/eager_backend.cc +++ b/paddle/fluid/primitive/backend/eager_backend.cc @@ -19,8 +19,6 @@ namespace paddle { namespace primitive { -namespace backend { -namespace experimental {} // namespace experimental -} // namespace backend +namespace backend {} // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/eager_backend.h b/paddle/fluid/primitive/backend/eager_backend.h index 1522bd1dfc31e..094487bb2b188 100644 --- a/paddle/fluid/primitive/backend/eager_backend.h +++ b/paddle/fluid/primitive/backend/eager_backend.h @@ -21,8 +21,6 @@ namespace paddle { namespace primitive { -namespace backend { -namespace experimental {} // namespace experimental -} // namespace backend +namespace backend {} // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index 040b68336d2d1..ab651a802e2f7 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -20,9 +20,8 @@ namespace paddle { namespace primitive { namespace backend { -namespace experimental { -using DescTensor = paddle::primitive::experimental::DescTensor; +using DescTensor = paddle::primitive::DescTensor; template <> Tensor tanh_grad(const Tensor& out, const Tensor& grad_out) { @@ -36,7 +35,7 @@ Tensor tanh_grad(const Tensor& out, const Tensor& grad_out) { ir::OpResult op_res = paddle::dialect::tanh_grad(out_res, grad_out_res); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -54,9 +53,9 @@ Tensor mean_grad(const Tensor& x, .dyn_cast(); ir::OpResult op_res = paddle::dialect::mean_grad( - x_res, out_grad_res, axis, keepdim, reduce_all); + x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -68,7 +67,7 @@ Tensor divide(const Tensor& x, const Tensor& y) { ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::divide(x_res, y_res); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -80,7 +79,7 @@ Tensor add(const Tensor& x, const Tensor& y) { ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::add(x_res, y_res); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -92,11 +91,12 @@ Tensor multiply(const Tensor& x, const Tensor& y) { ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::multiply(x_res, y_res); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> Tensor elementwise_pow(const Tensor& x, const Tensor& y) { + VLOG(3) << "elementwise_pow static api"; ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); @@ -104,7 +104,7 @@ Tensor elementwise_pow(const Tensor& x, const Tensor& y) { ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::elementwise_pow(x_res, y_res); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -117,7 +117,7 @@ Tensor scale(const Tensor& x, .dyn_cast(); ir::OpResult op_res = paddle::dialect::scale(x_res, scale.to(), bias, bias_after_scale); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -130,7 +130,7 @@ Tensor sum(const Tensor& x, .dyn_cast(); ir::OpResult op_res = paddle::dialect::sum(x_res, axis.GetData(), dtype, keepdim); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -138,18 +138,23 @@ Tensor full(const IntArray& shape, const Scalar& value, phi::DataType dtype, phi::Place place) { + VLOG(3) << "full static api"; ir::OpResult op_res = paddle::dialect::full(shape.GetData(), value.to(), dtype, place); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> -Tensor reshape(const Tensor& x, const IntArray& shape) { +std::tuple reshape(const Tensor& x, + const IntArray& shape) { ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); - ir::OpResult op_res = paddle::dialect::reshape(x_res, shape.GetData()); - return Tensor(std::make_shared(op_res)); + std::tuple op_res = + paddle::dialect::reshape(x_res, shape.GetData()); + return std::make_tuple( + Tensor(std::make_shared(std::get<0>(op_res))), + Tensor(std::make_shared(std::get<1>(op_res)))); } template <> @@ -158,7 +163,7 @@ Tensor expand(const Tensor& x, const IntArray& shape) { ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::expand(x_res, shape.GetData()); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> @@ -167,10 +172,9 @@ Tensor tile(const Tensor& x, const IntArray& repeat_times) { ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times.GetData()); - return Tensor(std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } -} // namespace experimental } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index 443f29366b3a2..81eef7e529466 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -23,7 +23,6 @@ namespace paddle { namespace primitive { namespace backend { -namespace experimental { using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArray; @@ -69,7 +68,7 @@ Tensor full(const IntArray& shape, phi::Place place = phi::CPUPlace()); template -Tensor reshape(const Tensor& x, const IntArray& shape); +std::tuple reshape(const Tensor& x, const IntArray& shape); template Tensor expand(const Tensor& x, const IntArray& shape); @@ -77,7 +76,6 @@ Tensor expand(const Tensor& x, const IntArray& shape); template Tensor tile(const Tensor& x, const IntArray& repeat_times = {}); -} // namespace experimental } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/primitive/primitive.h b/paddle/fluid/primitive/primitive/primitive.h index a15334851c87d..80510ed921c04 100644 --- a/paddle/fluid/primitive/primitive/primitive.h +++ b/paddle/fluid/primitive/primitive/primitive.h @@ -18,12 +18,71 @@ namespace paddle { namespace primitive { -namespace experimental { // why exist this file? // We provide this file to divide // the primitive ops set in the backend. // It will be called by the vjp composite // rules and composite ops rules. -} // namespace experimental +using Tensor = paddle::Tensor; +using IntArray = paddle::experimental::IntArray; + +template +Tensor divide(const Tensor& x, const Tensor& y) { + return backend::divide(x, y); +} + +template +Tensor add(const Tensor& x, const Tensor& y) { + return backend::add(x, y); +} + +template +Tensor multiply(const Tensor& x, const Tensor& y) { + return backend::multiply(x, y); +} + +template +Tensor elementwise_pow(const Tensor& x, const Tensor& y) { + return backend::elementwise_pow(x, y); +} + +template +Tensor scale(const Tensor& x, + const Scalar& scale = 1.0, + float bias = 0.0, + bool bias_after_scale = true) { + return backend::scale(x, scale, bias, bias_after_scale); +} + +template +Tensor sum(const Tensor& x, + const IntArray& axis = {}, + phi::DataType dtype = phi::DataType::UNDEFINED, + bool keepdim = false) { + return backend::sum(x, axis, dtype, keepdim); +} + +template +Tensor full(const IntArray& shape, + const Scalar& value, + phi::DataType dtype = phi::DataType::FLOAT32, + phi::Place place = phi::CPUPlace()) { + return backend::full(shape, value, dtype, place); +} + +template +std::tuple reshape(const Tensor& x, const IntArray& shape) { + return backend::reshape(x, shape); +} + +template +Tensor expand(const Tensor& x, const IntArray& shape) { + return backend::expand(x, shape); +} + +template +Tensor tile(const Tensor& x, const IntArray& repeat_times = {}) { + return backend::tile(x, repeat_times); +} } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt index eb72b0c9ecc65..2a8c401a77eb0 100644 --- a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt +++ b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt @@ -1,7 +1,17 @@ -file(GLOB VJP_SRCS "*.cc") - +file(GLOB VJP_SRCS "vjp.cc") +if(WITH_PYTHON OR NOT ON_INFER) + cc_library( + primitive_eager_utils_experimental + SRCS eager_utils.cc + DEPS phi common_infer_shape_functions) +endif() +cc_library( + primitive_static_utils_experimental + SRCS static_utils.cc + DEPS phi common_infer_shape_functions) cc_library( primitive_vjp_experimental SRCS ${VJP_SRCS} - DEPS primitive_backend_static_experimental) + DEPS primitive_backend_static_experimental static_global_utils + primitive_static_utils_experimental) add_dependencies(primitive_vjp_experimental pd_dialect) diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h new file mode 100644 index 0000000000000..8c1d12be0451f --- /dev/null +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -0,0 +1,137 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif + +#include +#include + +#include "paddle/fluid/primitive/primitive/primitive.h" +#include "paddle/fluid/primitive/rule/vjp/utils.h" +#include "paddle/fluid/primitive/type/desc_tensor.h" + +namespace paddle { +namespace primitive { +namespace details { + +template +void divide_grad(const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis, + Tensor* dx, + Tensor* dy) { + if (dy) { + // dy = -(x/y^2) * dout + auto denominator = + elementwise_pow(y, full(y.shape(), 2.0, y.dtype(), y.place())); + auto dy_res = scale( + multiply(divide(x, denominator), out_grad), -1.0, 0.0, true); + if (x.dims() != y.dims()) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + set_output(dy_res, dy); + } else { + auto dy_reduce_res = + sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); + auto reshape_res = reshape(dy_reduce_res, phi::vectorize(y.dims())); + auto dy_tmp = std::get<0>(reshape_res); + set_output(dy_tmp, dy); + } + } else { + set_output(dy_res, dy); + } + } // indicate we will compute dy + if (dx) { + // dx = (1/y) * dout + auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); + auto dx_res = multiply(divide(one_tensor, y), out_grad); + if (y.dims() != x.dims()) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + set_output(dx_res, dx); + } else { + auto dx_reduce_res = + sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_reduce_reshape_res = + reshape(dx_reduce_res, phi::vectorize(x.dims())); + auto dx_tmp = std::get<0>(dx_reduce_reshape_res); + set_output(dx_tmp, dx); + } + + } else { + set_output(dx_res, dx); + } + } // indicate we will compute dx +} + +template +void sum_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + std::vector x_dim = phi::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + if (x_dim_size == 1) { + x_grad_tmp = expand(out_grad, IntArray(x_dim)); + } else { + if (!keepdim) { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 0; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_reshape_res = reshape(out_grad, out_grad_shape); + auto out_grad_ = std::get<0>(out_grad_reshape_res); + x_grad_tmp = expand(out_grad, IntArray(x_dim)); + } else { + x_grad_tmp = expand(out_grad, IntArray(x_dim)); + } + } + + set_output(x_grad_tmp, x_grad); +} + +} // namespace details +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/eager_utils.cc b/paddle/fluid/primitive/rule/vjp/eager_utils.cc new file mode 100644 index 0000000000000..46a1310a6a351 --- /dev/null +++ b/paddle/fluid/primitive/rule/vjp/eager_utils.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" +#include "paddle/fluid/primitive/rule/vjp/utils.h" + +namespace paddle { +namespace primitive { +template <> +void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { + x->set_impl(x_tmp.impl()); + x->set_autograd_meta(x_tmp.mutable_autograd_meta()); +} + +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/static_utils.cc b/paddle/fluid/primitive/rule/vjp/static_utils.cc new file mode 100644 index 0000000000000..0d8725d0d398b --- /dev/null +++ b/paddle/fluid/primitive/rule/vjp/static_utils.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/primitive/rule/vjp/utils.h" +#include "paddle/fluid/primitive/type/desc_tensor.h" + +namespace paddle { +namespace primitive { +template <> +void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { + x->set_impl(x_tmp.impl()); +} + +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/utils.h b/paddle/fluid/primitive/rule/vjp/utils.h new file mode 100644 index 0000000000000..b5d21667e54f8 --- /dev/null +++ b/paddle/fluid/primitive/rule/vjp/utils.h @@ -0,0 +1,103 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/ddim.h" + +namespace paddle { +namespace primitive { + +template +void set_output(const Tensor& x_tmp, Tensor* x); + +// This fucction compute unsqueeze dims for reshape to replace unsqueeze. +static std::vector get_unsqueeze_dims( + const Tensor& origin, const std::vector& axis) { + auto origin_dims = origin.shape(); + auto total_shape_size = origin_dims.size() + axis.size(); + std::vector result; + size_t j = 0, k = 0; + for (size_t i = 0; i < total_shape_size; ++i) { + if (j < axis.size() && axis[j] == int64_t(i)) { + result.push_back(1); + j++; + } else { + PADDLE_ENFORCE_LT( + k, + origin_dims.size(), + platform::errors::OutOfRange("Your index [%lu] exceeds the number of " + "elements in origin_dims[%lu].", + k, + origin_dims.size())); + result.push_back(origin_dims[k]); + k++; + } + } + return result; +} + +// These method don't need to be specified +static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, + const phi::DDim& in_dims) { + std::vector result; + int bat = dout_dims.size() - in_dims.size(); + for (int i = 0; i < bat; ++i) { + result.push_back(i); + } + for (int i = 0; i < in_dims.size(); ++i) { + if (in_dims[i] == 1) { + result.push_back(i + bat); + } else { + PADDLE_ENFORCE_EQ( + in_dims[i], + dout_dims[i + bat], + platform::errors::InvalidArgument( + "ReduceDims dimension mismatch. Operands could " + "not be broadcast together with the shape of dout = [%s] and " + "the shape of in_dims = [%s]. Received [%d] in X is not equal to " + "[%d] in Y at i:%d.", + dout_dims, + in_dims, + dout_dims[i + bat], + in_dims[i], + i)); + } + } + return phi::make_ddim(result); +} + +static phi::DDim get_reduce_dims(const phi::DDim& x_dims, + const phi::DDim& y_dims) { + auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims); + return get_reduce_dims_from_out(out_dims, x_dims); +} + +static std::vector get_details_outputs_ptr( + const std::vector& outputs, + const std::vector& stop_gradients) { + std::vector outputs_ptr(outputs.size(), nullptr); + for (size_t i = 0; i < outputs.size(); i++) { + if (!stop_gradients[i]) { + outputs_ptr[i] = &outputs[i]; + } + } + return outputs_ptr; +} + +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index b5f0acf98c1d8..ad18d9172a6c4 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/fluid/primitive/rule/vjp/vjp.h" -#include "paddle/fluid/ir/dialect/pd_api.h" +#include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/primitive/backend/static_backend.h" +#include "paddle/fluid/primitive/rule/vjp/details.h" +#include "paddle/fluid/primitive/rule/vjp/utils.h" #include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/ir/core/operation.h" // TODO(wanghao107): @@ -22,7 +24,6 @@ namespace paddle { namespace primitive { -namespace experimental { std::vector> tanh_vjp( const Tensor& out, @@ -31,19 +32,15 @@ std::vector> tanh_vjp( std::vector> vjp_res( 1, std::vector(1)); // get tanh_grad res. - Tensor op_res = - backend::experimental::tanh_grad( - out, grad_out); + Tensor op_res = backend::tanh_grad(out, grad_out); // set op stop_gradient info // TODO(wanghao107): Replace with more generic code. // Support set stop_gradients for all ops. - ir::Operation* grad_op = - std::static_pointer_cast( - op_res.impl()) - ->getValue() - .dyn_cast() - .owner(); + ir::Operation* grad_op = std::static_pointer_cast(op_res.impl()) + ->getValue() + .dyn_cast() + .owner(); uint32_t num_res = grad_op->num_results(); std::vector ir_stop_gradients(num_res); for (size_t i = 0; i < num_res; i++) { @@ -77,18 +74,15 @@ std::vector> mean_vjp( 1, std::vector(1)); // get mean_grad res. Tensor op_res = - backend::experimental::mean_grad( - x, out_grad, axis, keepdim, reduce_all); + backend::mean_grad(x, out_grad, axis, keepdim, reduce_all); // set op stop_gradient info // TODO(wanghao107): Replace with more generic code. // Support set stop_gradients for all ops. - ir::Operation* grad_op = - std::static_pointer_cast( - op_res.impl()) - ->getValue() - .dyn_cast() - .owner(); + ir::Operation* grad_op = std::static_pointer_cast(op_res.impl()) + ->getValue() + .dyn_cast() + .owner(); uint32_t num_res = grad_op->num_results(); std::vector ir_stop_gradients(num_res); for (size_t i = 0; i < num_res; i++) { @@ -111,6 +105,49 @@ std::vector> mean_vjp( return vjp_res; } -} // namespace experimental +std::vector> divide_vjp( + const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis, + const std::vector>& stop_gradients) { + // TODO(wanghao107): support prim and no prim + // mode in this function when pd_api code_gen + // is ready; + VLOG(3) + << std::boolalpha << "IsBwdPrimEnabled: " + << paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled(); + std::vector> vjp_res( + 2, std::vector(1)); + // get divide_grad prim mode res. + Tensor* dx = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; + Tensor* dy = !stop_gradients[1][0] ? &vjp_res[0][0] : nullptr; + details::divide_grad(x, y, out, out_grad, axis, dx, dy); + + return vjp_res; +} + +std::vector> sum_vjp( + const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + const std::vector>& stop_gradients) { + // TODO(wanghao107): support prim and no prim + // mode in this function when pd_api code_gen + // is ready; + VLOG(3) + << std::boolalpha << "IsBwdPrimEnabled: " + << paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled(); + std::vector> vjp_res( + 1, std::vector(1)); + // get divide_grad prim mode res. + Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; + details::sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad); + return vjp_res; +} + } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.h b/paddle/fluid/primitive/rule/vjp/vjp.h index 48bc2affa9db4..59be5cdf6c356 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.h +++ b/paddle/fluid/primitive/rule/vjp/vjp.h @@ -14,13 +14,6 @@ #pragma once -#ifndef _USE_MATH_DEFINES -#define _USE_MATH_DEFINES -#endif - -#include -#include - #include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/ir/core/value.h" #include "paddle/phi/api/include/tensor.h" @@ -28,7 +21,6 @@ namespace paddle { namespace primitive { -namespace experimental { using IntArray = paddle::experimental::IntArray; // TODO(wanghao107): @@ -46,11 +38,20 @@ std::vector> mean_vjp( bool reduce_all, const std::vector>& stop_gradients); -namespace details { -// NOTE: this namespace will store -// primitive ops grad composite rules. +std::vector> divide_vjp( + const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis, + const std::vector>& stop_gradients); -} // namespace details -} // namespace experimental +std::vector> sum_vjp( + const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + const std::vector>& stop_gradients); } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/type/desc_tensor.h b/paddle/fluid/primitive/type/desc_tensor.h index 650b00e58ba7d..b4547bb6603a3 100644 --- a/paddle/fluid/primitive/type/desc_tensor.h +++ b/paddle/fluid/primitive/type/desc_tensor.h @@ -22,7 +22,6 @@ namespace paddle { namespace primitive { -namespace experimental { class DescTensor : public phi::ExtendedTensor, public phi::TypeInfoTraits { @@ -38,18 +37,21 @@ class DescTensor : public phi::ExtendedTensor, int64_t numel() const override { return product(dims()); } DataType dtype() const override { - return paddle::dialect::TransToPhiDataType(value_.type()); + return paddle::dialect::TransToPhiDataType( + value_.type().dyn_cast().dtype()); } ir::Value getValue() const { return value_; } + const phi::Place& place() const override { return place_; } + bool initialized() const override { return value_.impl() != nullptr; } private: ir::Value value_; mutable phi::DDim dims_; + phi::Place place_; }; -} // namespace experimental } // namespace primitive } // namespace paddle diff --git a/test/prim/new_ir_prim/test_vjp_prim.py b/test/prim/new_ir_prim/test_vjp_prim.py new file mode 100644 index 0000000000000..204b48788775a --- /dev/null +++ b/test/prim/new_ir_prim/test_vjp_prim.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle import ir +from paddle.fluid.core import call_vjp, has_vjp + +paddle.enable_static() + + +def get_ir_program(): + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=2.0) + x.stop_gradient = False + y = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=1.0) + y.stop_gradiable = False + dout = paddle.tensor.fill_constant( + shape=[4], dtype='float32', value=1.0 + ) + dout.stop_gradiable = False + out = paddle.divide(x, y) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TesBackward(unittest.TestCase): + def test_1(self): + newir_program = get_ir_program() + # input = newir_program.block().ops[-1].operand(0).source() + y = newir_program.block().ops[-3].result(0) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False], [False]] + divide_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + print(y.dtype) + print(newir_program) + print(has_vjp(divide_op)) + grad_outs = call_vjp(divide_op, out_grads, stop_gradients) + print(newir_program) + + +if __name__ == "__main__": + unittest.main() From 6b6fc8eee5d40f40a4f3c0e4f7a93c95916fbe4b Mon Sep 17 00:00:00 2001 From: Charles-hit Date: Wed, 16 Aug 2023 08:18:11 +0000 Subject: [PATCH 10/30] remove useless code --- paddle/fluid/primitive/rule/vjp/utils.h | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/paddle/fluid/primitive/rule/vjp/utils.h b/paddle/fluid/primitive/rule/vjp/utils.h index b5d21667e54f8..e1765357aa9f8 100644 --- a/paddle/fluid/primitive/rule/vjp/utils.h +++ b/paddle/fluid/primitive/rule/vjp/utils.h @@ -87,17 +87,5 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims, return get_reduce_dims_from_out(out_dims, x_dims); } -static std::vector get_details_outputs_ptr( - const std::vector& outputs, - const std::vector& stop_gradients) { - std::vector outputs_ptr(outputs.size(), nullptr); - for (size_t i = 0; i < outputs.size(); i++) { - if (!stop_gradients[i]) { - outputs_ptr[i] = &outputs[i]; - } - } - return outputs_ptr; -} - } // namespace primitive } // namespace paddle From bfbb0e866599c79a15add02012a4108a0bc63501 Mon Sep 17 00:00:00 2001 From: chenzhiyang <1792266893@qq.com> Date: Wed, 16 Aug 2023 12:37:18 +0000 Subject: [PATCH 11/30] add vjp autogen v1.0 --- .../fluid/ir/dialect/op_generator/op_gen.py | 32 ++++- .../dialect/op_generator/op_interface_gen.py | 125 +++++++++++------ paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 127 +----------------- 3 files changed, 113 insertions(+), 171 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 5bbb5c80c0693..7a5968ae42cb4 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -20,6 +20,7 @@ from op_interface_gen import ( gen_exclusive_interface_str, gen_op_infer_meta_str, + gen_op_vjp_str, vjp_interface_gen_op_list, ) from op_member_func_gen import gen_op_get_inputs_outputs_str @@ -286,6 +287,9 @@ def __init__(self, op_yaml_item, op_compat_item): self.attribute_build_arg_type_list = ( self.parse_attribute_build_arg_type_list() ) + self.attribute_gen_arg_type_list = ( + self.parse_attribute_gen_arg_type_list() + ) self.attribute_data_type_list = self.parse_attribute_data_type_list() self.attribute_default_value_list = ( self.parse_attribute_default_value_list() @@ -584,6 +588,17 @@ def parse_attribute_build_arg_type_list(self): type_list.append(self.get_phi_dtype_name(temp_type)) return type_list + def parse_attribute_gen_arg_type_list(self): + type_list = [] + for attribute_info in self.op_yaml_item['attrs']: + assert ( + attribute_info['typename'] in self.attr_types_map + ), f"{self.op_phi_name} : Attr type error." + + temp_type = self.attr_types_map[attribute_info['typename']][1] + type_list.append(self.get_phi_dtype_name(temp_type)) + return type_list + def parse_attribute_type_list(self): type_list = [] for attribute_info in self.op_yaml_item['attrs']: @@ -1038,12 +1053,17 @@ def OpGenerator( op_vjp_str = '' # TODO(chenzhiyang) add vjp gen code - # if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list: - # op_vjp_str = gen_op_vjp_str(op_class_name, - # op_info.backward_name, - # op_name, - # op_info_items[op_info.op_phi_name[0]], - # op_info_items[op_info.backward_name]) + if ( + op_info.backward_name + and op_info.op_phi_name[0] in vjp_interface_gen_op_list + ): + op_vjp_str = gen_op_vjp_str( + op_class_name, + op_info.backward_name, + op_name, + op_info_items[op_info.op_phi_name[0]], + op_info_items[op_info.backward_name], + ) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index ef5f2e1b4ccab..368730bb27989 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -23,57 +23,61 @@ """ OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ - {input_type} {input_name}(std::make_shared(op_obj.{input_name}())); -""" + {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ - Tensor {output_grad_name}(std::make_shared((out_grads[{idx1}][{idx2}]); -""" + Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ - std::vector {output_grad_name}(std::make_shared((out_grads[{idx1}]); -""" + std::vector {output_grad_name}(std::make_shared(out_grads[{idx1}]));""" -OP_VJP_CALL_VJP_TEMPLATE = """ - Tensor std::vector> tensor_res = - primitive::experimental::{op_phi_name}_vjp({inputs_list}, stop_gradients); -""" +OP_VJP_ATTRIBUTE_TEMPLATE = """ + {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" -OP_VJP_STOPGRADIENT_TEMPLATE = """ - if(!stop_gradients[{idx1}][{idx2}]){{ - res[{idx1}][{idx2}] = std::static_pointer_cast( - tensor_res[idx1][idx2].impl()) - ->getValue() - .dyn_cast(); - }} -""" +OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ + {attr_type} {attr_name} = {default_value};""" -OP_VJP_DEFINE_TEMPLATE = """ -std::vector> {op_class_name}::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients){{ - {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); - VLOG(6) << "Prepare inputs of {op_grad_name}"; +OP_VJP_CALL_VJP_TEMPLATE = """ std::vector> tensor_res = + primitive::experimental::{op_phi_name}_vjp( + {inputs_list}stop_gradients);""" - {forward_input_code} - {forward_output_code} - {forward_output_grad_code} +OP_VJP_STOPGRADIENT_TEMPLATE = """ + std::vector> res(tensor_res.size()); + for (size_t i = 0; i < tensor_res.size(); ++i) {{ + res[i].resize(tensor_res[i].size()); + for (size_t j = 0; j < tensor_res[i].size(); ++j) {{ + if(tensor_res[i][j].defined()){{ + res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); + }} + }} + }}""" + +OP_VJP_DEFINE_TEMPLATE = """ +std::vector> {op_class_name}::Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ + {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); - VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; - {attribute_code} +VLOG(6) << "Prepare inputs of {op_grad_name}"; +{forward_input_code} +{forward_output_grad_code} - VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; - {call_vjp_code} +VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; +{attribute_code} - std::vector> res(1, std::vector(1)); - {stop_gradient_input_grad_code} +VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; +{call_vjp_code} - return res; +VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}"; +{stop_gradient_input_grad_code} + return res; }} """ +input_types_map = { + 'paddle::dialect::DenseTensorType': 'Tensor', + 'ir::VectorType': 'Tensor[]', +} + def gen_op_vjp_str( op_class_name, @@ -82,19 +86,62 @@ def gen_op_vjp_str( op_info, op_grad_info, ): + bw_input_list = op_grad_info.input_name_list forward_input_code = '' - forward_output_code = '' forward_output_grad_code = '' + build_args_str = '' + grad_idx = -1 + for idx in range(len(bw_input_list)): + build_args_str += bw_input_list[idx] + ", " + if ( + bw_input_list[idx] in op_info.input_name_list + or bw_input_list[idx] in op_info.output_name_list + ): + forward_input_code += ( + OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + input_type=input_types_map[ + op_grad_info.input_type_list[idx] + ], + input_name=bw_input_list[idx], + ) + ) + else: + grad_idx += 1 + forward_output_grad_code += ( + OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], idx1=grad_idx, idx2=0 + ) + ) + op_attribute_list = op_grad_info.attribute_name_list attribute_code = '' - call_vjp_code = '' - stop_gradient_input_grad_code = '' + for idx in range(len(op_attribute_list)): + build_args_str += op_attribute_list[idx] + ", " + if op_attribute_list[idx] in op_info.attribute_name_list: + attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( + attr_type=op_grad_info.attribute_gen_arg_type_list[idx], + attr_name=op_attribute_list[idx], + attr_parse_type=op_grad_info.attribute_type_list[idx], + ) + else: + attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format( + attr_type=op_grad_info.attribute_gen_arg_type_list[idx], + attr_name=op_attribute_list[idx], + default_value=op_grad_info.attribute_default_value_list[idx], + ) + if op_phi_name[-1] == '_': + op_phi_name = op_phi_name[:-1] + call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format( + op_phi_name=op_phi_name, + inputs_list=build_args_str, + ) + stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE str = OP_VJP_DEFINE_TEMPLATE.format( op_class_name=op_class_name, op_grad_name=op_grad_name, op_phi_name=op_phi_name, + res_size=len(op_info.input_name_list), forward_input_code=forward_input_code, - forward_output_code=forward_output_code, forward_output_grad_code=forward_output_grad_code, attribute_code=attribute_code, call_vjp_code=call_vjp_code, diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index ee03e826b652d..c9201bc03c54f 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -23,130 +23,5 @@ // this file will be generated in pd_op.cc namespace paddle { -namespace dialect { -using IntArray = paddle::experimental::IntArray; - -std::vector> TanhOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - TanhOp op_obj = op->dyn_cast(); - Tensor out( - std::make_shared(op_obj.out())); - Tensor grad_out( - std::make_shared(out_grads[0][0])); - std::vector> tensor_res = - primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); - std::vector> res(1, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = std::static_pointer_cast( - tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); - } - return res; -} - -std::vector> Tanh_Op::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - // TODO(wanghao107) - // we don't support inplace now, - // so use the non-inplace version instead currently. - // Support inplace in the future. - Tanh_Op op_obj = op->dyn_cast(); - Tensor out( - std::make_shared(op_obj.out())); - Tensor grad_out( - std::make_shared(out_grads[0][0])); - std::vector> tensor_res = - primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); - std::vector> res(1, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = std::static_pointer_cast( - tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); - } - return res; -} - -std::vector> MeanOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - MeanOp op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor out_grad( - std::make_shared(out_grads[0][0])); - - IntArray axis = op->attribute("axis") - .dyn_cast() - .data(); - bool keepdim = op->attribute("keepdim").dyn_cast().data(); - bool reduce_all = false; - std::vector> tensor_res = - primitive::experimental::mean_vjp( - x, out_grad, axis, keepdim, reduce_all, stop_gradients); - std::vector> res(1, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = std::static_pointer_cast( - tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); - } - return res; -} - -std::vector> AddOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - AddOp op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor y(std::make_shared(op_obj.y())); - Tensor out_grad( - std::make_shared(out_grads[0][0])); - int axis = -1; - - std::vector> tensor_res = - primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients); - std::vector> res(2, std::vector(1)); - for (size_t i = 0; i < 2; ++i) { - if (tensor_res[i][0].defined()) { - res[i][0] = std::static_pointer_cast( - tensor_res[i][0].impl()) - ->getValue() - .dyn_cast(); - } - } - return res; -} - -std::vector> Add_Op::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - Add_Op op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor y(std::make_shared(op_obj.y())); - Tensor out_grad( - std::make_shared(out_grads[0][0])); - int axis = -1; - - std::vector> tensor_res = - primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients); - std::vector> res(2, std::vector(1)); - for (size_t i = 0; i < 2; ++i) { - if (tensor_res[i][0].defined()) { - res[i][0] = std::static_pointer_cast( - tensor_res[i][0].impl()) - ->getValue() - .dyn_cast(); - } - } - return res; -} -} // namespace dialect +namespace dialect {} // namespace dialect } // namespace paddle From 065c0119cbb6c57e8ce782b8ef739b9d36e2aa31 Mon Sep 17 00:00:00 2001 From: Charles-hit Date: Thu, 17 Aug 2023 02:41:13 +0000 Subject: [PATCH 12/30] add test for prim --- paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 4 +- .../fluid/primitive/backend/static_backend.cc | 48 ++++++- .../fluid/primitive/backend/static_backend.h | 14 ++ paddle/fluid/primitive/rule/vjp/vjp.cc | 49 ++++--- test/prim/new_ir_prim/test_vjp_prim.py | 132 ++++++++++++++++-- 5 files changed, 210 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index 7c123189c7009..c19a72fcd6924 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -126,7 +126,9 @@ std::vector> SumOp::Vjp( Tensor x(std::make_shared(op_obj.x())); Tensor out_grad(std::make_shared(out_grads[0][0])); - IntArray axis = op->attribute("axis") + IntArray axis = op_obj.axis() + .GetDefiningOp() + ->attribute("value") .dyn_cast() .data(); bool keepdim = op->attribute("keepdim").dyn_cast().data(); diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index b6611b78bad58..1d55453805553 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -96,7 +96,6 @@ Tensor multiply(const Tensor& x, const Tensor& y) { template <> Tensor elementwise_pow(const Tensor& x, const Tensor& y) { - VLOG(3) << "elementwise_pow static api"; ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); @@ -138,7 +137,6 @@ Tensor full(const IntArray& shape, const Scalar& value, phi::DataType dtype, phi::Place place) { - VLOG(3) << "full static api"; ir::OpResult op_res = paddle::dialect::full(shape.GetData(), value.to(), dtype, place); return Tensor(std::make_shared(op_res)); @@ -198,6 +196,52 @@ std::tuple add_grad(const Tensor& x, Tensor(std::make_shared(std::get<0>(op_res))), Tensor(std::make_shared(std::get<1>(op_res)))); } + +template <> +std::tuple divide_grad(const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult out_res = std::static_pointer_cast(out.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult out_grad_res = + std::static_pointer_cast(out_grad.impl()) + ->getValue() + .dyn_cast(); + + std::tuple op_res = + paddle::dialect::divide_grad(x_res, y_res, out_res, out_grad_res, axis); + + return std::make_tuple( + Tensor(std::make_shared(std::get<0>(op_res))), + Tensor(std::make_shared(std::get<1>(op_res)))); +} + +template <> +Tensor sum_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult out_grad_res = + std::static_pointer_cast(out_grad.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::sum_grad( + x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); + return Tensor(std::make_shared(op_res)); +} } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index 95110184704f6..1e484aa35e676 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -82,6 +82,20 @@ Tensor expand(const Tensor& x, const IntArray& shape); template Tensor tile(const Tensor& x, const IntArray& repeat_times = {}); +template +std::tuple divide_grad(const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis); + +template +Tensor sum_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all); + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 83b4397b5bd87..d0d901ddd4f69 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -116,7 +116,7 @@ std::vector> add_vjp( const std::vector>& stop_gradients) { std::vector> vjp_res( 2, std::vector(1)); - // get mean_grad res. + // get add_grad res. std::tuple op_res = backend::add_grad(x, y, out_grad, axis); @@ -155,19 +155,21 @@ std::vector> divide_vjp( const Tensor& out_grad, int axis, const std::vector>& stop_gradients) { - // TODO(wanghao107): support prim and no prim - // mode in this function when pd_api code_gen - // is ready; - VLOG(3) - << std::boolalpha << "IsBwdPrimEnabled: " - << paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled(); std::vector> vjp_res( 2, std::vector(1)); - // get divide_grad prim mode res. - Tensor* dx = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; - Tensor* dy = !stop_gradients[1][0] ? &vjp_res[0][0] : nullptr; - details::divide_grad(x, y, out, out_grad, axis, dx, dy); - + if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) { + // get divide_grad res. + std::tuple op_res = + backend::divide_grad(x, y, out, out_grad, axis); + // construct vjp result by op result and stop_gradients info + vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0]; + vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0]; + } else { + // get divide_grad prim mode res. + Tensor* dx = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; + Tensor* dy = !stop_gradients[1][0] ? &vjp_res[1][0] : nullptr; + details::divide_grad(x, y, out, out_grad, axis, dx, dy); + } return vjp_res; } @@ -178,17 +180,22 @@ std::vector> sum_vjp( bool keepdim, bool reduce_all, const std::vector>& stop_gradients) { - // TODO(wanghao107): support prim and no prim - // mode in this function when pd_api code_gen - // is ready; - VLOG(3) - << std::boolalpha << "IsBwdPrimEnabled: " - << paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled(); std::vector> vjp_res( 1, std::vector(1)); - // get divide_grad prim mode res. - Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; - details::sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad); + if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) { + // get sum_grad res. + Tensor op_res = backend::sum_grad( + x, out_grad, axis, keepdim, reduce_all); + // construct vjp result by op result and stop_gradients info + if (!stop_gradients[0][0]) { + vjp_res[0][0] = op_res; + } + } else { + // get divide_grad prim mode res. + Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; + details::sum_grad( + x, out_grad, axis, keepdim, reduce_all, x_grad); + } return vjp_res; } diff --git a/test/prim/new_ir_prim/test_vjp_prim.py b/test/prim/new_ir_prim/test_vjp_prim.py index 204b48788775a..8c2fd4ebd76df 100644 --- a/test/prim/new_ir_prim/test_vjp_prim.py +++ b/test/prim/new_ir_prim/test_vjp_prim.py @@ -16,23 +16,25 @@ import paddle from paddle import ir -from paddle.fluid.core import call_vjp, has_vjp +from paddle.fluid.core import call_vjp paddle.enable_static() -def get_ir_program(): +def get_ir_program_0(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), ) with paddle.static.program_guard(main_program, start_program): - x = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=2.0) + x = paddle.tensor.fill_constant( + shape=[1, 4], dtype='float32', value=2.0 + ) x.stop_gradient = False y = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=1.0) y.stop_gradiable = False dout = paddle.tensor.fill_constant( - shape=[4], dtype='float32', value=1.0 + shape=[1, 4], dtype='float32', value=1.0 ) dout.stop_gradiable = False out = paddle.divide(x, y) @@ -40,21 +42,125 @@ def get_ir_program(): return newir_program -class TesBackward(unittest.TestCase): - def test_1(self): - newir_program = get_ir_program() - # input = newir_program.block().ops[-1].operand(0).source() - y = newir_program.block().ops[-3].result(0) +def get_ir_program_1(): + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x = paddle.tensor.fill_constant( + shape=[4, 5], dtype='float32', value=2.0 + ) + x.stop_gradient = False + dout = paddle.tensor.fill_constant( + shape=[1], dtype='float32', value=1.0 + ) + dout.stop_gradiable = False + out = paddle.sum(x) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TestVjpPrim(unittest.TestCase): + def test_divide_grad_prim_case1(self): + newir_program = get_ir_program_0() + paddle.fluid.core._set_prim_backward_enabled(True) dout = newir_program.block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False], [False]] divide_op = newir_program.block().ops[-1] with paddle.ir.core.program_guard(newir_program): - print(y.dtype) - print(newir_program) - print(has_vjp(divide_op)) grad_outs = call_vjp(divide_op, out_grads, stop_gradients) - print(newir_program) + reshape_op2 = newir_program.block().ops[-1] + reshape_op1 = newir_program.block().ops[-8] + self.assertEqual(len(grad_outs), 2) + self.assertEqual(len(newir_program.block().ops), 21) + self.assertEqual(reshape_op2.result(0), grad_outs[0][0]) + self.assertEqual(reshape_op1.result(0), grad_outs[1][0]) + all_op_names = [ + "pd.full", + "pd.full", + "pd.full", + "pd.divide", + "pd.full", + "pd.elementwise_pow", + "pd.divide", + "pd.multiply", + "pd.full", + "pd.scale", + "pd.full_int_array", + "pd.sum", + "pd.full_int_array", + "pd.reshape", + "pd.full", + "pd.divide", + "pd.multiply", + "pd.full_int_array", + "pd.sum", + "pd.full_int_array", + "pd.reshape", + ] + for idx, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), all_op_names[idx]) + + def test_divide_grad_no_prim(self): + newir_program = get_ir_program_0() + paddle.fluid.core._set_prim_backward_enabled(False) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False], [False]] + divide_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + grad_outs = call_vjp(divide_op, out_grads, stop_gradients) + self.assertEqual(len(grad_outs), 2) + self.assertEqual( + grad_outs[0][0].get_defining_op().name(), "pd.divide_grad" + ) + self.assertEqual( + grad_outs[1][0].get_defining_op().name(), "pd.divide_grad" + ) + self.assertEqual(len(newir_program.block().ops), 5) + + def test_sum_grad_prim(self): + newir_program = get_ir_program_1() + paddle.fluid.core._set_prim_backward_enabled(True) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False]] + sum_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + grad_outs = call_vjp(sum_op, out_grads, stop_gradients) + expand_op = newir_program.block().ops[-1] + self.assertEqual(len(grad_outs), 1) + self.assertEqual(len(newir_program.block().ops), 8) + self.assertEqual(expand_op.result(0), grad_outs[0][0]) + all_op_names = [ + "pd.full", + "pd.full", + "pd.full_int_array", + "pd.sum", + "pd.full_int_array", + "pd.reshape", + "pd.full_int_array", + "pd.expand", + ] + for idx, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), all_op_names[idx]) + + def test_sum_grad_no_prim(self): + newir_program = get_ir_program_1() + paddle.fluid.core._set_prim_backward_enabled(False) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False]] + sum_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + grad_outs = call_vjp(sum_op, out_grads, stop_gradients) + self.assertEqual(len(grad_outs), 1) + self.assertEqual( + grad_outs[0][0].get_defining_op().name(), "pd.sum_grad" + ) + self.assertEqual(len(newir_program.block().ops), 6) if __name__ == "__main__": From 0c8db6eb6922263415563af0dcf087826a31c908 Mon Sep 17 00:00:00 2001 From: chenzhiyang <1792266893@qq.com> Date: Thu, 17 Aug 2023 03:13:14 +0000 Subject: [PATCH 13/30] resolve type conflict --- paddle/fluid/ir/dialect/op_generator/op_interface_gen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 368730bb27989..7c04fa14033d9 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -23,13 +23,13 @@ """ OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ - {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" + {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ - Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" + Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ - std::vector {output_grad_name}(std::make_shared(out_grads[{idx1}]));""" + std::vector {output_grad_name}(std::make_shared(out_grads[{idx1}]));""" OP_VJP_ATTRIBUTE_TEMPLATE = """ {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" @@ -48,7 +48,7 @@ res[i].resize(tensor_res[i].size()); for (size_t j = 0; j < tensor_res[i].size(); ++j) {{ if(tensor_res[i][j].defined()){{ - res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); + res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); }} }} }}""" From ffc71f295d0adc667e754a7a62af8523c4b3212b Mon Sep 17 00:00:00 2001 From: Charles-hit Date: Thu, 17 Aug 2023 04:00:17 +0000 Subject: [PATCH 14/30] modify utils --- paddle/fluid/primitive/CMakeLists.txt | 1 + .../rule/{vjp => utils}/eager_utils.cc | 2 +- .../rule/{vjp => utils}/static_utils.cc | 2 +- .../fluid/primitive/rule/vjp/CMakeLists.txt | 10 ------- paddle/fluid/primitive/rule/vjp/details.h | 2 +- paddle/fluid/primitive/rule/vjp/vjp.cc | 2 +- paddle/fluid/primitive/utils/CMakeLists.txt | 10 +++++++ paddle/fluid/primitive/utils/eager_utils.cc | 26 +++++++++++++++++++ paddle/fluid/primitive/utils/static_utils.cc | 25 ++++++++++++++++++ .../primitive/{rule/vjp => utils}/utils.h | 0 10 files changed, 66 insertions(+), 14 deletions(-) rename paddle/fluid/primitive/rule/{vjp => utils}/eager_utils.cc (95%) rename paddle/fluid/primitive/rule/{vjp => utils}/static_utils.cc (94%) create mode 100644 paddle/fluid/primitive/utils/CMakeLists.txt create mode 100644 paddle/fluid/primitive/utils/eager_utils.cc create mode 100644 paddle/fluid/primitive/utils/static_utils.cc rename paddle/fluid/primitive/{rule/vjp => utils}/utils.h (100%) diff --git a/paddle/fluid/primitive/CMakeLists.txt b/paddle/fluid/primitive/CMakeLists.txt index 5134cb0134989..aab7919dfe49d 100644 --- a/paddle/fluid/primitive/CMakeLists.txt +++ b/paddle/fluid/primitive/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(utils) add_subdirectory(backend) add_subdirectory(rule) diff --git a/paddle/fluid/primitive/rule/vjp/eager_utils.cc b/paddle/fluid/primitive/rule/utils/eager_utils.cc similarity index 95% rename from paddle/fluid/primitive/rule/vjp/eager_utils.cc rename to paddle/fluid/primitive/rule/utils/eager_utils.cc index 46a1310a6a351..e9ad10407e32a 100644 --- a/paddle/fluid/primitive/rule/vjp/eager_utils.cc +++ b/paddle/fluid/primitive/rule/utils/eager_utils.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" -#include "paddle/fluid/primitive/rule/vjp/utils.h" +#include "paddle/fluid/primitive/utils/utils.h" namespace paddle { namespace primitive { diff --git a/paddle/fluid/primitive/rule/vjp/static_utils.cc b/paddle/fluid/primitive/rule/utils/static_utils.cc similarity index 94% rename from paddle/fluid/primitive/rule/vjp/static_utils.cc rename to paddle/fluid/primitive/rule/utils/static_utils.cc index c08c755c933a7..40cbbc8d21e89 100644 --- a/paddle/fluid/primitive/rule/vjp/static_utils.cc +++ b/paddle/fluid/primitive/rule/utils/static_utils.cc @@ -11,8 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/primitive/rule/vjp/utils.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" namespace paddle { namespace primitive { diff --git a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt index 2a8c401a77eb0..3243228d1127d 100644 --- a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt +++ b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt @@ -1,14 +1,4 @@ file(GLOB VJP_SRCS "vjp.cc") -if(WITH_PYTHON OR NOT ON_INFER) - cc_library( - primitive_eager_utils_experimental - SRCS eager_utils.cc - DEPS phi common_infer_shape_functions) -endif() -cc_library( - primitive_static_utils_experimental - SRCS static_utils.cc - DEPS phi common_infer_shape_functions) cc_library( primitive_vjp_experimental SRCS ${VJP_SRCS} diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index f42a084974167..6ee9c5880b6d6 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -22,8 +22,8 @@ #include #include "paddle/fluid/primitive/primitive/primitive.h" -#include "paddle/fluid/primitive/rule/vjp/utils.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" namespace paddle { namespace primitive { diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 613436fa6feab..59fabfc87cfa2 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -17,8 +17,8 @@ #include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/primitive/backend/static_backend.h" #include "paddle/fluid/primitive/rule/vjp/details.h" -#include "paddle/fluid/primitive/rule/vjp/utils.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" #include "paddle/ir/core/operation.h" // TODO(wanghao107): // op's vjp will be auto generated. diff --git a/paddle/fluid/primitive/utils/CMakeLists.txt b/paddle/fluid/primitive/utils/CMakeLists.txt new file mode 100644 index 0000000000000..044198c827fb2 --- /dev/null +++ b/paddle/fluid/primitive/utils/CMakeLists.txt @@ -0,0 +1,10 @@ +if(WITH_PYTHON OR NOT ON_INFER) + cc_library( + primitive_eager_utils_experimental + SRCS eager_utils.cc + DEPS phi common_infer_shape_functions) +endif() +cc_library( + primitive_static_utils_experimental + SRCS static_utils.cc + DEPS phi common_infer_shape_functions) diff --git a/paddle/fluid/primitive/utils/eager_utils.cc b/paddle/fluid/primitive/utils/eager_utils.cc new file mode 100644 index 0000000000000..e9ad10407e32a --- /dev/null +++ b/paddle/fluid/primitive/utils/eager_utils.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" +#include "paddle/fluid/primitive/utils/utils.h" + +namespace paddle { +namespace primitive { +template <> +void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { + x->set_impl(x_tmp.impl()); + x->set_autograd_meta(x_tmp.mutable_autograd_meta()); +} + +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/utils/static_utils.cc b/paddle/fluid/primitive/utils/static_utils.cc new file mode 100644 index 0000000000000..40cbbc8d21e89 --- /dev/null +++ b/paddle/fluid/primitive/utils/static_utils.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" + +namespace paddle { +namespace primitive { +template <> +void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { + x->set_impl(x_tmp.impl()); +} + +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/utils.h b/paddle/fluid/primitive/utils/utils.h similarity index 100% rename from paddle/fluid/primitive/rule/vjp/utils.h rename to paddle/fluid/primitive/utils/utils.h From c161003978267c179d5e49ab956f6317fa3f0fe9 Mon Sep 17 00:00:00 2001 From: Charles-hit Date: Thu, 17 Aug 2023 04:03:55 +0000 Subject: [PATCH 15/30] remove useless code --- .../fluid/primitive/rule/utils/eager_utils.cc | 26 ------------------- .../primitive/rule/utils/static_utils.cc | 25 ------------------ 2 files changed, 51 deletions(-) delete mode 100644 paddle/fluid/primitive/rule/utils/eager_utils.cc delete mode 100644 paddle/fluid/primitive/rule/utils/static_utils.cc diff --git a/paddle/fluid/primitive/rule/utils/eager_utils.cc b/paddle/fluid/primitive/rule/utils/eager_utils.cc deleted file mode 100644 index e9ad10407e32a..0000000000000 --- a/paddle/fluid/primitive/rule/utils/eager_utils.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" -#include "paddle/fluid/primitive/utils/utils.h" - -namespace paddle { -namespace primitive { -template <> -void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { - x->set_impl(x_tmp.impl()); - x->set_autograd_meta(x_tmp.mutable_autograd_meta()); -} - -} // namespace primitive -} // namespace paddle diff --git a/paddle/fluid/primitive/rule/utils/static_utils.cc b/paddle/fluid/primitive/rule/utils/static_utils.cc deleted file mode 100644 index 40cbbc8d21e89..0000000000000 --- a/paddle/fluid/primitive/rule/utils/static_utils.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/fluid/primitive/type/lazy_tensor.h" -#include "paddle/fluid/primitive/utils/utils.h" - -namespace paddle { -namespace primitive { -template <> -void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { - x->set_impl(x_tmp.impl()); -} - -} // namespace primitive -} // namespace paddle From 9fcfe39a5131b8f6970068e9044ccb793160d226 Mon Sep 17 00:00:00 2001 From: wangruting Date: Thu, 17 Aug 2023 09:25:57 +0000 Subject: [PATCH 16/30] add split op and modify some bug of vectorType --- paddle/fluid/ir/dialect/pd_manual_api.cc | 5 +- paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 19 ++--- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 34 ++++++++- .../ir/phi_kernel_adaptor/phi_kernel_util.h | 2 +- .../ir/transforms/pd_op_to_kernel_pass.cc | 72 ++++++++++++++++++- .../fluid/primitive/backend/static_backend.cc | 2 - paddle/fluid/primitive/rule/vjp/vjp.cc | 8 +-- paddle/ir/core/builtin_dialect.cc | 1 + paddle/ir/core/builtin_op.cc | 31 +++++++- paddle/ir/core/builtin_op.h | 22 ++++++ test/cpp/prim/test_vjp.cc | 61 ++++++++++++++++ 11 files changed, 233 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_manual_api.cc b/paddle/fluid/ir/dialect/pd_manual_api.cc index df80e1639940e..10e25ea883b91 100644 --- a/paddle/fluid/ir/dialect/pd_manual_api.cc +++ b/paddle/fluid/ir/dialect/pd_manual_api.cc @@ -25,13 +25,12 @@ std::vector concat_grad(std::vector x, ir::OpResult axis) { auto combine_op = APIBuilder::Instance().GetBuilder()->Build(x); - paddle::dialect::ConcatGradOp concat_grad_op = APIBuilder::Instance().GetBuilder()->Build( combine_op.out(), out_grad, axis); - auto slice_op = APIBuilder::Instance().GetBuilder()->Build( + auto split_op = APIBuilder::Instance().GetBuilder()->Build( concat_grad_op.result(0)); - return slice_op.outputs(); + return split_op.outputs(); } } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index 6dbc3e478976c..87c5a95e370be 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -115,18 +115,21 @@ std::vector> ConcatOp::Vjp( Tensor out_grad( std::make_shared(out_grads[0][0])); - Tensor axis( std::make_shared(op_obj.axis())); - std::vector> tensor_res = primitive::experimental::concat_vjp(x, out_grad, axis, stop_gradients); - std::vector> res(1, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = std::static_pointer_cast( - tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); + std::vector> res(1, std::vector()); + std::cout << "ConcatOp::Vjp called 4" << std::endl; + res[0] = std::vector(tensor_res[0].size()); + for (uint64_t idx = 0; idx < tensor_res[0].size(); idx++) { + if (tensor_res[0][idx].defined()) { + res[0][idx] = + std::static_pointer_cast( + tensor_res[0][idx].impl()) + ->getValue() + .dyn_cast(); + } } return res; } diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index b5bf6e123ac5e..c7573c04e83ad 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -54,7 +54,10 @@ void AddNewData(ir::Value value, std::string>* variable_2_var_name, std::map* var_name_2_id, std::vector* variable_list) { - value_2_var_name->emplace(value, name); + if (value_2_var_name->count(value) == 0) { + value_2_var_name->emplace(value, name); + } + variable_2_var_name->emplace(var, name); if (var_name_2_id->count(name) == 0) { auto id = var_name_2_id->size(); @@ -174,7 +177,6 @@ void BuildValue(ir::Value value, var_name_2_id, variable_list); } - // Only support DenseTensor or Vector if (!value.type()) { var->GetMutable(); @@ -200,6 +202,7 @@ void BuildValue(ir::Value value, variable_2_var_name, var_name_2_id, variable_list); + var_i->GetMutable(); tensor_array->emplace_back(var_i); } @@ -412,6 +415,30 @@ void HandleForSpecialOp( std::string var_name = variable_2_var_name->at(variable_array[index]); value_2_var_name->emplace(out_value, var_name); } + + if (op_name == "builtin.split") { + VLOG(6) << "Handle for builtin.split"; + auto in_value = op->operand_source(0); + PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value), + true, + phi::errors::PreconditionNotMet( + "input of buildin split not in name map")); + + auto in_var = inner_scope->FindVar(value_2_var_name->at(in_value)); + auto variable_array = in_var->Get(); + + for (uint64_t idx = 0; idx < variable_array.size(); ++idx) { + auto out_value = op->result(idx); + PADDLE_ENFORCE_EQ( + variable_2_var_name->count(variable_array[idx]), + true, + phi::errors::PreconditionNotMet("[%d] the variable in build split " + "input MUST in variable name map", + idx)); + std::string var_name = variable_2_var_name->at(variable_array[idx]); + value_2_var_name->emplace(out_value, var_name); + } + } } void HandleForInplaceOp( @@ -498,7 +525,8 @@ void BuildScope(const ir::Block& block, if (op_name == "pd.feed" || op_name == "pd.fetch" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" || op_name == "builtin.get_parameter" || op_name == "builtin.slice" || - op_name == "pd.data" || op_name == "pd.shadow_output") { + op_name == "builtin.split" || op_name == "pd.data" || + op_name == "pd.shadow_output") { HandleForSpecialOp(op, inner_scope, var_name_prefix, diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index f59b8d927cbdd..2b024f4786893 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -311,7 +311,7 @@ void BuildPhiContext(ir::Operation* op, ->Get())))); } else if (out_type.isa()) { OutListType outputs; - auto& variable_array = scope->FindVar(name_map.at(out_ptr)) + auto& variable_array = inner_scope->FindVar(name_map.at(out_ptr)) ->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { outputs.emplace_back(OutType(const_cast( diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index 3bb7cd161ff81..45e257b493a08 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -53,6 +53,7 @@ const std::unordered_set UnchangeOutputOps = { "pd.data", "builtin.combine", "builtin.slice", + "builtin.split", "pd.feed", "pd.fetch", "builtin.set_parameter", @@ -509,7 +510,76 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, op_output_types.push_back(allocated_dense_tensor_dtype); } else { PADDLE_THROW(phi::errors::Unimplemented( - "builtin.combine Result type only support DenseTensorType")); + "builtin.slice Result type only support DenseTensorType")); + } + } + } + // Get op info + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); + // Generate new op + ir::Operation* op = ir::Operation::Create( + vec_inputs, op_item->attributes(), op_output_types, op_info); + program->block()->push_back(op); + map_op_pair[op_item] = op; + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + map_value_pair[op_item->result(i)] = op->result(i); + } + } + VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); + continue; + } + + if (op_item->name() == "builtin.split") { + phi::Place out_place = place; + // Copy op inputs + std::vector vec_inputs; + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST in map pair", + i, + op_item->name())); + auto new_in = map_value_pair.at(cur_in); + vec_inputs.push_back(new_in); + + if (new_in.type().isa()) { + auto vec_types = new_in.type().dyn_cast().data(); + out_place = + vec_types[0] + .dyn_cast() + .place(); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("only support vector type for now")); + } + } + } + // Copy op output type + std::vector op_output_types; + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + auto result_type = op_item->result(i).type(); + if (!result_type) { + op_output_types.push_back(result_type); + } else if (result_type.isa()) { + auto allocated_dense_tensor_dtype = + paddle::dialect::AllocatedDenseTensorType::get( + ctx, + out_place, + result_type.dyn_cast()); + op_output_types.push_back(allocated_dense_tensor_dtype); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "builtin.split Result type only support DenseTensorType")); } } } diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index b1f0b85015d32..518fe467beda0 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -52,10 +52,8 @@ Tensor mean_grad(const Tensor& x, std::static_pointer_cast(out_grad.impl()) ->getValue() .dyn_cast(); - ir::OpResult op_res = paddle::dialect::mean_grad( x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); - return Tensor(std::make_shared(op_res)); } diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index c1aada25fa80e..e420406db9f23 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -116,13 +116,11 @@ std::vector> concat_vjp( const Tensor& out_grad, const Tensor& axis, const std::vector>& stop_gradients) { - std::vector> vjp_res( - 1, std::vector(1)); + std::vector> vjp_res(1, std::vector()); // get concat_grad res. std::vector op_res = backend::experimental::concat_grad( x, out_grad, axis); - // set op stop_gradient info // TODO(wanghao107): Replace with more generic code. // Support set stop_gradients for all ops. @@ -146,9 +144,9 @@ std::vector> concat_vjp( grad_op->set_attribute( "stop_gradient", ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients)); - // construct vjp result by op result and stop_gradients info - for (auto idx = 0; idx <= op_res[0].size(); idx++) { + vjp_res[0] = std::vector(op_res.size()); + for (uint64_t idx = 0; idx < op_res.size(); idx++) { if (!stop_gradients[0][idx]) { vjp_res[0][idx] = op_res[idx]; } diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index 3284a96c8b519..375bf90d2b8fd 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -55,6 +55,7 @@ void BuiltinDialect::initialize() { SetParameterOp, CombineOp, SliceOp, + SplitOp, ConstantOp>(); } diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index bb2acd9606399..cee89a04f6104 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -207,7 +207,35 @@ void SliceOp::Verify() const { output_type); } -void SliceOp::Build(Builder &builder, +void SplitOp::Verify() const { + // inputs.size() == 1 + IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1."); + + // input_type == Vector + auto input_type = (*this)->operand(0).type().dyn_cast(); + IR_ENFORCE(input_type, "The type of inputs[0] must be equal to VectorType."); + + // inputs[0].size() == outputs.size() + auto output_num = num_results(); + IR_ENFORCE(input_type.size() == output_num, + "The size %d of output must be equal to size %d of inputs.", + output_num, + input_type.size()); + + // for all i in outputs.size(): outputs[i].type == inputs[0][i].type + for (size_t i = 0; i < output_num; ++i) { + auto type = (*this)->result(i).type(); + IR_ENFORCE(input_type[i] == type, + "The type %s of inputs[0][%d] must be " + "equal to type %s of outputs[%d].", + input_type[i], + i, + type, + i); + } +} + +void SplitOp::Build(Builder &builder, OperationArgument &argument, const ir::OpResult &input) { argument.inputs = {input}; @@ -244,5 +272,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::SplitOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp) diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index b840fd9cf0e98..ae1b748bd339d 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -116,6 +116,27 @@ class IR_API SliceOp : public ir::Op { static const char *attributes_name[attributes_num]; + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const ir::OpResult &input); + + void Verify() const; + ir::Value input() { return operand_source(0); } +}; + +/// +/// \brief SplitOp: SplitOp(OpOperand) +/// +class IR_API SplitOp : public ir::Op { + public: + using Op::Op; + + static const char *name() { return "builtin.split"; } + + static constexpr uint32_t attributes_num = 0; + + static constexpr const char **attributes_name = nullptr; + static void Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT const ir::OpResult &input); @@ -165,5 +186,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::GetParameterOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SetParameterOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::CombineOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SliceOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SplitOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantOp) diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 9eb865a579765..4821df76fd284 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -204,5 +204,66 @@ TEST(VJP, MeanBackwardTest) { ASSERT_EQ(grad_out_tensor.data()[3], 0.25); } +TEST(VJP, ConcatBackwardTest) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + paddle::dialect::APIBuilder::Instance().SetProgram(&program); + + std::shared_ptr builder = + paddle::dialect::APIBuilder::Instance().GetBuilder(); + paddle::dialect::FullOp op1 = builder->Build( + std::vector{1, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + std::vector combine_input{{op1.out(), op1.out()}}; + ir::CombineOp op2 = builder->Build(combine_input); + paddle::dialect::ConcatOp op3 = + builder->Build(op2.out(), 0); + + paddle::dialect::FullOp op4 = builder->Build( + std::vector{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); + std::vector> stop_gradients{{false, false}}; + std::vector> out_grads{{op4.out()}}; + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.concat"); + auto concat_vjp_interface_impl = + op2_info.GetInterfaceImpl(); + concat_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + std::stringstream os; + os << reinterpret_cast( + const_cast(test_core.Impl())); + std::string prefix_str = os.str(); + test_core.SetSkipGcVars({prefix_str + "_inner_var_3", + prefix_str + "_inner_var_7", + prefix_str + "_inner_var_8"}); + test_core.Run({}); + auto out_tensor = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_3")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_3") + ->Get(); + auto grad_out_tensor_0 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_7")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_7") + ->Get(); + auto grad_out_tensor_1 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_8") + ->Get(); + ASSERT_EQ(out_tensor.data()[0], 2.0); + ASSERT_EQ(grad_out_tensor_0.data()[0], 1.0); + ASSERT_EQ(grad_out_tensor_0.data()[1], 1.0); + ASSERT_EQ(grad_out_tensor_1.data()[0], 1.0); + ASSERT_EQ(grad_out_tensor_1.data()[1], 1.0); +} + } // namespace framework } // namespace paddle From ad8ea1630cab1997f1978d5a4b2bbb062f2112ae Mon Sep 17 00:00:00 2001 From: wangruting Date: Thu, 17 Aug 2023 12:20:19 +0000 Subject: [PATCH 17/30] fix conflict --- paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 26 +++++++++---------- .../fluid/primitive/backend/static_backend.cc | 10 +++---- paddle/fluid/primitive/rule/vjp/vjp.cc | 9 ++----- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index 5223036a33c3a..e5af670822f10 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -36,26 +36,24 @@ std::vector> ConcatOp::Vjp( op_obj.x().GetDefiningOp()->dyn_cast(); std::vector x; for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { - x.emplace_back(std::make_shared( - combine_op_obj.inputs()[idx])); + x.emplace_back( + std::make_shared(combine_op_obj.inputs()[idx])); } - Tensor out_grad( - std::make_shared(out_grads[0][0])); - Tensor axis( - std::make_shared(op_obj.axis())); + Tensor out_grad(std::make_shared(out_grads[0][0])); + Tensor axis(std::make_shared(op_obj.axis())); + std::vector> tensor_res = - primitive::experimental::concat_vjp(x, out_grad, axis, stop_gradients); + primitive::concat_vjp(x, out_grad, axis, stop_gradients); + std::vector> res(1, std::vector()); - std::cout << "ConcatOp::Vjp called 4" << std::endl; - res[0] = std::vector(tensor_res[0].size()); + res[0].resize(tensor_res[0].size()); for (uint64_t idx = 0; idx < tensor_res[0].size(); idx++) { if (tensor_res[0][idx].defined()) { - res[0][idx] = - std::static_pointer_cast( - tensor_res[0][idx].impl()) - ->getValue() - .dyn_cast(); + res[0][idx] = std::static_pointer_cast( + tensor_res[0][idx].impl()) + ->getValue() + .dyn_cast(); } } return res; diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index 13c45b088c0b6..3bb9c616a7819 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -244,22 +244,22 @@ Tensor sum_grad(const Tensor& x, } template <> -std::vector concat_grad(const std::vector& x, +std::vector concat_grad(const std::vector& x, const Tensor& out_grad, const Tensor& axis) { std::vector x_res; for (uint64_t idx = 0; idx < x.size(); idx++) { - x_res.emplace_back(std::static_pointer_cast(x[idx].impl()) + x_res.emplace_back(std::static_pointer_cast(x[idx].impl()) ->getValue() .dyn_cast()); } ir::OpResult out_grad_res = - std::static_pointer_cast(out_grad.impl()) + std::static_pointer_cast(out_grad.impl()) ->getValue() .dyn_cast(); - ir::OpResult axis_res = std::static_pointer_cast(axis.impl()) + ir::OpResult axis_res = std::static_pointer_cast(axis.impl()) ->getValue() .dyn_cast(); @@ -269,7 +269,7 @@ std::vector concat_grad(const std::vector& x, std::vector op_result; for (uint64_t idx = 0; idx < op_res.size(); idx++) { op_result.emplace_back( - std::make_shared(op_res[idx])); + std::make_shared(op_res[idx])); } return op_result; } diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 96dcfb0cf454e..308da083c15d7 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -156,15 +156,10 @@ std::vector> concat_vjp( std::vector> vjp_res(1, std::vector()); // get concat_grad res. std::vector op_res = - backend::experimental::concat_grad( - x, out_grad, axis); + backend::concat_grad(x, out_grad, axis); - // set op stop_gradient info - // TODO(wanghao107): Replace with more generic code. - // Support set stop_gradients for all ops. ir::Operation* grad_op = - std::static_pointer_cast( - op_res[0].impl()) + std::static_pointer_cast(op_res[0].impl()) ->getValue() .dyn_cast() .owner(); From 011c611a8881b46c7cb901b20cd8b9faee000c0b Mon Sep 17 00:00:00 2001 From: wangruting Date: Fri, 18 Aug 2023 06:34:21 +0000 Subject: [PATCH 18/30] add concat python test --- paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 16 ++-- paddle/fluid/primitive/rule/vjp/vjp.cc | 25 +----- paddle/fluid/pybind/ir.cc | 10 ++- python/paddle/autograd/backward.py | 97 +++++++++------------ test/ir/new_ir/test_ir_backward.py | 38 +++++++- 5 files changed, 96 insertions(+), 90 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index e5af670822f10..a68d0ee505816 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -45,15 +45,17 @@ std::vector> ConcatOp::Vjp( std::vector> tensor_res = primitive::concat_vjp(x, out_grad, axis, stop_gradients); - - std::vector> res(1, std::vector()); - res[0].resize(tensor_res[0].size()); - for (uint64_t idx = 0; idx < tensor_res[0].size(); idx++) { - if (tensor_res[0][idx].defined()) { - res[0][idx] = std::static_pointer_cast( - tensor_res[0][idx].impl()) + std::vector> res(tensor_res.size(), + std::vector()); + for (uint64_t i = 0; i < tensor_res.size(); i++) { + res[i].resize(tensor_res[i].size()); + for (uint64_t j = 0; j < tensor_res[i].size(); j++) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) ->getValue() .dyn_cast(); + } } } return res; diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 308da083c15d7..1ea264ad29101 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -153,37 +153,20 @@ std::vector> concat_vjp( const Tensor& out_grad, const Tensor& axis, const std::vector>& stop_gradients) { - std::vector> vjp_res(1, std::vector()); + std::vector> vjp_res(2, std::vector()); // get concat_grad res. std::vector op_res = backend::concat_grad(x, out_grad, axis); - ir::Operation* grad_op = - std::static_pointer_cast(op_res[0].impl()) - ->getValue() - .dyn_cast() - .owner(); - uint32_t num_res = grad_op->num_results(); - std::vector ir_stop_gradients(num_res); - for (size_t i = 0; i < num_res; i++) { - if (stop_gradients[0][i]) { - ir_stop_gradients[i] = - ir::BoolAttribute::get(ir::IrContext::Instance(), true); - } else { - ir_stop_gradients[i] = - ir::BoolAttribute::get(ir::IrContext::Instance(), false); - } - } - grad_op->set_attribute( - "stop_gradient", - ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients)); // construct vjp result by op result and stop_gradients info - vjp_res[0] = std::vector(op_res.size()); + vjp_res[0].resize(op_res.size()); for (uint64_t idx = 0; idx < op_res.size(); idx++) { if (!stop_gradients[0][idx]) { vjp_res[0][idx] = op_res[idx]; } } + // vjp_res[1] is axis's grad which is attribute (no grad). + vjp_res[1].resize(1); return vjp_res; } diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index a6da23bc78e0f..05295243b124c 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -249,6 +249,7 @@ void BindValue(py::module *m) { .def("get_defining_op", &Value::GetDefiningOp, return_value_policy::reference) + .def("first_use", &Value::first_use, return_value_policy::reference) .def("__eq__", &Value::operator==) .def("__eq__", [](Value &self, OpResult &other) { @@ -272,9 +273,11 @@ void BindOpOperand(py::module *m) { op_operand .def("source", [](OpOperand &self) { return self.source().dyn_cast(); }) - .def("set_source", [](OpOperand &self, const OpResult &result) { - self.set_source(result); - }); + .def("set_source", + [](OpOperand &self, const OpResult &result) { + self.set_source(result); + }) + .def("owner", &OpOperand::owner, return_value_policy::reference); } bool GetStopGradient(const OpResult &self) { @@ -331,6 +334,7 @@ void BindOpResult(py::module *m) { .def("get_defining_op", &OpResult::GetDefiningOp, return_value_policy::reference) + .def("first_use", &OpResult::first_use, return_value_policy::reference) .def("use_empty", &OpResult::use_empty) .def("type", &OpResult::type) .def_property( diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index 671182f7c3040..67fafbae389c5 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -216,11 +216,11 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): return effective_ops, uneffective_ops -def update_no_grad_set_after_purne( +def update_no_grad_set_after_prune( block, effective_forward_op, no_grad_set, inputs, outputs ): ''' - update no_grad_set after forward purne + update no_grad_set after forward prune from inputs to outputs add value not in the path to no_grad_set, from outputs to inputs add value not in the path to no_grad_set, @@ -338,19 +338,19 @@ def append_backward_ops( else continue to next op. ''' - def make_output_grad(op, split_op): + def make_output_grad(op): zero_flag = [False] * op.num_results() for i, value in enumerate(op.results()): if ( value not in state.value_to_valuegrad or state.value_to_valuegrad[value] is None ): - if split_op is not None and value == split_op.operand_source(0): + if value.first_use().owner().name() == "builtin.split": # pattern case: # this fwd_op's output is vectorType, it will split to # Type by builtin.split op, so need get from split op's ouput split_zero_flag, split_output_grad = make_output_grad( - split_op, None + value.first_use().owner() ) zero_flag[i] = all(split_zero_flag) grad_value = [op_list[0] for op_list in split_output_grad] @@ -400,11 +400,11 @@ def make_output_grad(op, split_op): output_grad = state.value_to_valuegrad[value][0] return zero_flag, output_grad - def make_input_stopgradient(combine_op, op): + def make_input_stopgradient(op): input_grad_stopgradient_list = [] for input in op.operands_source(): - if combine_op is not None and input == combine_op.result(0): - stop_gradient = make_input_stopgradient(None, combine_op) + if input.get_defining_op().name() == "builtin.combine": + stop_gradient = make_input_stopgradient(input.get_defining_op()) input_grad_stopgradient_list.append( [info[0] for info in stop_gradient] ) @@ -413,13 +413,14 @@ def make_input_stopgradient(combine_op, op): input_grad_stopgradient_list.append([True]) else: input_grad_stopgradient_list.append([False]) - return input_grad_stopgradient_list - def update_input_grad_map(combine_op, op, input_grad_list): + def update_input_grad_map(op, input_grad_list): for i, input in enumerate(op.operands_source()): - if combine_op is not None and input == combine_op.reslut(0): - update_input_grad_map(None, combine_op, input_grad_list[i]) + if input.get_defining_op().name() == "builtin.combine": + update_input_grad_map( + input.get_defining_op(), input_grad_list[i] + ) else: input_grad = input_grad_list[i] if isinstance(input_grad, list): @@ -427,48 +428,24 @@ def update_input_grad_map(combine_op, op, input_grad_list): else: state.value_to_valuegrad[input].append([input_grad]) - # make op to op pattern, there are four patterns: + # there are four patterns: # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) # [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) # [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType) # [op4] (op4's inputs and outputs are not vectorType) # einsum has twp vectorType outputs, special pattern - pattern_effective_op_list = [] - for idx, op in enumerate(effective_forward_op): - if op.name() == "builtin.combine": - pattern_effective_op_list.append([op]) - pattern_effective_op_list[-1].append(effective_forward_op[idx + 1]) - elif op.name() == "builtin.split": - pattern_effective_op_list[-1].append(op) - else: - if ( - not pattern_effective_op_list - or op not in pattern_effective_op_list[-1] - ): - pattern_effective_op_list.append([op]) - - for op_pattern in pattern_effective_op_list: - combine_op = None - split_op = None - if len(op_pattern) == 1: - op = op_pattern[0] - elif len(op_pattern) == 2: - if op_pattern[0] == 'builtin.combine': - combine_op = op_pattern[0] - op = op_pattern[1] - else: - op = op_pattern[0] - split_op = op_pattern[1] - else: - combine_op = op_pattern[0] - op = op_pattern[1] - split_op = op_pattern[2] + clear_effective_forward_op = [] + + for op in effective_forward_op: + if op.name() != "builtin.combine" and op.name() != "builtin.split": + clear_effective_forward_op.append(op) + for op in clear_effective_forward_op: if paddle.framework.core.has_vjp(op): # prepare output_grad output_grad_list = [] # (opresult) - zero_flag, output_grad = make_output_grad(op, split_op) + zero_flag, output_grad = make_output_grad(op) output_grad_list.append(output_grad) # all(zero_flag) support this op has no contribution for grad @@ -477,9 +454,7 @@ def update_input_grad_map(combine_op, op, input_grad_list): continue # prepare input_grad stop_gradient info. - input_grad_stopgradient_list = make_input_stopgradient( - combine_op, op - ) + input_grad_stopgradient_list = make_input_stopgradient(op) # create grad_op before_ops_num = len(block.ops) @@ -495,7 +470,7 @@ def update_input_grad_map(combine_op, op, input_grad_list): ) # update input_grad map - update_input_grad_map(combine_op, op, input_grad_list) + update_input_grad_map(op, input_grad_list) else: if op.num_operands() == 0 and op.num_results() != 0: @@ -526,17 +501,23 @@ def update_input_grad_map(combine_op, op, input_grad_list): state.op_to_opgrad[op] = [] -def create_backward_purne_set(inputs, outputs, no_grad_set, state): +def create_backward_prune_set(inputs, outputs, no_grad_set, state): outputs_set = set() for input in inputs: - if state.value_to_valuegrad[input] != []: - outputs_set.add(state.value_to_valuegrad[input][0][0]) - + for item in input.first_use().owner().operands_source(): + if state.value_to_valuegrad[item] != []: + outputs_set.add(state.value_to_valuegrad[item][0][0]) inputs_set = set() for output in outputs: if state.value_to_valuegrad[output] != []: inputs_set.add(state.value_to_valuegrad[output][0][0]) + inputs_set_tmp = set() + for out_grad in inputs_set: + for item in out_grad.first_use().owner().operands_source(): + inputs_set_tmp.add(item) + inputs_set.update(inputs_set_tmp) + no_gradvar_set = set() # grad_value of value in no_grad_set for key in state.value_to_valuegrad: if key in no_grad_set: @@ -590,31 +571,31 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): effective_forward_op, _ = prune_ops( block.ops, inputs_set, outputs_set, no_grad_set ) - update_no_grad_set_after_purne( + update_no_grad_set_after_prune( block, effective_forward_op, no_grad_set, inputs, complete_outputs ) - sorted_effective_forward_op = inverse_sort_op(effective_forward_op) + inverse_effective_forward_op = inverse_sort_op(effective_forward_op) append_backward_ops( - block, sorted_effective_forward_op, no_grad_set, backward_ops, state + block, inverse_effective_forward_op, no_grad_set, backward_ops, state ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) - outputs_set, inputs_set, no_gradvar_set = create_backward_purne_set( + outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( inputs, complete_outputs, no_grad_set, state ) _, remove_ops = prune_ops( backward_ops, inputs_set, outputs_set, no_gradvar_set ) - state.turn_map() + state.turn_map() for bwd_op in inverse_sort_op(remove_ops): remove_op(block, bwd_op, state) + state.turn_map() input_grad_map = state.value_to_valuegrad - state.turn_map() return input_grad_map diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index e6b47bbcd106e..63e5bdbc9e4c7 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -116,7 +116,6 @@ def get_ir_program_1(): class TesBackward_2(unittest.TestCase): def test_add_n(self): - # test add_n op newir_program = get_ir_program_1() input_x = newir_program.block().ops[-3].operand(0).source() @@ -130,6 +129,43 @@ def test_add_n(self): self.assertEqual( newir_program.block().ops[-2].name(), "builtin.combine" ) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + def test_concat(self): + newir_program = get_ir_program_1() + input_x = newir_program.block().ops[-3].operand(0).source() + + add_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.concat([add_out, add_out]) + input_grad = grad(out, input_x) + + ops_name = [ + "pd.data", + "pd.data", + "pd.tanh", + "pd.tanh", + "pd.add", + "builtin.combine", + "pd.full", + "pd.concat", + "pd.full", + "builtin.combine", + "pd.concat_grad", + "builtin.split", + "builtin.combine", + "pd.add_n", + "pd.add_grad", + "pd.tanh_grad", + "pd.tanh_grad", + "builtin.combine", + "pd.add_n", + ] + for i, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), ops_name[i]) + + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) if __name__ == "__main__": From add6d90a3e27c61d513e5415272590f4776dda74 Mon Sep 17 00:00:00 2001 From: wangruting Date: Mon, 21 Aug 2023 06:15:12 +0000 Subject: [PATCH 19/30] add split python api to vjp --- paddle/fluid/ir/dialect/pd_manual_api.cc | 12 +++ paddle/fluid/ir/dialect/pd_manual_api.h | 3 + paddle/fluid/ir/dialect/pd_manual_op.cc | 93 +++++++++++++++++++ paddle/fluid/ir/dialect/pd_manual_op.h | 18 ++++ paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 33 +++++++ .../fluid/primitive/backend/static_backend.cc | 20 ++++ .../fluid/primitive/backend/static_backend.h | 3 + paddle/fluid/primitive/rule/vjp/vjp.cc | 22 +++++ paddle/fluid/primitive/rule/vjp/vjp.h | 4 + paddle/fluid/pybind/ops_api.cc | 10 +- paddle/fluid/pybind/static_op_function.cc | 24 +++++ paddle/fluid/pybind/static_op_function.h | 1 + 12 files changed, 242 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/ir/dialect/pd_manual_api.cc b/paddle/fluid/ir/dialect/pd_manual_api.cc index 10e25ea883b91..9b451e2074ab4 100644 --- a/paddle/fluid/ir/dialect/pd_manual_api.cc +++ b/paddle/fluid/ir/dialect/pd_manual_api.cc @@ -32,5 +32,17 @@ std::vector concat_grad(std::vector x, concat_grad_op.result(0)); return split_op.outputs(); } + +ir::OpResult split_grad(std::vector out_grads, + ir::OpResult axis) { + auto combine_op = + APIBuilder::Instance().GetBuilder()->Build(out_grads); + paddle::dialect::SplitGradOp split_grad_op = + APIBuilder::Instance().GetBuilder()->Build( + combine_op.out(), axis); + + return split_grad_op.out(); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_manual_api.h b/paddle/fluid/ir/dialect/pd_manual_api.h index dff38ef565cb2..0a448a037ecae 100644 --- a/paddle/fluid/ir/dialect/pd_manual_api.h +++ b/paddle/fluid/ir/dialect/pd_manual_api.h @@ -26,5 +26,8 @@ namespace dialect { std::vector concat_grad(std::vector x, ir::OpResult out_grad, ir::OpResult axis); + +ir::OpResult split_grad(std::vector out_grads, ir::OpResult axis); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_manual_op.cc b/paddle/fluid/ir/dialect/pd_manual_op.cc index 7f95a6fdf3d3d..68e7075fe168b 100644 --- a/paddle/fluid/ir/dialect/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/pd_manual_op.cc @@ -145,6 +145,99 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +OpInfoTuple SplitGradOp::GetOpInfo() { + std::vector inputs = { + OpInputInfo("out_grad", + "ir::VectorType", + false, + false, + false), + OpInputInfo( + "axis", "paddle::dialect::ScalarAttribute", false, false, true)}; + std::vector attributes = {}; + std::vector outputs = { + OpOutputInfo("x_grad", "paddle::dialect::DenseTensorType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = + OpRunTimeInfo("ConcatInferMeta", + {"out_grad", "axis"}, + {"concat"}, + {"out_grad", "axis"}, + {"out_grad"}, + {}, + {}); + + return std::make_tuple(inputs, attributes, outputs, run_time_info); +} + +void SplitGradOp::Build(ir::Builder &builder, + ir::OperationArgument &argument, + ir::OpResult out_grad_, + ir::OpResult axis_) { + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {out_grad_, axis_}; + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + + VLOG(4) << "Builder construction attributes"; + + VLOG(4) << "Builder construction outputs"; + ir::VectorType out_grad = out_grad_.type().dyn_cast(); + (void)out_grad; + int axis = axis_.owner() + ->dyn_cast() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + (void)axis; + + std::vector vec_dense_out_grad; + for (size_t i = 0; i < static_cast(out_grad.size()); i++) { + vec_dense_out_grad.push_back(phi::DenseTensor( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta( + TransToPhiDataType(out_grad[i] + .dyn_cast() + .dtype()), + out_grad[i].dyn_cast().dims(), + out_grad[i] + .dyn_cast() + .data_layout(), + out_grad[i].dyn_cast().lod(), + out_grad[i] + .dyn_cast() + .offset()))); + } + std::vector vec_meta_out_grad; + for (size_t i = 0; i < vec_dense_out_grad.size(); i++) { + vec_meta_out_grad.push_back(phi::MetaTensor(&vec_dense_out_grad[i])); + } + + std::vector meta_out_grad; + for (size_t i = 0; i < static_cast(vec_meta_out_grad.size()); i++) { + meta_out_grad.push_back(&vec_meta_out_grad[i]); + } + phi::DenseTensor dense_x_grad; + phi::MetaTensor meta_out(&dense_x_grad); + + phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_gard); + + std::vector argument_outputs; + ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.dims(), + dense_x_grad.layout(), + dense_x_grad.lod(), + dense_x_grad.offset()); + argument_outputs.push_back(x_grad_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + +void SplitGradOp::Verify() {} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_manual_op.h b/paddle/fluid/ir/dialect/pd_manual_op.h index ff055ea6edf8a..0f41ca572e70e 100644 --- a/paddle/fluid/ir/dialect/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/pd_manual_op.h @@ -51,6 +51,24 @@ class AddNOp : public ir::Op { static void InferMeta(phi::InferMetaContext *infer_meta); }; +class SplitGradOp : public ir::Op { + public: + using Op::Op; + static const char *name() { return "pd.split_grad"; } + static const char *attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + static OpInfoTuple GetOpInfo(); + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult out_grad_, + ir::OpResult axis_); + + void Verify(); + ir::Value out_grad() { return operand_source(0); } + ir::Value axis() { return operand_source(1); } + ir::OpResult x_grad() { return result(0); } +}; + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index a68d0ee505816..426bc2d8bf8ed 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -87,5 +87,38 @@ std::vector> SumOp::Vjp( } return res; } + +std::vector> SplitOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + SplitOp op_obj = op->dyn_cast(); + + Tensor axis(std::make_shared(op_obj.axis())); + std::vector out_grads_; + for (size_t idx = 0; idx < out_grads().size(); idx++) { + out_grads_.emplace_back( + std::make_shared(out_grads[0][idx])); + } + + std::vector> tensor_res = + primitive::split_vjp(out_grads_, axis, stop_gradients); + + std::vector> res(tensor_res.size(), + std::vector()); + for (uint64_t i = 0; i < tensor_res.size(); i++) { + res[i].resize(tensor_res[i].size()); + for (uint64_t j = 0; j < tensor_res[i].size(); j++) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) + ->getValue() + .dyn_cast(); + } + } + } + return res; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index 3bb9c616a7819..afa164752878c 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -274,6 +274,26 @@ std::vector concat_grad(const std::vector& x, return op_result; } +template <> +Tensor split_grad(const std::vector& out_grads, + const Tensor& axis) { + std::vector out_grads_res; + for (uint64_t idx = 0; idx < out_grads.size(); idx++) { + out_grads_res.emplace_back( + std::static_pointer_cast(out_grads[idx].impl()) + ->getValue() + .dyn_cast()); + } + + ir::OpResult axis_res = std::static_pointer_cast(axis.impl()) + ->getValue() + .dyn_cast(); + + ir::OpResult op_res = paddle::dialect::split_grad(out_grads_res, axis_res); + + return Tensor(std::make_shared(op_res)); +} + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index ba608a01eeab3..b52f12a379872 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -101,6 +101,9 @@ Tensor sum_grad(const Tensor& x, bool keepdim, bool reduce_all); +template +Tensor split_grad(const std::vector& out_grads, const Tensor& axis); + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 364c590a6edbb..08a97cce8bda5 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -221,5 +221,27 @@ std::vector> sum_vjp( return vjp_res; } +std::vector> split_vjp( + const std::vector& out_grads, + const Tensor& axis, + const std::vector>& stop_gradients) { + std::vector> vjp_res(3, std::vector()); + // get concat_grad res. + std::vector op_res = + backend::split_grad(out_grads, axis); + + // construct vjp result by op result and stop_gradients info + vjp_res[0].resize(op_res.size()); + for (uint64_t idx = 0; idx < op_res.size(); idx++) { + if (!stop_gradients[0][idx]) { + vjp_res[0][idx] = op_res[idx]; + } + } + // vjp_res[1] is sections's grad which is attribute (no grad). + // vjp_res[2] is axis's grad which is attribute (no grad). + vjp_res[1].resize(stop_gradients[1].size()) vjp_res[2].resize( + stop_gradients[2].szie()); + return vjp_res; +} } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.h b/paddle/fluid/primitive/rule/vjp/vjp.h index eace3d3cb5bdf..72df75814ce13 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.h +++ b/paddle/fluid/primitive/rule/vjp/vjp.h @@ -67,5 +67,9 @@ std::vector> sum_vjp( bool reduce_all, const std::vector>& stop_gradients); +std::vector> split_vjp( + const std::vector& out_grads, + const Tensor& axis, + const std::vector>& stop_gradients); } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/pybind/ops_api.cc b/paddle/fluid/pybind/ops_api.cc index 27cbf36aecbcb..9dc9d8752b6a3 100644 --- a/paddle/fluid/pybind/ops_api.cc +++ b/paddle/fluid/pybind/ops_api.cc @@ -44,7 +44,11 @@ static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) { return static_api_concat(self, args, kwargs); } -static PyMethodDef OpsAPI[] = {{"add_n", +static PyObject *split(PyObject *self, PyObject *args, PyObject *kwargs) { + return static_api_split(self, args, kwargs); +} + +static PyMethodDef OpsAPI[] = {{"add_n", // NOLINT (PyCFunction)(void (*)(void))add_n, METH_VARARGS | METH_KEYWORDS, "C++ interface function for add_n."}, @@ -68,6 +72,10 @@ static PyMethodDef OpsAPI[] = {{"add_n", (PyCFunction)(void (*)(void))full, METH_VARARGS | METH_KEYWORDS, "C++ interface function for full."}, + {"split", + (PyCFunction)(void (*)(void))full, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for full."}, {nullptr, nullptr, 0, nullptr}}; void BindOpsAPI(pybind11::module *module) { diff --git a/paddle/fluid/pybind/static_op_function.cc b/paddle/fluid/pybind/static_op_function.cc index 632f7044c4617..5081cd66597f1 100644 --- a/paddle/fluid/pybind/static_op_function.cc +++ b/paddle/fluid/pybind/static_op_function.cc @@ -155,5 +155,29 @@ PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) { } } +PyObject *static_api_split(PyObject *self, PyObject *args, PyObject *kwargs) { + try { + VLOG(6) << "Add split op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + // Get OpResult from args + PyObject *x_obj = PyTuple_GET_ITEM(args, 0); + auto x = CastPyArg2OpResult("split", x_obj, 0); + + PyObject *sections_obj = PyTuple_GET_ITEM(args, 1); + auto sections = CastPyArg2IntArray(sections_obj, "split", 1); + + PyObject *axis_obj = PyTuple_GET_ITEM(args, 2); + paddle::experimental::Scalar axis = CastPyArg2Scalar(axis_obj, "split", 2); + + // Call ir static api + auto out = paddle::dialect::split( + x, sections.to < std::vector(), axis.to()); + return ToPyObject(out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/static_op_function.h b/paddle/fluid/pybind/static_op_function.h index 02d4777eeef05..5149267ec5603 100644 --- a/paddle/fluid/pybind/static_op_function.h +++ b/paddle/fluid/pybind/static_op_function.h @@ -30,6 +30,7 @@ PyObject *static_api_sum(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs); PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs); +PyObject *static_api_split(PyObject *self, PyObject *args, PyObject *kwargs); } // namespace pybind } // namespace paddle From 4ef0db162359fdea27f4314ea3f29a05ebe095b9 Mon Sep 17 00:00:00 2001 From: wangruting Date: Mon, 21 Aug 2023 12:08:30 +0000 Subject: [PATCH 20/30] modify build bug --- .../fluid/ir/dialect/op_generator/op_gen.py | 2 +- .../op_generator/vjp_interface_gen_op_list.py | 1 + paddle/fluid/ir/dialect/pd_dialect.cc | 5 +- paddle/fluid/ir/dialect/pd_manual_api.cc | 13 +++- paddle/fluid/ir/dialect/pd_manual_api.h | 3 + paddle/fluid/ir/dialect/pd_manual_op.cc | 6 +- paddle/fluid/ir/dialect/pd_manual_op.h | 3 +- paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 2 +- paddle/fluid/primitive/rule/vjp/vjp.cc | 15 ++-- paddle/fluid/pybind/eager_utils.cc | 10 +++ paddle/fluid/pybind/eager_utils.h | 2 +- paddle/fluid/pybind/static_op_function.cc | 3 +- test/cpp/prim/test_vjp.cc | 69 +++++++++++++++++++ 13 files changed, 115 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 423b47ae44ed0..8e37e099c0a89 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -154,7 +154,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ 'bool': 'ir::BoolAttribute', } -_NO_NEED_GEN_OPS = {'add_n'} +_NO_NEED_GEN_OPS = {'add_n', 'split_grad'} def to_phi_and_fluid_op_name(op_item): diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 8077801bf235f..af1bf4c948e36 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -29,5 +29,6 @@ "sum", "add", "concat", + "split", ] vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"] diff --git a/paddle/fluid/ir/dialect/pd_dialect.cc b/paddle/fluid/ir/dialect/pd_dialect.cc index 5ebdf7611a8dc..2082dbca65231 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/pd_dialect.cc @@ -108,7 +108,10 @@ void PaddleDialect::initialize() { #define GET_OP_LIST #include "paddle/fluid/ir/dialect/pd_op.h" // NOLINT >(); - RegisterOp(); + RegisterOps< +#define GET_MANUAL_OP_LIST +#include "paddle/fluid/ir/dialect/pd_manual_op.h" // NOLINT + >(); RegisterInterfaces(); } diff --git a/paddle/fluid/ir/dialect/pd_manual_api.cc b/paddle/fluid/ir/dialect/pd_manual_api.cc index 9b451e2074ab4..0c1d2fed677a0 100644 --- a/paddle/fluid/ir/dialect/pd_manual_api.cc +++ b/paddle/fluid/ir/dialect/pd_manual_api.cc @@ -20,6 +20,17 @@ namespace paddle { namespace dialect { +std::vector split(ir::OpResult x, + const std::vector& sections, + int axis) { + paddle::dialect::SplitOp pd_split_op = + APIBuilder::Instance().GetBuilder()->Build( + x, sections, axis); + auto builtin_split_op = + APIBuilder::Instance().GetBuilder()->Build( + pd_split_op.result(0)); + return builtin_split_op.outputs(); +} std::vector concat_grad(std::vector x, ir::OpResult out_grad, ir::OpResult axis) { @@ -41,7 +52,7 @@ ir::OpResult split_grad(std::vector out_grads, APIBuilder::Instance().GetBuilder()->Build( combine_op.out(), axis); - return split_grad_op.out(); + return split_grad_op.x_grad(); } } // namespace dialect diff --git a/paddle/fluid/ir/dialect/pd_manual_api.h b/paddle/fluid/ir/dialect/pd_manual_api.h index 0a448a037ecae..b952f58f0b03e 100644 --- a/paddle/fluid/ir/dialect/pd_manual_api.h +++ b/paddle/fluid/ir/dialect/pd_manual_api.h @@ -22,6 +22,9 @@ namespace paddle { namespace dialect { +std::vector split(ir::OpResult x, + const std::vector& sections, + int axis); std::vector concat_grad(std::vector x, ir::OpResult out_grad, diff --git a/paddle/fluid/ir/dialect/pd_manual_op.cc b/paddle/fluid/ir/dialect/pd_manual_op.cc index 68e7075fe168b..9eed0f4d78800 100644 --- a/paddle/fluid/ir/dialect/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/pd_manual_op.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/ir/dialect/pd_manual_op.h" #include "paddle/fluid/ir/dialect/pd_attribute.h" +#include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_op.h" @@ -220,9 +221,9 @@ void SplitGradOp::Build(ir::Builder &builder, meta_out_grad.push_back(&vec_meta_out_grad[i]); } phi::DenseTensor dense_x_grad; - phi::MetaTensor meta_out(&dense_x_grad); + phi::MetaTensor meta_x_grad(&dense_x_grad); - phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_gard); + phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad); std::vector argument_outputs; ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( @@ -242,3 +243,4 @@ void SplitGradOp::Verify() {} } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) diff --git a/paddle/fluid/ir/dialect/pd_manual_op.h b/paddle/fluid/ir/dialect/pd_manual_op.h index 0f41ca572e70e..f514f459c9bab 100644 --- a/paddle/fluid/ir/dialect/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/pd_manual_op.h @@ -14,7 +14,7 @@ #ifdef GET_MANUAL_OP_LIST #undef GET_MANUAL_OP_LIST -paddle::dialect::AddNOp +paddle::dialect::AddNOp, paddle::dialect::SplitGradOp #else @@ -73,5 +73,6 @@ class SplitGradOp : public ir::Op { } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) #endif diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index 426bc2d8bf8ed..61b67fa2fe50d 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -96,7 +96,7 @@ std::vector> SplitOp::Vjp( Tensor axis(std::make_shared(op_obj.axis())); std::vector out_grads_; - for (size_t idx = 0; idx < out_grads().size(); idx++) { + for (size_t idx = 0; idx < out_grads.size(); idx++) { out_grads_.emplace_back( std::make_shared(out_grads[0][idx])); } diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 08a97cce8bda5..5a97d41578448 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -227,20 +227,17 @@ std::vector> split_vjp( const std::vector>& stop_gradients) { std::vector> vjp_res(3, std::vector()); // get concat_grad res. - std::vector op_res = - backend::split_grad(out_grads, axis); + Tensor op_res = backend::split_grad(out_grads, axis); // construct vjp result by op result and stop_gradients info - vjp_res[0].resize(op_res.size()); - for (uint64_t idx = 0; idx < op_res.size(); idx++) { - if (!stop_gradients[0][idx]) { - vjp_res[0][idx] = op_res[idx]; - } + if (!stop_gradients[0][0]) { + vjp_res[0][0] = op_res; } + // vjp_res[1] is sections's grad which is attribute (no grad). // vjp_res[2] is axis's grad which is attribute (no grad). - vjp_res[1].resize(stop_gradients[1].size()) vjp_res[2].resize( - stop_gradients[2].szie()); + vjp_res[1].resize(stop_gradients[1].size()); + vjp_res[2].resize(stop_gradients[2].size()); return vjp_res; } } // namespace primitive diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 175ca84cec3bc..3e33590c4da06 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -884,6 +884,16 @@ PyObject* ToPyObject(const ir::OpResult& value) { return obj.ptr(); } +PyObject* ToPyObject(const std::vector& value) { + PyObject* result = PyList_New((Py_ssize_t)value.size()); + + for (size_t i = 0; i < value.size(); i++) { + PyList_SET_ITEM(result, static_cast(i), ToPyObject(value[i])); + } + + return result; +} + #ifdef PADDLE_WITH_DISTRIBUTE PyObject* ToPyObject(const phi::distributed::DistTensor* value) { auto obj = ::pybind11::cast(value, py::return_value_policy::reference); diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 24d8e42a22fef..3663370d03464 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -136,7 +136,7 @@ PyObject* ToPyObject(const paddle::framework::Vocab& value); PyObject* ToPyObject(std::shared_ptr grad_node); PyObject* ToPyObject(const ir::OpResult& value); - +PyObject* ToPyObject(const std::vector& value); class PyTensorHook : public egr::TensorHook { public: explicit PyTensorHook(PyObject* func) : py_func_(func) { diff --git a/paddle/fluid/pybind/static_op_function.cc b/paddle/fluid/pybind/static_op_function.cc index 5081cd66597f1..56f9ca6a2b217 100644 --- a/paddle/fluid/pybind/static_op_function.cc +++ b/paddle/fluid/pybind/static_op_function.cc @@ -170,8 +170,7 @@ PyObject *static_api_split(PyObject *self, PyObject *args, PyObject *kwargs) { paddle::experimental::Scalar axis = CastPyArg2Scalar(axis_obj, "split", 2); // Call ir static api - auto out = paddle::dialect::split( - x, sections.to < std::vector(), axis.to()); + auto out = paddle::dialect::split(x, sections.GetData(), axis.to()); return ToPyObject(out); } catch (...) { ThrowExceptionToPython(std::current_exception()); diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 7ceb38ffcbfb3..3608b74e04aac 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -395,5 +395,74 @@ TEST(VJP, Add_BackwardTest) { ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); } + +TEST(VJP, SplitBackwardTest) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ir::Program program((ctx)); + paddle::dialect::APIBuilder::Instance().SetProgram(&program); + + std::shared_ptr builder = + paddle::dialect::APIBuilder::Instance().GetBuilder(); + paddle::dialect::FullOp op1 = builder->Build( + std::vector{2, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::SplitOp op2 = builder->Build( + op1.out(), std::vector{2}, 0); + + ir::SplitOp op3 = builder->Build(op2.out()); + + paddle::dialect::FullOp op4 = builder->Build( + std::vector{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + std::vector> stop_gradients{{false}, {true}, {true}}; + std::vector> out_grads{{op3.result(0)}, + {op4.out()}}; + ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.split"); + auto concat_vjp_interface_impl = + op2_info.GetInterfaceImpl(); + concat_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + + ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + std::stringstream os; + os << reinterpret_cast( + const_cast(test_core.Impl())); + std::string prefix_str = os.str(); + test_core.SetSkipGcVars({prefix_str + "_inner_var_4", + prefix_str + "_inner_var_5", + prefix_str + "_inner_var_8"}); + test_core.Run({}); + auto out_tensor_0 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_4")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_4") + ->Get(); + auto out_tensor_1 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_5")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_5") + ->Get(); + auto grad_out_tensor_0 = + test_core.local_scope() == nullptr + ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + : test_core.local_scope() + ->FindVar(prefix_str + "_inner_var_8") + ->Get(); + ASSERT_EQ(out_tensor_0.data()[0], 2.0); + ASSERT_EQ(out_tensor_0.data()[1], 2.0); + ASSERT_EQ(out_tensor_1.data()[0], 2.0); + ASSERT_EQ(out_tensor_1.data()[1], 2.0); + ASSERT_EQ(grad_out_tensor_0.data()[0], 1.0); + ASSERT_EQ(grad_out_tensor_0.data()[1], 1.0); + ASSERT_EQ(grad_out_tensor_0.data()[2], 1.0); + ASSERT_EQ(grad_out_tensor_0.data()[3], 1.0); +} + } // namespace framework } // namespace paddle From 453deed2dc7f0faa8dfbfcc2286eb681ec74449d Mon Sep 17 00:00:00 2001 From: wangruting Date: Tue, 22 Aug 2023 11:29:05 +0000 Subject: [PATCH 21/30] modify run bug --- paddle/fluid/ir/dialect/pd_dialect.cc | 6 +----- paddle/fluid/ir/dialect/pd_manual_op.cc | 2 ++ paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 3 ++- .../fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc | 1 - paddle/fluid/primitive/backend/static_backend.cc | 3 --- paddle/fluid/primitive/rule/vjp/vjp.cc | 2 +- python/paddle/tensor/manipulation.py | 3 +++ test/cpp/prim/test_vjp.cc | 12 ++++++------ test/ir/new_ir/test_ir_backward.py | 13 +++++++++++++ 9 files changed, 28 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_dialect.cc b/paddle/fluid/ir/dialect/pd_dialect.cc index 2082dbca65231..b61e8d4220326 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/pd_dialect.cc @@ -103,15 +103,11 @@ void PaddleDialect::initialize() { // generated by op_gen.py, see details in // paddle/fluid/ir/dialect/CMakeLists.txt. // NOTE(Ruting)GET_MANUAL_OP_LIST is define in pd_manual_op.h" - // use RegisterOps when list has more than two ops. RegisterOps< #define GET_OP_LIST #include "paddle/fluid/ir/dialect/pd_op.h" // NOLINT >(); - RegisterOps< -#define GET_MANUAL_OP_LIST -#include "paddle/fluid/ir/dialect/pd_manual_op.h" // NOLINT - >(); + RegisterOps(); RegisterInterfaces(); } diff --git a/paddle/fluid/ir/dialect/pd_manual_op.cc b/paddle/fluid/ir/dialect/pd_manual_op.cc index 9eed0f4d78800..74f2296d186ea 100644 --- a/paddle/fluid/ir/dialect/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/pd_manual_op.cc @@ -146,6 +146,8 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +const char *SplitGradOp::attributes_name[1] = {"axis"}; + OpInfoTuple SplitGradOp::GetOpInfo() { std::vector inputs = { OpInputInfo("out_grad", diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index 61b67fa2fe50d..35cd2887b37f0 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -96,7 +96,7 @@ std::vector> SplitOp::Vjp( Tensor axis(std::make_shared(op_obj.axis())); std::vector out_grads_; - for (size_t idx = 0; idx < out_grads.size(); idx++) { + for (size_t idx = 0; idx < out_grads[0].size(); idx++) { out_grads_.emplace_back( std::make_shared(out_grads[0][idx])); } @@ -106,6 +106,7 @@ std::vector> SplitOp::Vjp( std::vector> res(tensor_res.size(), std::vector()); + for (uint64_t i = 0; i < tensor_res.size(); i++) { res[i].resize(tensor_res[i].size()); for (uint64_t j = 0; j < tensor_res[i].size(); j++) { diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index cadd9a29519ab..e3ab3c38d05c6 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -57,7 +57,6 @@ void AddNewData(ir::Value value, if (value_2_var_name->count(value) == 0) { value_2_var_name->emplace(value, name); } - variable_2_var_name->emplace(var, name); if (var_name_2_id->count(name) == 0) { auto id = var_name_2_id->size(); diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index afa164752878c..baa89cc6cbd76 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -284,13 +284,10 @@ Tensor split_grad(const std::vector& out_grads, ->getValue() .dyn_cast()); } - ir::OpResult axis_res = std::static_pointer_cast(axis.impl()) ->getValue() .dyn_cast(); - ir::OpResult op_res = paddle::dialect::split_grad(out_grads_res, axis_res); - return Tensor(std::make_shared(op_res)); } diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 5a97d41578448..60500fb4e01b3 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -225,7 +225,7 @@ std::vector> split_vjp( const std::vector& out_grads, const Tensor& axis, const std::vector>& stop_gradients) { - std::vector> vjp_res(3, std::vector()); + std::vector> vjp_res(3, std::vector(1)); // get concat_grad res. Tensor op_res = backend::split_grad(out_grads, axis); diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 98da3330238a1..0a6246e74e8b2 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1976,6 +1976,9 @@ def split(x, num_or_sections, axis=0, name=None): else: return _C_ops.split(input, num_or_sections, dim) else: + if paddle.ir.core._use_new_ir_api(): + return paddle._ir_ops.split(input, num_or_sections, dim) + check_variable_and_dtype( input, 'input', diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 3608b74e04aac..61018fa935e15 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -407,7 +407,7 @@ TEST(VJP, SplitBackwardTest) { std::vector{2, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::SplitOp op2 = builder->Build( - op1.out(), std::vector{2}, 0); + op1.out(), std::vector{1, 1}, 0); ir::SplitOp op3 = builder->Build(op2.out()); @@ -415,17 +415,17 @@ TEST(VJP, SplitBackwardTest) { std::vector{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}, {true}, {true}}; - std::vector> out_grads{{op3.result(0)}, - {op4.out()}}; + std::vector> out_grads{{op3.result(0), op4.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.split"); + auto concat_vjp_interface_impl = op2_info.GetInterfaceImpl(); + concat_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); std::stringstream os; @@ -458,8 +458,8 @@ TEST(VJP, SplitBackwardTest) { ASSERT_EQ(out_tensor_0.data()[1], 2.0); ASSERT_EQ(out_tensor_1.data()[0], 2.0); ASSERT_EQ(out_tensor_1.data()[1], 2.0); - ASSERT_EQ(grad_out_tensor_0.data()[0], 1.0); - ASSERT_EQ(grad_out_tensor_0.data()[1], 1.0); + ASSERT_EQ(grad_out_tensor_0.data()[0], 2.0); + ASSERT_EQ(grad_out_tensor_0.data()[1], 2.0); ASSERT_EQ(grad_out_tensor_0.data()[2], 1.0); ASSERT_EQ(grad_out_tensor_0.data()[3], 1.0); } diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index 63e5bdbc9e4c7..02a62425221a2 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -94,6 +94,19 @@ def test_no_grad_set(self): self.assertEqual(newir_program.block().ops[-1].name(), "pd.mean") paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + def test_split(self): + # test create output_grad in backward use full op + newir_program = get_ir_program_0() + input = newir_program.block().ops[-1].operand(0).source() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.split(tanh_out, [1, 1], 0) + input_grad = grad(out, input) + + print(newir_program) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + def get_ir_program_1(): x = paddle.randn([2, 2]) From c21b912229c3ae082412b7de10b69e8fd73abc8a Mon Sep 17 00:00:00 2001 From: wangruting Date: Wed, 23 Aug 2023 02:26:46 +0000 Subject: [PATCH 22/30] fix conflict bug --- .../dialect/paddle_dialect/ir/pd_dialect.cc | 2 +- .../transforms/param_to_variable.cc | 71 ------------------- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 1 + 3 files changed, 2 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc index ddc117cb22c19..19b8b133559b7 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc @@ -48,7 +48,7 @@ void PaddleDialect::initialize() { #define GET_OP_LIST #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" // NOLINT >(); - RegisterOp(); + RegisterOps(); RegisterInterfaces(); } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.cc b/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.cc index 74a0332131e2d..0113e38b8fd5e 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.cc @@ -79,77 +79,6 @@ std::unique_ptr ParameterConvertInterface::VariableToParameter( } } -PaddleDialect::PaddleDialect(ir::IrContext *context) - : ir::Dialect(name(), context, ir::TypeId::get()) { - initialize(); -} - -void PaddleDialect::initialize() { - RegisterTypes(); - RegisterTypes(); - - RegisterAttributes(); - - // NOTE(zhangbo9674): GET_OP_LIST is defined in pd_op.h which is - // generated by op_gen.py, see details in - // paddle/fluid/ir/dialect/CMakeLists.txt. - // NOTE(Ruting)GET_MANUAL_OP_LIST is define in pd_manual_op.h" - RegisterOps< -#define GET_OP_LIST -#include "paddle/fluid/ir/dialect/pd_op.h" // NOLINT - >(); - RegisterOps(); - - RegisterInterfaces(); -} - -void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { - os << type.dialect().name(); - os << '.'; - if (auto tensor_type = type.dyn_cast()) { - os << "tensor<"; - for (auto d : phi::vectorize(tensor_type.dims())) { - os << d; - os << "x"; - } - tensor_type.dtype().Print(os); - os << ">"; - } else if (auto selected_rows_type = type.dyn_cast()) { - os << "selectedrows<"; - for (auto d : phi::vectorize(selected_rows_type.dims())) { - os << d; - os << "x"; - } - selected_rows_type.dtype().Print(os); - os << ">"; - } -} - -void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { - if (auto int_array_attr = attr.dyn_cast()) { - phi::IntArray data = int_array_attr.data(); - os << "IntArray["; - const auto &inner_data = data.GetData(); - ir::PrintInterleave( - inner_data.begin(), - inner_data.end(), - [&os](int64_t i) { os << i; }, - [&os]() { os << ","; }); - os << "]"; - } else if (auto data_type_attr = attr.dyn_cast()) { - os << data_type_attr.data(); - } else if (auto place_type_attr = attr.dyn_cast()) { - os << place_type_attr.data(); - } else if (auto data_layout_attr = attr.dyn_cast()) { - os << data_layout_attr.data(); - } else { - os << "<#AttrNotImplemented>"; - } -} - } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 75af15e7edc60..5cc29e7c38767 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -57,6 +57,7 @@ void AddNewData(ir::Value value, if (value_2_var_name->count(value) == 0) { value_2_var_name->emplace(value, name); } + variable_2_var_name->emplace(var, name); if (var_name_2_id->count(name) == 0) { auto id = var_name_2_id->size(); From 98e6bebcc3f13a82e8bbedc36bc91869e9fe9898 Mon Sep 17 00:00:00 2001 From: wangruting Date: Wed, 23 Aug 2023 11:29:49 +0000 Subject: [PATCH 23/30] build bug fix --- paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 8af1c458b64ff..c34bed7c1f622 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -169,7 +169,8 @@ OpInfoTuple SplitGradOp::GetOpInfo() { {}, {}); - return std::make_tuple(inputs, attributes, outputs, run_time_info); + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "split_grad"); } void SplitGradOp::Build(ir::Builder &builder, From f5d60fba3ba932850f7db241a891e1cba8e4198b Mon Sep 17 00:00:00 2001 From: wangruting Date: Thu, 24 Aug 2023 01:55:34 +0000 Subject: [PATCH 24/30] modify python api bug --- paddle/fluid/pybind/ops_api.cc | 2 +- python/paddle/tensor/manipulation.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/ops_api.cc b/paddle/fluid/pybind/ops_api.cc index a1610c06015bf..fd9c83a8b9247 100644 --- a/paddle/fluid/pybind/ops_api.cc +++ b/paddle/fluid/pybind/ops_api.cc @@ -73,7 +73,7 @@ static PyMethodDef OpsAPI[] = {{"add_n", METH_VARARGS | METH_KEYWORDS, "C++ interface function for full."}, {"split", - (PyCFunction)(void (*)(void))full, + (PyCFunction)(void (*)(void))split, METH_VARARGS | METH_KEYWORDS, "C++ interface function for full."}, {nullptr, nullptr, 0, nullptr}}; diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 0a6246e74e8b2..cdb464bc5c35d 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1977,7 +1977,12 @@ def split(x, num_or_sections, axis=0, name=None): return _C_ops.split(input, num_or_sections, dim) else: if paddle.ir.core._use_new_ir_api(): - return paddle._ir_ops.split(input, num_or_sections, dim) + if not isinstance(num_or_sections, int): + return paddle._ir_ops.split(input, num_or_sections, dim) + else: + raise NotImplementedError( + "_ir_ops.split_with_num is not implemented, please change sections as list" + ) check_variable_and_dtype( input, From 47097e86724844ba30d3b4477a79671316392df1 Mon Sep 17 00:00:00 2001 From: wangruting Date: Thu, 24 Aug 2023 03:10:14 +0000 Subject: [PATCH 25/30] modify test --- paddle/fluid/pybind/ops_api.cc | 4 +- python/paddle/fluid/backward.py | 3180 +++++--------------------- python/paddle/tensor/manipulation.py | 7 +- test/ir/new_ir/test_build_op.py | 3 +- test/ir/new_ir/test_ir_backward.py | 3 - 5 files changed, 593 insertions(+), 2604 deletions(-) diff --git a/paddle/fluid/pybind/ops_api.cc b/paddle/fluid/pybind/ops_api.cc index a1610c06015bf..ca33bc7305c09 100644 --- a/paddle/fluid/pybind/ops_api.cc +++ b/paddle/fluid/pybind/ops_api.cc @@ -73,9 +73,9 @@ static PyMethodDef OpsAPI[] = {{"add_n", METH_VARARGS | METH_KEYWORDS, "C++ interface function for full."}, {"split", - (PyCFunction)(void (*)(void))full, + (PyCFunction)(void (*)(void))split, METH_VARARGS | METH_KEYWORDS, - "C++ interface function for full."}, + "C++ interface function for split."}, {nullptr, nullptr, 0, nullptr}}; void BindOpsAPI(pybind11::module *module) { diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 9b09ec11cd3ab..5bf723be06c1b 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,2745 +12,733 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .proto import framework_pb2 - -from paddle.fluid import framework as framework -from paddle.fluid import program_guard -from . import core import collections -import copy -import logging -from . import unique_name -from . import log_helper -import paddle.fluid -from .data_feeder import check_type -import warnings - from collections.abc import Sequence -import re - -__all__ = [ - 'append_backward', - 'gradients', -] - -_logger = log_helper.get_logger( - __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' -) - - -class ProgramStats: - def __init__(self, block, ops): - self.block = block - self.ops = ops - self.op_deps = {} # op-> in_ops, out_ops - self.var_op_deps = {} # var as input op, var as output op - - def get_input_nodes(self): - input_names = [] - for name in self.var_op_deps: - if ( - len(self.var_op_deps[name]["var_as_output_ops"]) == 0 - and len(self.var_op_deps[name]["var_as_input_ops"]) > 0 - ): - if self.block.var(name).persistable: - continue - input_names.append(name) - for op in self.ops: - if op.desc.type() == "read": - input_names.extend(op.desc.output_arg_names()) - return input_names - - def get_reserved_vars(self): - var_name = [] - for op in self.ops: - if op.desc.type() == "seed": - var_name.extend(op.desc.output_arg_names()) - return var_name - - def get_out_of_subgraph_vars(self, begin_op_idx, end_op_idx): - var_name = [] - for i in range(begin_op_idx, end_op_idx, 1): - for name in self.ops[i].desc.output_arg_names(): - if name in self.var_op_deps: - for idx in self.var_op_deps[name]["var_as_input_ops"]: - if idx >= end_op_idx: - var_name.append(name) - for name in self.ops[i].desc.input_arg_names(): - if name in self.var_op_deps: - for idx in self.var_op_deps[name]["var_as_output_ops"]: - if idx < begin_op_idx: - var_name.append(name) - return var_name - - def is_subgraph(self, var_group1, var_group2): - # should traverse from var_group1 to var_group2 - # max op idx in var_group2 - # min op idx in var_group1 - min_op_idx = len(self.ops) - max_op_idx = -1 - for name in var_group1: - if name not in self.var_op_deps: - return False, min_op_idx, max_op_idx - for name in var_group2: - if name not in self.var_op_deps: - return False, min_op_idx, max_op_idx - for name in var_group1: - op_idx = self.var_op_deps[name]["var_as_input_ops"] - for idx in op_idx: - min_op_idx = min(min_op_idx, idx) - for name in var_group2: - op_idx = self.var_op_deps[name]["var_as_output_ops"] - for idx in op_idx: - max_op_idx = max(max_op_idx, idx) - if min_op_idx >= max_op_idx: - return False, min_op_idx, max_op_idx - - return True, min_op_idx, max_op_idx - - def _update_segment_start(self, min_idx, pre_segment_end_idx): - """ - persist vars of amp-related cast should be included in recompute segment - """ - - def is_amp_cast(op): - return ( - op.desc.type() == 'cast' - and self.block.var(op.desc.input_arg_names()[0]).persistable - ) - - idx_ = min_idx - 1 - updated_min_idx = min_idx - while idx_ > pre_segment_end_idx: - if is_amp_cast(self.ops[idx_]): - _logger.info( - "found amp-cast op: {}, : {}".format( - self.ops[idx_].desc.type(), - self.ops[idx_].desc.input_arg_names()[0], - ) - ) - updated_min_idx = idx_ - idx_ -= 1 - else: - break - - return updated_min_idx - - def build_stats(self): - for i, op in enumerate(self.ops): - self.op_deps[i] = {"in_ops": [], "out_ops": []} - for j, name in enumerate(op.desc.input_arg_names()): - if name in self.var_op_deps: - self.op_deps[i]["in_ops"].extend( - self.var_op_deps[name]["var_as_output_ops"] - ) - for j, name in enumerate(op.desc.input_arg_names()): - if name in self.var_op_deps: - self.var_op_deps[name]["var_as_input_ops"].extend([i]) - else: - self.var_op_deps[name] = {} - self.var_op_deps[name]["var_as_input_ops"] = [i] - self.var_op_deps[name]["var_as_output_ops"] = [] - - for j, name in enumerate(op.desc.output_arg_names()): - if name in self.var_op_deps: - self.var_op_deps[name]["var_as_output_ops"].extend([i]) - else: - self.var_op_deps[name] = {} - self.var_op_deps[name]["var_as_input_ops"] = [] - self.var_op_deps[name]["var_as_output_ops"] = [i] - - for op_idx in self.op_deps[i]["in_ops"]: - self.op_deps[op_idx]["out_ops"].extend([i]) - - def sort_checkpoints(self, checkpoints_name): - sorted_checkpoints = [] - for name in checkpoints_name: - if name not in self.var_op_deps: - _logger.info( - "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." - % name - ) - elif self.var_op_deps[name]["var_as_output_ops"] == []: - # input nodes - sorted_checkpoints.append((name, -1)) - else: - sorted_checkpoints.append( - (name, max(self.var_op_deps[name]["var_as_output_ops"])) - ) - sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1]) - return [x[0] for x in sorted_checkpoints] - - def modify_forward_desc_for_recompute(self): - op_types = [op.desc.type() for op in self.ops] - if "dropout" not in op_types: - return - - op_idx = 0 - while op_idx < len(self.ops): - op = self.ops[op_idx] - if op.desc.type() != "dropout": - op_idx += 1 - continue - # already insert seed op before dropout - if op.input('Seed') is not None and len(op.input('Seed')) == 1: - op_idx += 1 - continue - # add a seed op so that the two dropout op can generate same output - op_unique_name = unique_name.generate("seed") - var_unique_name = unique_name.generate_with_ignorable_key( - ".".join([op_unique_name, 'tmp']) - ) - added_var = self.block.create_var( - name=var_unique_name, - dtype='int32', - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=False, - ) - seed = 0 if op.attr("fix_seed") is False else int(op.attr("seed")) +import paddle.ir +from paddle.autograd.backward_utils import State - op_device_attr_name = ( - core.op_proto_and_checker_maker.kOpDeviceAttrName() - ) - op_device = "" - if op.desc.has_attr(op_device_attr_name): - op_device = op.desc.attr(op_device_attr_name) - - # Setting the force_cpu of seed to true will make the output of seed in cpu memory, - # reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang - added_op = self.block._insert_op( - index=op.idx, - type='seed', - inputs={}, - outputs={'Out': [added_var]}, - attrs={'seed': seed, 'op_device': op_device, 'force_cpu': True}, - ) - self.ops.insert(op_idx, added_op) - # modify dropout op desc so that it accept a seed var as input - op.desc.set_input("Seed", [var_unique_name]) - op.desc.remove_attr("fix_seed") - op.desc.remove_attr("seed") - self.block._sync_with_cpp() - op_idx += 2 - - -def _pretty_op_desc_(op_desc, prefix): - out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % ( - prefix + "_op", - str(op_desc.type()), - prefix + "_input", - " ".join(op_desc.input_arg_names()), - prefix + "_output", - " ".join(op_desc.output_arg_names()), - ) - return out_s +""" + grad: for templete test, will combine in paddle.grad . + calc_gradient: for internal use, optest, parallel etc . + calc_gradient_helper: for dygraph to static . +""" +__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper'] -def _add_needed_descs_to_block( - descs, block, main_block, in_memory_vars, grad_op_id_to_fwd_op=None -): - if len(descs) == 0: - return [] - result_descs = [] - op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() - backward = core.op_proto_and_checker_maker.OpRole.Backward - for desc in descs: - origin_desc = desc - origin_is_operator = False - if isinstance(desc, framework.Operator): - desc = desc.desc - origin_is_operator = True - if isinstance(desc, tuple): - desc = desc[0] - is_needed = False - for name in desc.output_arg_names(): - if main_block.has_var(name) and main_block.var(name).persistable: - continue - if name not in in_memory_vars: - is_needed = True - if is_needed: - if origin_is_operator and grad_op_id_to_fwd_op is not None: - grad_op_id_to_fwd_op[desc.original_id()] = origin_desc - new_op_desc = block.desc.append_op() - new_op_desc.copy_from(desc) - new_op_desc._set_attr(op_role_attr_name, backward) - if desc.has_attr('op_device'): - new_op_desc._set_attr('op_device', desc.attr('op_device')) - result_descs.append(new_op_desc) - return result_descs - - -def _add_descs_to_block(descs, block, grad_op_id_to_fwd_op=None): - if len(descs) == 0: - return [] - result_descs = [] - op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() - backward = core.op_proto_and_checker_maker.OpRole.Backward - for desc in descs: - if isinstance(desc, framework.Operator): - # for recompute, should record recompute ops - if grad_op_id_to_fwd_op is not None: - grad_op_id_to_fwd_op[desc.desc.original_id()] = desc - desc = desc.desc - if isinstance(desc, tuple): - desc = desc[0] - new_op_desc = block.desc.append_op() - new_op_desc.copy_from(desc) - new_op_desc._set_attr(op_role_attr_name, backward) - if desc.has_attr('op_device'): - new_op_desc._set_attr('op_device', desc.attr('op_device')) - result_descs.append(new_op_desc) - return result_descs - - -def _find_loss_op_(loss): - for op in reversed(loss.block.ops): - assert isinstance(op, framework.Operator) - if ( - len(op.output_arg_names) == 1 - and op.output_arg_names[0] == loss.name - ): - loss.op = op - break - if loss.op is None: - raise ValueError("loss.op is None. Should not happen") - - -def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): - """ - Traverse all ops in op_descs[begin_idx : end_idx], - if any op has inputs/outputs named "old_name", rename it as 'new_name' - """ - if begin_idx is None: - begin_idx = 0 - if end_idx is None: - end_idx = len(op_descs) - if isinstance(op_descs, (list, tuple)): - for i in range(begin_idx, end_idx): - op_desc = op_descs[i] - if isinstance(op_desc, tuple): - op_desc = op_desc[0] - op_desc._rename_input(old_name, new_name) - op_desc._rename_output(old_name, new_name) - if isinstance(op_descs, collections.OrderedDict): - for key, value in op_descs.items(): - if isinstance(value, (list, tuple)): - for op_desc in value: - op_desc._rename_input(old_name, new_name) - op_desc._rename_output(old_name, new_name) - - -def _create_op_desc_(op_type, inputs, outputs, attrs): - """ - Create a C++ OpDesc object with specified inputs, outputs and attributes. - """ - op_desc = core.OpDesc() - op_desc.set_type(op_type) - for para, args in inputs.items(): - op_desc.set_input( - para, - list( - map( - lambda arg: arg.decode() if isinstance(arg, bytes) else arg, - args, - ) - ), +def check_type(input, input_name, expected_type, op_name, extra_message=''): + if not isinstance(input, expected_type): + raise TypeError( + f"The type of '{input_name}' in {op_name} must be {expected_type}, but received {type(input)}. {extra_message}" ) - for para, args in outputs.items(): - op_desc.set_output( - para, - list( - map( - lambda arg: arg.decode() if isinstance(arg, bytes) else arg, - args, - ) - ), - ) - op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() - op_device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() - - if op_role_attr_name not in attrs: - attrs[ - op_role_attr_name - ] = core.op_proto_and_checker_maker.OpRole.Backward - if op_device_attr_name not in attrs: - attrs[op_device_attr_name] = "" - for name, val in attrs.items(): - if isinstance(val, framework.Block): - op_desc.set_block_attr(name, val.desc) - else: - op_desc._set_attr(name, val) - return op_desc -def _create_loss_op_desc_(loss): - # 0-D Tensor or 0-Size Tensor - if len(loss.shape) == 0 or 0 in loss.shape: - create_shape = loss.shape - else: - create_shape = [1] - op_desc = _create_op_desc_( - "fill_constant", - {}, - {"Out": [_append_grad_suffix_(loss.name)]}, - { - "shape": create_shape, - "value": 1.0, - "dtype": loss.dtype, - "force_cpu": False, - core.op_proto_and_checker_maker.kOpRoleAttrName(): int( - core.op_proto_and_checker_maker.OpRole.Backward - ) - | int(core.op_proto_and_checker_maker.OpRole.Loss), - core.op_proto_and_checker_maker.kOpDeviceAttrName(): loss.op.attr( - core.op_proto_and_checker_maker.kOpDeviceAttrName() - ), - }, - ) - return op_desc +def _as_list(x): + if x is None: + return [] + return list(x) if isinstance(x, Sequence) else [x] -def _infer_var_data_type_shape_(grad_var_name, block): - """ - Infer the data type and shape of given grad variable - """ - grad_var = block.desc.find_var(grad_var_name.encode()) - fwd_name = _strip_grad_suffix_(grad_var_name) - if block.desc.has_var_recursive(fwd_name.encode()): - fwd_var = block.desc.find_var_recursive(fwd_name.encode()) - grad_var.set_dtype(fwd_var.dtype()) - grad_var.set_shape(fwd_var.shape()) - else: - # TODO(jiabin): Maybe we should not to this to cause some unexpected error on dtype - warnings.warn( - "Set grad var: {} dtype to default FP32, since we can't find its related forward var".format( - grad_var_name +def check_all_puts(block, inputs, outputs): + for output in outputs: + if output.get_defining_op().get_parent_block() != block: + raise ValueError("all outputs must be in the same block") + for input in inputs: + if input.get_defining_op().get_parent_block() != block: + raise ValueError( + "all inputs must be in the same block with outputs" ) - ) - grad_var.set_dtype(core.VarDesc.VarType.FP32) - -def _all_in_set_(cands, s): - """ - Test if all elements of 'cands' are in set 's' - """ - if len(cands) == 0: - return False - for c in cands: - if not c in s: - return False - return True - - -def _some_in_set_(cands, s): - """ - Test if some elements of 'cands' are in set 's' - """ - if len(cands) == 0: - return False - for c in cands: - if c in s: - return True - return False +def update_no_grad_set_by_stopgradient(block, no_grad_set): + for op in block.ops: + for opresult_idx in range(op.num_results()): + value = op.result(opresult_idx) + if value.stop_gradient and value not in no_grad_set: + no_grad_set.add(value) -def _strip_grad_suffix_(name): - """ - Strip the grad suffix from the given variable name - e.g. x@GRAD ==> x - x@GRAD@GRAD ==> x - y@GRAD@RENAME@1 ==> y - z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0 - grad/grad/z@GRAD@RENAME@block0@1@GRAD ==> z - """ - pos = re.search(f'{core.grad_var_suffix()}+@', name) or re.search( - f'{core.grad_var_suffix()}$', name - ) - new_name = name[: pos.start()] if pos is not None else name - new_pos = name.rfind('grad/') - return new_name[new_pos + 5 :] if new_pos != -1 else new_name - -def _append_grad_suffix_(name): - """ - Append grad suffix to the given variable name - e.g. x ==> x@GRAD - """ - return name + core.grad_var_suffix() +def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op): + backward_ops.append(grad_op) + op_to_opgrad_list.append(grad_op) -def _accumulate_gradients_by_sum_op_( - var_name, renamed_vars, pending_sum_ops, op_idx, op_device="" +def prepare_grad_outputs( + block, grad_outputs, outputs, value_to_valuegrad, op_to_opgrad ): """ - Use sum op to accumulate_gradients, the gradients are stored in renamed_vars. - """ - if op_idx not in pending_sum_ops.keys(): - pending_sum_ops[op_idx] = [] - pending_sum_ops[op_idx].append( - _create_op_desc_( - "sum", - {"X": renamed_vars[var_name]}, - {"Out": [var_name]}, - {"use_mkldnn": False, "op_device": op_device}, - ) - ) - renamed_vars[var_name] = [var_name] - + if grad_outputs is none, add fill_1 op to create grad_outputs, + else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error. -def _accumulate_gradients_by_add_ops_( - var_name, renamed_vars, pending_sum_ops, op_idx, op_device="" -): - """ - Use several inplace add op to accumulate_gradients, the gradients are stored in renamed_vars. - """ - if op_idx not in pending_sum_ops.keys(): - pending_sum_ops[op_idx] = [] - out_name = renamed_vars[var_name][0] - for i in range(1, len(renamed_vars[var_name])): - x_name = out_name - y_name = renamed_vars[var_name][i] - if i != len(renamed_vars[var_name]) - 1: - out_name = var_name + '@ADD@' + str(i) - else: - out_name = var_name - pending_sum_ops[op_idx].append( - _create_op_desc_( - "grad_add", - {"X": [x_name], "Y": [y_name]}, - {"Out": [out_name]}, - {"use_mkldnn": False, "op_device": op_device}, - ) - ) - renamed_vars[var_name] = [var_name] + if only part of op's outputs in outputs, add fill_0 op to create other grad_outputs. + eg: split. + update value_to_valuegrad and op_to_opgrad. -def _addup_repetitive_outputs_( - op_descs, block_idx, grad_var_to_var=None, grad_op_id_to_fwd_op=None -): - """ - In backward part, an variable may be the output of more than one ops. - And one op may yield its multiple outputs to the same variable. - In these cases, the variable should be the accumulation of all the outputs. - `sum_op`s are added to implement the accumulate. + return complete_outputs and complete_gradoutputs, backward_ops. - Args: - grad_var_to_var(dict): used to build the mapping between grad var name and forward var name. - Only for auto parallel. """ + if not grad_outputs: + grad_outputs = [None] * len(outputs) - _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] - # pending_sum_ops = [] - pending_sum_ops = collections.OrderedDict() - var_rename_count = collections.defaultdict(int) - renamed_vars = collections.defaultdict(list) - renamed_var_start_idx = collections.defaultdict(list) - var_device = collections.defaultdict(str) - for idx, op_desc in enumerate(op_descs): - op_device_attr_name = ( - core.op_proto_and_checker_maker.kOpDeviceAttrName() + if len(grad_outputs) != len(outputs): + raise ValueError( + "grad_outputs should have the same length of as outputs." ) - op_device = "" - if op_desc.has_attr(op_device_attr_name): - op_device = op_desc.attr(op_device_attr_name) - for var_name in op_desc.input_arg_names(): - if "@GRAD" not in var_name: - continue - if len(renamed_vars[var_name]) > 1: - if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: - _accumulate_gradients_by_sum_op_( - var_name, - renamed_vars, - pending_sum_ops, - idx, - var_device[var_name], - ) - else: - _accumulate_gradients_by_add_ops_( - var_name, - renamed_vars, - pending_sum_ops, - idx, - var_device[var_name], - ) - - for param_idx, param_name in enumerate(op_desc.output_names()): - arg_names = op_desc.output(param_name) - for arg_idx, var_name in enumerate(arg_names): - if "@GRAD" not in var_name: - continue - # if "@RENAME@" in var_name: - # continue - if ( - var_name == core.empty_var_name() - or var_name in op_desc.input_arg_names() - ): - # empty variable or inplace op - continue - if len(renamed_vars[var_name]) == 0: - # it's the first time we get the variable - renamed_vars[var_name] = [var_name] - renamed_var_start_idx[var_name] = idx - else: - if len(renamed_vars[var_name]) == 1: - new_name = ( - var_name - + "@RENAME@block" - + str(block_idx) - + "@" - + str(var_rename_count[var_name]) - ) - var_rename_count[var_name] += 1 - # Build the mapping between the new_name and var_name (Only for auto parallel) - if grad_var_to_var is not None: - if var_name in grad_var_to_var: - grad_var_to_var[new_name] = grad_var_to_var[ - var_name - ] - else: - grad_var_to_var[new_name] = var_name - # rename original var_name - renamed_vars[var_name][0] = new_name - # before change: _rename_arg_(op_descs, var_name, - # new_name, 0, idx) - # rename arg from idx of the first appearance - # in backward, not always from 0 - _rename_arg_( - op_descs, - var_name, - new_name, - renamed_var_start_idx[var_name], - idx, - ) - _rename_arg_(pending_sum_ops, var_name, new_name) - - for p in op_desc.output_names()[:param_idx]: - p_arg_names = op_desc.output(p) - if var_name in p_arg_names: - op_desc.set_output( - p, - [ - new_name if x == var_name else x - for x in p_arg_names - ], - ) - - arg_names = [ - new_name if x == var_name else x - for x in arg_names[:arg_idx] - ] + arg_names[arg_idx:] - - new_name = ( - var_name - + "@RENAME@block" - + str(block_idx) - + "@" - + str(var_rename_count[var_name]) - ) - var_rename_count[var_name] += 1 - # Build the mapping between the new_name and var_name (Only for auto parallel) - if grad_var_to_var is not None: - if var_name in grad_var_to_var: - grad_var_to_var[new_name] = grad_var_to_var[ - var_name - ] - else: - grad_var_to_var[new_name] = var_name - arg_names[arg_idx] = new_name - op_desc.set_output(param_name, arg_names) - renamed_vars[var_name].append(new_name) - # record the latest device - var_device[var_name] = op_device - - for var_name, inputs in renamed_vars.items(): - if len(renamed_vars[var_name]) > 1: - if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: - _accumulate_gradients_by_sum_op_( - var_name, - renamed_vars, - pending_sum_ops, - len(op_descs), - var_device[var_name], - ) - else: - _accumulate_gradients_by_add_ops_( - var_name, - renamed_vars, - pending_sum_ops, - len(op_descs), - var_device[var_name], - ) - - op_descs_len = len(op_descs) - # sum_op descs are sorted according to their insert position - for key, value in collections.OrderedDict( - reversed(list(pending_sum_ops.items())) - ).items(): - # NOTE(zhiqiu): Since reversed, the idx of op_descs to be inserted will remains correct. - # For example, [0, 1, 2], and we want to insert 'a' at idx 1, 'b' at idx 2, and the expected result is [0, 1, 'a', 2, 'b']. - # If reversed, we first insert 'b' at idx 2, it becomes [0, 1, 2, 'b'], and then insert 'a' at idx 1, it becomes [0, 1, 'a', 2, 'b']. - # If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2]. - idx = key - for i, op in enumerate(value): - # update the mapping between fwd and bwd - target_idx = idx - 1 if idx == op_descs_len else idx + i - if ( - grad_op_id_to_fwd_op is not None - and grad_op_id_to_fwd_op.get( - op_descs[target_idx].original_id(), None - ) - is not None - ): - grad_op_id_to_fwd_op[op.original_id()] = grad_op_id_to_fwd_op[ - op_descs[target_idx].original_id() - ] - op_descs.insert(idx + i, op) - - return op_descs - - -def _remove_no_grad_branch_( - op_descs, no_grad_set, grad_op_id_to_fwd_op=None, target_vars=[] -): - """ - Remove unnecessary grad ops - A grad op can be removed in two cases: - 1. all outputs of the grad op are in 'no_grad_set' - 2. all grad inputs of the grad op are in 'no_grad_set' - NOTE: we will skip target_vars's grad name. - """ - - def _op_can_be_removed_(op_desc, no_grad_set): - out_arg_names = op_desc.output_arg_names() - if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set): - return True - if _all_in_set_( - [ - name - for name in op_desc.input_arg_names() - if name.find(core.grad_var_suffix()) != -1 - ], - no_grad_set, - ): - no_grad_set.update(set(out_arg_names) - target_grad_var_names) - return True - return False - - # Remove ops whose outputs are all in no_grad_dict - target_grad_var_names = set( - [var.name + core.grad_var_suffix() for var in target_vars] - ) - op_descs = [ - op_desc - for op_desc in op_descs - if not _op_can_be_removed_(op_desc, no_grad_set) - ] - # Insert fill_any_like_op with value 0 - to_insert = [] - if not core._is_bwd_prim_enabled(): - for idx, op_desc in enumerate(op_descs): - for arg in op_desc.input_arg_names(): - # arg is a gradient var name and arg should not have gradient - if core.grad_var_suffix() in arg and arg in no_grad_set: - x_in = _strip_grad_suffix_(arg) - # the reason should be: arg can be input of another grad op - # and the op is a not-to-remove op - new_op_desc = _create_op_desc_( - "fill_any_like", - {"X": [x_in]}, - {"Out": [arg]}, - {'value': 0, 'dtype': -1}, - ) - # update the mapping between fwd and bwd - if ( - grad_op_id_to_fwd_op is not None - and grad_op_id_to_fwd_op.get( - op_desc.original_id(), None - ) - is not None - ): - grad_op_id_to_fwd_op[ - new_op_desc.original_id() - ] = grad_op_id_to_fwd_op[op_desc.original_id()] - to_insert.append((new_op_desc, idx)) - - list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]) - - return op_descs - - -def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set): - """ - Pruning Program with Structural Analysis Method of Computational Graph. - The nodes of the computational graph composed of backward OPS should be - interconnected. If there are unconnected sub-graphs in the computational graph, - these sub-graphs should be cut off. - - Args: - grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs. - forward_ops(list[Operator]): The forward ops. - input_grad_names_set(set): this set is used to store the gradients' name - which is generated by backward ops, and input_grad_names_set can help - to prune the unnecessary backward ops. - - Return: - (set[core.OpDesc]): A set of OpDescs which should be pruned. - """ - - class Var: - def __init__(self, var_name): - self.var_name = var_name - self.gen_op = None - self.pendding_ops = [] - - def set_gen_op(self, gen_op): - assert isinstance(gen_op, Op) - assert self.gen_op is None - self.gen_op = gen_op - - def add_pending_op(self, op): - assert isinstance(op, Op) - self.pendding_ops.append(op) - - class Op: - def __init__(self, op_desc): - self.op_desc = op_desc - self.inputs = [] - self.outputs = [] - - def insert_input(self, var): - assert isinstance(var, Var) - self.inputs.append(var) - - def insert_output(self, var): - assert isinstance(var, Var) - self.outputs.append(var) - - var_versions = dict() - - def _create_node(name): - if name not in var_versions.keys(): - var_versions[name] = [Var(name)] - else: - var_versions[name].append(Var(name)) - return var_versions[name][-1] - - def _create_or_get_last_version_node(name): - if name not in var_versions.keys(): - var_versions[name] = [Var(name)] - return var_versions[name][-1] - - def _create_op_node(op_desc): - op_node = Op(op_desc) - for input in op_desc.input_arg_names(): - var = _create_or_get_last_version_node(name=input) - var.add_pending_op(op_node) - op_node.insert_input(var) - for output in op_desc.output_arg_names(): - var = _create_node(name=output) - var.set_gen_op(op_node) - op_node.insert_output(var) - return op_node - - # Record the forward vars - forward_vars_set = ( - set() if input_grad_names_set is None else set(input_grad_names_set) - ) - for op in forward_ops: - forward_vars_set.update(op.desc.input_arg_names()) - forward_vars_set.update(op.desc.output_arg_names()) - - # Record the vars which are created during backward and is not generated by op. - backward_vars_set = set() - # special_op_nodes is the candidate sub-graph head node. - special_op_nodes = set() - for op_desc in grad_op_descs: - input_set = set(op_desc.input_arg_names()) - # The new_vars are created during backward and is not generated by op. - new_vars = input_set - forward_vars_set - backward_vars_set - backward_vars_set.update(op_desc.output_arg_names()) - - op_node = _create_op_node(op_desc) - if len(new_vars) == len(input_set): - special_op_nodes.add(op_node) - - not_need_op_descs = [] - # Start traversing all candidate sub-graph headers to check whether - # they are connected to backward computational graphs, and if they are - # not, list them in not_need_op_descs - for special_op_node in special_op_nodes: - op_list = [special_op_node] - ready_vars = set(special_op_node.inputs) - remove_ops = True - candidate_ops = [special_op_node] - while len(candidate_ops) > 0: - op_node = candidate_ops.pop(0) - if _all_in_set_(op_node.inputs, ready_vars): - for out_var in op_node.outputs: - candidate_ops.extend(out_var.pendding_ops) - op_list.extend(out_var.pendding_ops) - ready_vars.update(op_node.outputs) - else: - remove_ops = False - break - if remove_ops: - not_need_op_descs.extend([node.op_desc for node in op_list]) - not_need_op_descs_set = set(not_need_op_descs) - grad_op_descs_set = set(grad_op_descs) - # If a backward computational graph is simply one sub-graph header, the - # not_need_op_descs will be whole graph, this IF clause avoids it. - if grad_op_descs_set == not_need_op_descs_set: - return set() - return not_need_op_descs_set - - -def serialize_op_decs(op_desc): - protostr = op_desc.serialize_to_string() - proto = framework_pb2.OpDesc.FromString(bytes(protostr)) - return proto.__str__() - - -def _append_backward_ops_with_checkpoints_( - block, - ops, - target_vars, - target_block, - no_grad_dict, - grad_to_var, - checkpoints, - grad_op_id_to_fwd_op=None, -): - """ - Create grad ops with forward ops, and insert them into given block - - Args: - block(Block): the block where forward ops are - ops(Op): the forward operators whose forward recomputation backward ops need to be added - target_vars(list[Tensor]): the loss vars we want to calculate gradient. - target_block(Block): the block which is going to hold new generated grad ops - no_grad_dict(dict): - key(int) block index - val(str): corresponding forward variable name - checkpoints: variables that a user defined as checkpoint for forward recomputation - - Algorithms: - 0) deal with forward recomputing program descs - 1) find ops between checkpoints, i.e. recompute_segments - 2) go through all forward ops and induct all variables that will be hold in memory - a. variables that are used across segments will be held in memory - b. output of dropout op will be held in memory - c. input variables will be held in memory - 3) go through each recompute_segments, add backward ops with forward recomputation - a. add ops in current recompute_segment as forward recomputation ops - b. rename all non-checkpoint variables in recomputation ops - c. add backward ops of current recomputation ops - d. add sum op for repetitive_outputs - 4) remove no grad branch as it is in _remove_no_grad_branch_ - 5) Note1: all appended ops' OpRole are Backward - 6) Note2: all variables with new name should be returned so that _append_backward_vars_ can be called - 7) Note3: current forward recomputation backpropagation does not handle programs with subblock - """ + backward_ops = [] + for i, grad in enumerate(grad_outputs): + output = outputs[i] + # fwd : op1 -> op2 -> op3 -> output + # bwd : op1G <- op2G <- op3G <- outputG <- fillop/feedop + if grad is None: + output_grad = paddle.full( + output.shape, + 1.0, + dtype=output.dtype, + ) + fillop = output_grad.get_defining_op() - checkpoints_name = [x.name for x in checkpoints] - checkpoints_name = list(set(checkpoints_name)) - local_block = block.program._create_block() - buffer_block = block.program._create_block() - # 0) deal with forward recomputing program descs - program_stat = ProgramStats(block, ops) - program_stat.modify_forward_desc_for_recompute() - program_stat.build_stats() - - # 1) find ops between checkpoints, i.e. recompute_segments - checkpoints_name = program_stat.sort_checkpoints(checkpoints_name) - segments = [] - - if len(checkpoints_name) == 1: - # only one checkpoint - max_op_idx = -1 - var_group = [checkpoints_name[0]] - for name in var_group: - if name not in program_stat.var_op_deps: - break - op_idx = program_stat.var_op_deps[name]["var_as_output_ops"] - # only count the last generate op - for idx in op_idx: - max_op_idx = max(max_op_idx, idx) - if max_op_idx > 0: - segments.append([0, max_op_idx + 1]) - else: - start_idx = 0 - pre_segment_end_idx = -1 - while True: - if start_idx >= len(checkpoints_name) - 1: - break - # min_idx: checkpoint_1' s input op - # max_idx: checkpoint_2' s output op - flag, min_idx, max_idx = program_stat.is_subgraph( - [checkpoints_name[start_idx]], [checkpoints_name[start_idx + 1]] + update_bwdop_structure( + backward_ops, + op_to_opgrad[output.get_defining_op()], + fillop, ) - if flag: - # max_idx + 1 since the exact and used segment end idx is max_idx - min_idx = program_stat._update_segment_start( - min_idx, pre_segment_end_idx + value_to_valuegrad[output] = [[output_grad]] + else: + if output.shape != grad.shape: + raise ValueError( + "The shape of grad_output[%d] should be the same as the shape of output[%d]" + % (i, i) ) - segments.append([min_idx, max_idx + 1]) - else: - _logger.info( - "Could not recompute op range [{}] - [{}] ".format( - min_idx, max_idx + 1 - ) + if output.dtype != grad.dtype: + raise ValueError( + "The dtype of grad_output[%d] should be the same as the dtype of output[%d]" + % (i, i) ) - - start_idx += 1 - - if segments != [] and segments[0][0] != 0: - recompute_segments = [[0, segments[0][0]]] + segments - else: - recompute_segments = segments - - for i, (idx1, idx2) in enumerate(recompute_segments): - _logger.info("recompute segment[{}]".format(i)) - _logger.info( - "segment start op: [{}]: [{}]".format( - ops[idx1].desc.type(), ops[idx1].desc.input_arg_names() - ) - ) - _logger.info( - "segment end op: [{}]: [{}]".format( - ops[idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names() - ) - ) - _logger.info("recompute segment[{}]".format(i)) - _logger.info( - "segment start op: [{}]: [{}]".format( - ops[idx1].desc.type(), ops[idx1].desc.input_arg_names() - ) - ) - _logger.info( - "segment end op: [{}]: [{}]".format( - ops[idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names() + feedop = grad.get_defining_op() + update_bwdop_structure( + backward_ops, op_to_opgrad[output.get_defining_op()], feedop ) - ) - - # 2) go through all forward ops and induct all variables that will be hold in memory - vars_should_be_hold = [] - # a. variables that are used across segments will be held in memory - for segment in recompute_segments: - vars_should_be_hold.extend( - program_stat.get_out_of_subgraph_vars(segment[0], segment[1]) - ) + value_to_valuegrad[output] = [[grad]] - cross_vars = set(vars_should_be_hold) - set(checkpoints_name) - _logger.info( - "found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( - len(cross_vars), cross_vars - ) - ) - - # b. output of seed op should be kept in memory - vars_should_be_hold.extend(program_stat.get_reserved_vars()) - # c. input variables are checkpoints - vars_should_be_hold.extend(program_stat.get_input_nodes()) - vars_should_be_hold = list(set(vars_should_be_hold)) - - # 3) go through each recompute_segments, add backward ops with forward recomputation - grad_op_descs = [] - var_name_dict = {} - - vars_in_memory = vars_should_be_hold + checkpoints_name - - max_calculated_op_position = len(ops) - device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() - if recompute_segments == []: - gap_ops = ops[0:max_calculated_op_position] - for op in reversed(gap_ops): - if op.has_attr("sub_block"): - raise Exception( - "Recompute don't support ops with sub_block" - "invoke op: %s" - % _pretty_op_desc_(op.desc, "with_sub_block") - ) - grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - op.desc, no_grad_dict[block.idx], [] - ) + # add input for bwd first op + complete_outputs = outputs + complete_gradoutputs = grad_outputs - # record the mapping between fwd and bwd - if grad_op_id_to_fwd_op is not None: - for op_desc in grad_op_desc: - grad_op_id_to_fwd_op[op_desc.original_id()] = op - - # Set device for grad_op according to forward Op - if op.desc.has_attr(device_attr_name): - op_device = op.desc.attr(device_attr_name) - for op_desc in grad_op_desc: - op_desc._set_attr(device_attr_name, op_device) - added_descs = _add_descs_to_block( - grad_op_desc, local_block, grad_op_id_to_fwd_op - ) - grad_op_descs.extend(added_descs) - grad_to_var.update(op_grad_to_var) - - for i, segment in enumerate(recompute_segments[::-1]): - gap_ops = ops[segment[1] : max_calculated_op_position] - max_calculated_op_position = segment[0] - for op in reversed(gap_ops): - if op.has_attr("sub_block"): - raise Exception( - "Recompute don't support ops with sub_block" - "invoke op: %s" - % _pretty_op_desc_(op.desc, "with_sub_block") + visited_output = set() + for output in outputs: + if output in visited_output: + continue + for opresult in output.get_defining_op().results(): + if opresult in value_to_valuegrad: + visited_output.add(opresult) + continue + else: + grad_value = paddle.full( + opresult.shape, + 0.0, + opresult.dtype, ) - grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - op.desc, no_grad_dict[block.idx], [] - ) + fillop = grad.get_defining_op() - # record the mapping between fwd and bwd - if grad_op_id_to_fwd_op is not None: - for op_desc in grad_op_desc: - grad_op_id_to_fwd_op[op_desc.original_id()] = op - - # Set device for grad_op according to forward Op - if op.desc.has_attr(device_attr_name): - op_device = op.desc.attr(device_attr_name) - for op_desc in grad_op_desc: - op_desc._set_attr(device_attr_name, op_device) - added_descs = _add_descs_to_block( - grad_op_desc, local_block, grad_op_id_to_fwd_op - ) - grad_op_descs.extend(added_descs) - grad_to_var.update(op_grad_to_var) - - ff_ops = ops[segment[0] : segment[1]] - var_suffix = ".subprog_%d" % i - - for op in ff_ops: - if op.has_attr("sub_block"): - raise Exception( - "Recompute don't support ops with sub_block" - "invoke op: %s" - % _pretty_op_desc_(op.desc, "with_sub_block") + update_bwdop_structure( + backward_ops, + op_to_opgrad[opresult.get_defining_op()], + fillop, ) - input_and_output_names = [] - input_and_output_names.extend(op.desc.input_arg_names()) - input_and_output_names.extend(op.desc.output_arg_names()) - for name in input_and_output_names: - if block.var(name).persistable or name in checkpoints_name: - continue - if name in vars_should_be_hold: - continue - if name not in var_name_dict: - var_name_dict[name] = name + var_suffix - - # we should create the rename var in subprog, otherwise its VarType will be BOOL - ref_var = block.program.global_block().var(name) - block.create_var( - name=var_name_dict[name], - shape=ref_var.shape, - dtype=ref_var.dtype, - type=ref_var.type, - persistable=ref_var.persistable, - stop_gradient=ref_var.stop_gradient, - ) - - # 3.a. add ops in current recompute_segment as forward recomputation ops - buffer_descs = _add_needed_descs_to_block( - ff_ops, buffer_block, block, vars_in_memory, grad_op_id_to_fwd_op - ) - added_descs = _add_descs_to_block( - ff_ops, local_block, grad_op_id_to_fwd_op - ) - - # 3.b. rename all non-checkpoint variables in recomputation ops - for key in var_name_dict: - _rename_arg_(buffer_descs, key, var_name_dict[key]) - - # added_descs should be in grad_op_descs because it is backward op desc - grad_op_descs.extend(buffer_descs) - - # 3.c. add backward ops for all ops in current segment - for op_desc in reversed(added_descs): - grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - op_desc, no_grad_dict[block.idx], [] - ) - - # record the mapping between fwd and bwd - if grad_op_id_to_fwd_op is not None: - for g_op_desc in grad_op_desc: - grad_op_id_to_fwd_op[ - g_op_desc.original_id() - ] = grad_op_id_to_fwd_op[op_desc.original_id()] - - # Set device for grad_op according to forward Op - if op_desc.has_attr(device_attr_name): - op_device = op_desc.attr(device_attr_name) - for g_op_desc in grad_op_desc: - g_op_desc._set_attr(device_attr_name, op_device) - - for key in var_name_dict: - _rename_arg_(grad_op_desc, key, var_name_dict[key]) - grad_op_descs.extend(grad_op_desc) - grad_to_var.update(op_grad_to_var) - - # 3.d. add sum op for repetitive_outputs - grad_op_descs = _addup_repetitive_outputs_( - grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op - ) - # 4) remove no grad branch as it is in _remove_no_grad_branch_ - grad_op_descs = _remove_no_grad_branch_( - grad_op_descs, - no_grad_dict[block.idx], - grad_op_id_to_fwd_op, - target_vars, - ) - added_descs = _add_descs_to_block( - grad_op_descs, target_block, grad_op_id_to_fwd_op - ) - return ( - program_stat, - checkpoints_name, - vars_should_be_hold, - recompute_segments, - ) + value_to_valuegrad[opresult] = [grad_value] + visited_output.add(opresult) -def _get_sub_block_path( - sub_block, - sub_block_op_desc, - no_grad_set, - op_path_dict, - sub_block_target_names=None, -): - """ - Get output vars in subblock which will be assigned to parent block. - It is used to find the grad path in subblock. + complete_outputs.append(opresult) + complete_gradoutputs.append(grad_value) - Args: - sub_block(Block): The sub-block in which to get op path. - sub_block_op_desc: The op desc of the sub-block op such as 'while', 'conditional_block' and 'recurrent'. - no_grad_set(set): The set of no grad var name. no_grad_set will be changed. - op_path_dict(dict): op_path_dict will be changed. - key(int) block index - val(list) the op path of block(index) - sub_block_target_names(set): Target var names of sub-block. - Return: - The forward op path of sub-block corresponding to backward op. - """ + return complete_outputs, complete_gradoutputs, backward_ops - assert sub_block_op_desc.has_attr( - "sub_block" - ) and sub_block.idx == sub_block_op_desc._block_attr_id("sub_block") - assert isinstance(sub_block_target_names, (set, type(None))) - - if sub_block_target_names is None: - sub_block_target_names = sub_block_op_desc.output_arg_names - - # TODO(huihuangzheng): add support for recurrent op. - if sub_block_op_desc.type in ["conditional_block", "while"]: - # Step1: get the output vars in sub-block - sub_outputs = [ - sub_block._var_recursive(var) for var in sub_block_target_names - ] - for var in sub_block_target_names: - for op_desc in sub_block.ops: - if var in op_desc.output_arg_names: - for name in op_desc.input_arg_names: - sub_outputs.append(sub_block._var_recursive(name)) - - # Step2: find op path of sub-block - is_while = sub_block_op_desc.type in ["while"] - sub_block_op_path = _find_op_path_( - sub_block, sub_outputs, [], no_grad_set, op_path_dict, is_while - ) - return sub_block_op_path - return sub_block.ops +def some_in_set(value_list, value_set): + def operand2value(values): + value_set = set() + for item in values: + if isinstance(item, paddle.ir.OpOperand): + value_set.add(item.source()) + else: + value_set.add(item) + return value_set -def _is_grad_op_(op): - op_maker = core.op_proto_and_checker_maker - backward = core.op_proto_and_checker_maker.OpRole.Backward - if op_maker.kOpRoleVarAttrName() in op.attr_names and int( - op.all_attrs()[op_maker.kOpRoleAttrName()] - ) == int(backward): + if operand2value(value_list) & operand2value(value_set): return True - return False - - -def _rename_grad_name_(name, grad_order): - return 'grad/' * grad_order + name - - -def _append_backward_ops_( - block, - ops, - target_vars, - target_block, - no_grad_dict, - grad_to_var, - callbacks=None, - input_grad_names_set=None, - op_path_dict=None, - distop_context=None, - rename_var_map=None, - grad_op_id_to_fwd_op=None, -): - """ - Create all grad ops, and insert them into given block - - Args: - block(Block): the block where forward ops are - ops(Op): the forward operators whose backward ops need to be added - target_vars(list[Tensor]): the loss vars we want to calculate gradient. - target_block(Block): the block which is going to hold new generated grad ops - no_grad_dict(dict): - key(int) block index - val(set) a set of variable names. These variables have no gradient - grad_to_var(dict)(output argument): - key(str): grad variable name - val(str): corresponding forward variable name - callbacks(callable object): a callable object used to decorate new generated grad ops - input_grad_names_set(set): this set is used to store the gradients' name which is - generated by backward ops, and input_grad_names_set can help to prune the unnecessary - backward ops. - op_path_dict(dict): op_path_dict will be changed. - key(int) block index - val(list) the op path of block(index) - rename_var_map(dict): used to associate target_grad var name with first grad_op input name. - Only used in for high order gradient. - """ - - # Build the mapping between the forward op and backward op (Only for auto parallel) - def update_distop_context( - distop_context, op_grad_to_var, appending_grad_times - ): - distop_context.grad_var_to_var[appending_grad_times].update( - op_grad_to_var - ) - for op_desc in grad_op_desc: - assert ( - op_desc.original_id() not in distop_context.grad_op_id_to_op_id - ) - distop_context.grad_op_id_to_op_id[ - op_desc.original_id() - ] = op.desc.original_id() - - if callbacks is not None: - assert isinstance(callbacks, (list, tuple)) - for cb in callbacks: - if not hasattr(cb, '__call__'): - raise ValueError("'callback' must be a callable object.") - - # grad_op_descs holds created grad_op, and will be appended to target_block - grad_op_descs = [] - program = block.program - - if rename_var_map is None: - rename_var_map = {} - assert isinstance(rename_var_map, dict) - - if core._is_bwd_prim_enabled(): - composite_block = program.clone().current_block() - # Create output and infer shape for operators whose output haven't - # been created. - for op in composite_block.ops: - for name in op.output_arg_names: - if not ( - composite_block.desc.has_var_recursive(name.encode()) - or name == core.empty_var_name() - ): - composite_block.create_var(name=name) - op.desc.infer_var_type(composite_block.desc) - op.desc.infer_shape(composite_block.desc) - - # add grad_op_desc by reversed ops - for op in reversed(ops): - grad_sub_block_list = [] - # If the op has its own sub-block, deal with the sub-block first - if op.has_attr("sub_block"): - sub_block = program.block(op._block_attr_id("sub_block")) - grad_sub_block = program._create_block() - grad_sub_block._set_forward_block_idx(sub_block.idx) - # see following comments for why set None here. - pre_input_grad_names_set = copy.copy(input_grad_names_set) - input_grad_names_set = None - sub_block_path = op_path_dict[op._block_attr_id("sub_block")] - _append_backward_ops_( - sub_block, - sub_block_path, - target_vars, - grad_sub_block, - no_grad_dict, - grad_to_var, - callbacks, - input_grad_names_set, - op_path_dict, - grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, - ) - input_grad_names_set = pre_input_grad_names_set - - program._rollback() - grad_sub_block_list.append(grad_sub_block.desc) - # In primitive mode, raw phi GradOp will be split into multiple small - # primitive operators, and the split rules are defined in c++ level, - # see details: paddle/fluid/prim/api/manual/backward/composite_backward_api.h - # It means that the output's shape and dtype of previous operators which - # maybe used as the input of next operators must be known. Therefore, - # we infer shape and dtype in a sandbox block(named composite_block) for - # used in c++ level. - # For example: - # forward: - # z = multiply(x, y) //maybe broadcast in kernel - # backward: - # x_grad_unreduce = z_grad * y // maybe unreduce - # reduced_axes = get_reduced_axes(x_grad.shape, x.shape) // need known shape - # x_grad = reduce_sum(x_grad_unreduce) - grad_op_desc = [] - op_grad_to_var = {} - if core._is_bwd_prim_enabled(): - - def find_op_index(block_desc, cur_op_desc): - for idx in range(block_desc.op_size()): - if cur_op_desc == block_desc.op(idx): - return idx - return -1 - - grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - composite_block.desc.op(find_op_index(block.desc, op.desc)), - no_grad_dict[composite_block.idx], - grad_sub_block_list, - ) - for desc in grad_op_desc: - infershape_for_composite(composite_block, desc) - else: - # Getting op's corresponding grad_op - grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - op.desc, no_grad_dict[block.idx], grad_sub_block_list - ) + else: + return False - # record the mapping between fwd and bwd - if grad_op_id_to_fwd_op is not None: - for op_desc in grad_op_desc: - grad_op_id_to_fwd_op[op_desc.original_id()] = op - # Build the mapping between the forward op and backward op (Only for auto parallel) - if distop_context is not None: - update_distop_context( - distop_context, op_grad_to_var, program._appending_grad_times - ) - else: - default_ctx = getattr( - paddle.distributed.auto_parallel.static.dist_context, - '_g_default_distributed_context', - None, - ) - if default_ctx is not None: - distop_context = default_ctx.dist_op_context - update_distop_context( - distop_context, - op_grad_to_var, - program._appending_grad_times, - ) +def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): + ''' + prune ops which do not in the path from inputs_set to outputs_set, + prune ops which do not in the path from outputs_set to inputs_set, - # Set device for grad_op according to forward Op - device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() - if op.desc.has_attr(device_attr_name): - op_device = op.desc.attr(device_attr_name) - for op_desc in grad_op_desc: - op_desc._set_attr(device_attr_name, op_device) - - # Rename internal gradient variables in multiple backward - # so that they have different names with previous backward. - # For example: - # y = x * x, grad = fluid.gradients(fluid.gradients(y, x) + y * y, x) - # In second-time backward, gradient variable names of partial - # forward network (y * y) may be have same names with first-time - # fluid.gradients(y, x). - # So rename here before _addup_repetitive_outputs_. - if program._appending_grad_times > 1: - for op_desc in grad_op_desc: - forward_op_inputs = op.desc.input_arg_names() - for name in op_desc.input_arg_names(): - if name in rename_var_map and name not in forward_op_inputs: - op_desc._rename_input(name, rename_var_map[name]) - for name in op_desc.output_arg_names(): - if "@GRAD" not in name: - continue - if block.desc.find_var(name.encode("ascii")): - new_name = _rename_grad_name_( - name, program._appending_grad_times - ) - op_desc._rename_output(name, new_name) - rename_var_map[name] = new_name - - if name in op_grad_to_var: - # Build the mapping between the grad var name and var name (Only for auto parallel) - if distop_context is not None: - distop_context.grad_var_to_var[ - program._appending_grad_times - ][new_name] = op_grad_to_var[name] - op_grad_to_var[new_name] = op_grad_to_var[name] - op_grad_to_var.pop(name) - - # If input_grad_names_set is not None, extend grad_op_descs only when - # any input grad in outputs of previous grad ops. - # But this strategy is not suited for while op for some control flow, - # for example, for while op, the grads maybe generated in next loop. - if input_grad_names_set is not None: - is_grad_name = ( - lambda name: name.find(core.grad_var_suffix()) != -1 - or name in input_grad_names_set - ) - is_append_grad = False - - # NOTE: In primitive mode, the intermediate variable generated by - # decompositing raw grad op are not satisfied the rule of 'XX@GRAD', - # which will cause it be pruned according to current pruning logic. - # For simplicity, we treate all prmitive operators as one raw - # operator, and keep the pruning logic consistent with currently - # logic. The drawback of this solution is may lead to some primitive - # operators are not pruned, which is needed to fixed. - # FIXME: Optimize pruning logic from the perspective of whole graph. - input_grad_names = [] - for op_desc in grad_op_desc: - input_grad_names += [ - name - for name in op_desc.input_arg_names() - if is_grad_name(name) - ] + pruned op in total_ops is uneffective_ops, else is effective_ops - # some code of gradient ops, like increment, are not very - # standard, there is no @GRAD in these ops' inputs. - if len(input_grad_names) == 0: - is_append_grad = True + ''' + relevant_op_flags = [True] * len(total_ops) + # from input to output + if inputs_set: + for i, op in enumerate(total_ops): + if some_in_set(op.results(), inputs_set): continue - if _some_in_set_(input_grad_names, input_grad_names_set): - is_append_grad = True - for op_desc in grad_op_desc: - grad_op_descs.append(op_desc) - for name in op_desc.output_arg_names(): - input_grad_names_set.add(name) + if some_in_set(op.operands_source(), inputs_set): + for value in op.results(): + if value not in no_grad_set: + inputs_set.add(value) + else: + relevant_op_flags[i] = False - if is_append_grad: - grad_to_var.update(op_grad_to_var) + # from output to input + for i, op in reversed(list(enumerate(total_ops))): + # while op support + if some_in_set(op.results(), outputs_set): + for operand in op.operands_source(): + if operand not in no_grad_set: + outputs_set.add(operand) else: - grad_op_descs.extend(grad_op_desc) - grad_to_var.update(op_grad_to_var) - - # record mapping between grad var name and var name (Only for auto parallel) - grad_var_to_var = None - if distop_context is not None: - grad_var_to_var = distop_context.grad_var_to_var[ - program._appending_grad_times - ] - # sum parameter's gradients' var given multiple var gradient - grad_op_descs = _addup_repetitive_outputs_( - grad_op_descs, - block.idx, - grad_var_to_var, - grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, - ) - - # if all outputs of the grad op are in no_grad_set, then just remove and fill zero - # if all inputs of the grad op are in no_grad_set, just remove this op - grad_op_descs = _remove_no_grad_branch_( - grad_op_descs, - no_grad_dict[block.idx], - grad_op_id_to_fwd_op, - target_vars, - ) - - # remove some backward ops - # TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem - if not core._is_bwd_prim_enabled(): - not_need_ops = _find_not_need_ops( - grad_op_descs, ops, input_grad_names_set - ) - grad_op_descs = [ - op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops - ] - else: - logging.debug( - "Running backward composite and disable find_not_need_ops" - ) - - # append op_desc in grad_op_descs to target_block - op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() - backward = core.op_proto_and_checker_maker.OpRole.Backward - for op_desc in grad_op_descs: - new_op_desc = target_block.desc.append_op() - new_op_desc.copy_from(op_desc) - new_op_desc._set_attr(op_role_attr_name, backward) - grad_to_var["__current_op_desc__"] = new_op_desc - if callbacks is not None: - assert isinstance(callbacks, (list, tuple)) - for cb in callbacks: - cb(block=target_block, context=grad_to_var) - - -def _is_grad_var_(var_name): - return core.grad_var_suffix() in var_name - - -# Find the op who holds the sub_block as its "sub_block" attr -def _find_parent_op_(sub_block): - sub_block_id = sub_block.idx - - if sub_block_id == 0: - return None - - program = sub_block.program - for block_id in range(program.num_blocks): - block_desc = program.block(block_id).desc - for op_idx in range(block_desc.op_size()): - op = block_desc.op(op_idx) - if ( - op.has_attr("sub_block") - and op._block_attr_id("sub_block") == sub_block_id - ): - return op + relevant_op_flags[i] = False - # NOTE(paddle-dev): When optimizer is added in conditional block, - # sub_block may not be found. - return None + effective_ops = [ + total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] + ] + uneffective_ops = [ + total_ops[i] + for i in reversed(range(len(total_ops))) + if not relevant_op_flags[i] + ] + return effective_ops, uneffective_ops -def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): - """ - Create new variables required by backward pass. - Args: - block(Block): the block where new variables will be created - start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created - grad_to_var(dict): - key(str): grad variable name - val(str): corresponding forward variable name - In most cases, this dict is generated by _append_backward_ops_() - grad_info_map(dict)(output argument): - key(str): forward variable name - val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable - """ - ops_to_remove = [] +def update_no_grad_set_after_prune( + block, effective_forward_op, no_grad_set, inputs, outputs +): ''' - NOTE(paddle-dev): while_grad op may hold some inputs which are not found - in the parent/forward block, and they are also the outputs of while_grad - op. These kinds of inputs are the recursive outputs inside while_grad op. - They should be considered as "already created" when scanning the inner - ops of while_grad ops. + update no_grad_set after forward prune + + from inputs to outputs add value not in the path to no_grad_set, + from outputs to inputs add value not in the path to no_grad_set, ''' - parent_op = _find_parent_op_(block) - parent_op_vars = [] - if parent_op is not None: - input_args = parent_op.input_arg_names() - output_args = parent_op.output_arg_names() - for in_arg in input_args: - if in_arg in output_args: - parent_op_vars.append(in_arg) - - for op_idx in range(start_op_idx, block.desc.op_size()): - op_desc = block.desc.op(op_idx) - if op_desc.has_attr("sub_block"): - sub_block = block.program.block(op_desc._block_attr_id("sub_block")) - _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map) - - grad_var_ins = [ - var for var in op_desc.input_arg_names() if _is_grad_var_(var) - ] - grad_var_outs = [ - var for var in op_desc.output_arg_names() if _is_grad_var_(var) - ] - - inputs = [ - var - for var in op_desc.input_arg_names() - if var != core.empty_var_name() - ] - outputs = [ - var - for var in op_desc.output_arg_names() - if var != core.empty_var_name() - ] - - # If the outputs of grad op is empty, just remove it - if not outputs: - ops_to_remove.append(op_idx) - continue - else: - ''' - If the output is not empty and there is any grad input, find - whether there is any existing input. If not, just remove it. - ''' - if grad_var_ins: - existing_grad_var_ins = [ - var - for var in grad_var_ins - if block.desc.has_var_recursive(var.encode()) - or var in parent_op_vars - ] - if not existing_grad_var_ins: - ''' - FIXME(paddle-dev, zengjinle): rnn_memory_helper_grad is used - in recurrent op. The input of this op does not even exist in - the program! Therefore, any dependency analysis would not - work to this op! If I do not add the following code, this op - would be pruned, and the calculation result would be wrong. - Maybe we should re-design this op later... - ''' - if op_desc.type() not in ['rnn_memory_helper_grad']: - ops_to_remove.append(op_idx) - continue - - # sum may create invalid variable, here to deal with it. - if op_desc.type() == 'sum': - new_inputs = [] - for grad_var_name in op_desc.input_arg_names(): - if block.desc.has_var_recursive(grad_var_name.encode()): - # meet invalid sum variables, remove the invalid operand. - new_inputs.append(grad_var_name) - assert ( - len(new_inputs) > 0 - ), "After remove invalid variables, sum op have no inputs." - op_desc.set_input("X", new_inputs) - - new_vars = set() - # create new gradient variables - for grad_var_name in op_desc.output_arg_names(): - if ( - block.desc.has_var_recursive(grad_var_name.encode()) - or grad_var_name == core.empty_var_name() + inputs_set = set(inputs) + if inputs_set: + for op in block.ops: + if some_in_set(op.operands_source(), inputs_set): + for value in op.results(): + if value not in no_grad_set: + inputs_set.add(value) + + for op in effective_forward_op: + for value in op.operands_source(): + if value not in inputs_set: # and value.get_stopgradient(): + no_grad_set.add(value) + + outputs_set = set(outputs) + no_grad_set_tmp = set() + for op in reversed(effective_forward_op): + for output in op.results(): + if output not in outputs_set and not some_in_set( + [output], set(op.operands_source()) ): - continue - block.desc.var(grad_var_name.encode()) - new_vars.add(grad_var_name) - if grad_var_name not in grad_to_var: - continue - grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) - # infer_shape and infer_type - op_desc.check_attrs() - op_desc.infer_var_type(block.desc) - op_desc.infer_shape(block.desc) - - for arg in op_desc.output_arg_names(): - if arg in new_vars: - _infer_var_data_type_shape_(arg, block) - - for op_idx in reversed(ops_to_remove): - block.desc._remove_op(op_idx, op_idx + 1) - - -def infershape_for_composite(block, grad_op_desc): - # NOTE: why pruning the operator with empty output here ? - # Some backward operator will output empty var, which will cause infer - # shape error, such assign with input's stop_gradient=True - if len(grad_op_desc.output_arg_names()) == 0: - return - - # create output variable - new_vars = set() - for grad_var_name in grad_op_desc.output_arg_names(): - if not ( - block.desc.has_var_recursive(grad_var_name.encode()) - or grad_var_name == core.empty_var_name() - ): - # NOTE: stop_gradient will be set in append_op - desc = block.desc.var(grad_var_name.encode()) - block.create_var(name=grad_var_name, desc=desc, type=desc.type()) - new_vars.add(grad_var_name) - - # NOTE For the primitive operator generated by decompositing phi grad kernel, - # we Operator to reconstruct the op_desc for reusing some complex logic, such - # as processing dispensable input, intermediate output, extra attrs, etc... - if framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()): - op = block.append_op( - type=grad_op_desc.type(), - inputs={ - name: [block._find_var_recursive(arg) for arg in args] - for name, args in grad_op_desc.inputs().items() - }, - outputs={ - name: [block._find_var_recursive(arg) for arg in args] - for name, args in grad_op_desc.outputs().items() - }, - # NOTE Runtime attr will be ignore as the c++ GetRuntimeAttr - # interface cann't be exported to python. Please note the WARNING - # message logged in RuntimeAttrs of composite_grad_desc_maker.h - attrs=grad_op_desc.get_attr_map(), - ) - op.desc._set_attr( - core.op_proto_and_checker_maker.kOpRoleAttrName(), - core.op_proto_and_checker_maker.OpRole.Backward, - ) - grad_op_desc.copy_from(op.desc) - # For the backward operator, we reuse the logic of _append_backward_var - else: - op_desc = block.desc.append_op() - op_desc.copy_from(grad_op_desc) - op_desc._set_attr( - core.op_proto_and_checker_maker.kOpRoleAttrName(), - core.op_proto_and_checker_maker.OpRole.Backward, - ) - op_desc.check_attrs() - op_desc.infer_var_type(block.desc) - op_desc.infer_shape(block.desc) - grad_op_desc.copy_from(op_desc) - - if not framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()): - # NOTE: Some raw fluid grad operators which hadn't been decomposed may not - # implement InferVarType method, such as elementwise_xx_grad, and it will - # cause the dtype or shape of corresponding cotangent incorrect. This - # patch set the cotangent dtype and shape same with corresponding - # forward variable. For primitive operators, we have ensure all - # InferVarType method to be executed correctly in PR#52818, we skip - # this patch for primitive operators. - for arg in grad_op_desc.output_arg_names(): - if arg in new_vars: - _infer_var_data_type_shape_(arg, block) - - -def _rename_grad_( - block, start_op_idx, grad_to_var, target_grad_map, skip_rename_var_list -): - var_map = copy.copy(target_grad_map) - for op_idx in range(start_op_idx, block.desc.op_size()): - op_desc = block.desc.op(op_idx) - for name in op_desc.input_arg_names(): - if name in var_map: - op_desc._rename_input(name, var_map[name]) - - for name in op_desc.output_arg_names(): - if "@GRAD" not in name: - continue - if block.desc.find_var(name.encode("ascii")): - if name in skip_rename_var_list: - continue - new_name = unique_name.generate(name) - op_desc._rename_output(name, new_name) - var_map[name] = new_name - - for g, ng in var_map.items(): - if g in grad_to_var: - grad_to_var[ng] = grad_to_var[g] - grad_to_var.pop(g) - - -def _get_stop_gradients_(program): - no_grad_dict = dict() - assert isinstance(program, framework.Program) - for block in program.blocks: - assert isinstance(block, framework.Block) - block_no_grad_set = set() - for var in list(block.vars.values()): - assert isinstance(var, framework.Variable) - if var.stop_gradient: - block_no_grad_set.add(_append_grad_suffix_(var.name)) - no_grad_dict[block.idx] = block_no_grad_set - return no_grad_dict - - -def _get_son_parent_block_idx_dict(program, current_block_idx): - son_parent_block_idx_dict = collections.OrderedDict() - while current_block_idx >= 0: - parent_block_idx = program.block(current_block_idx).parent_idx - son_parent_block_idx_dict[current_block_idx] = parent_block_idx - current_block_idx = parent_block_idx - - return son_parent_block_idx_dict - - -def _get_no_grad_set_name(no_grad_set): - no_grad_set_name = set() - if no_grad_set is not None: - if isinstance(no_grad_set, (set, list, tuple)): - for i, no_grad_var in enumerate(no_grad_set): - if isinstance(no_grad_var, framework.Variable): - no_grad_set_name.add(no_grad_var.name) - elif isinstance(no_grad_var, str): - no_grad_set_name.add(no_grad_var) - else: - raise TypeError( - "The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s." - % (type(no_grad_var)) - ) - else: - raise TypeError( - "The type of no_grad_set should be set or list or tuple, but received {}".format( - type(no_grad_set) - ) - ) - return no_grad_set_name - - -@framework.static_only -def append_backward( - loss, - parameter_list=None, - no_grad_set=None, - callbacks=None, - checkpoints=None, - distop_context=None, -): - """ - :api_attr: Static Graph + no_grad_set_tmp.add(output) - This function appends backward part to main_program. + for input in op.operands_source(): + if input not in no_grad_set: + outputs_set.add(input) - A complete neural network training is made up of forward and backward - propagation. However, when we configure a network, we only need to - specify its forward part. This function uses the chain rule to automatically - generate the backward part according to the forward part. + no_grad_set.update(no_grad_set_tmp) - In most cases, users do not need to invoke this function manually. - It will be automatically invoked by the optimizer's `minimize` function. - Parameters: - loss(Tensor): The loss Tensor of the network. - parameter_list(list[Tensor|str]|tuple[Tensor|str], optional): List/Tuple of Parameters or Parameter.names - that need to be updated by optimizers. - If it is None, all parameters - will be updated. - Default: None. - no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients - should be ignored. All Tensors with - `stop_gradient=True` from all blocks will - be automatically added into this set. - If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set. - Default: None. - callbacks(list[callable object]|tuple[callable object], optional): List/Tuple of callback functions. - The callbacks are used for - doing some custom jobs during - backward part building. All - callable objects in it will - be invoked once each time a - new gradient operator is added - into the program. The callable - object must have two input - parameters: ``block`` and ``context`` . - The ``block`` is the :ref:`api_guide_Block_en` which - the new gradient operator will - be added to. The ``context`` is a - map, whose keys are gradient - Tensor names and values are - corresponding original :ref:`api_guide_tensor_en` . - In addition to this, the ``context`` - has another special key-value pair: - the key is string ``__current_op_desc__`` - and the value is the op_desc of the - gradient operator who has just - triggered the callable object. - Default: None. - - Returns: - list of tuple ( :ref:`api_guide_tensor_en` , :ref:`api_guide_tensor_en` ): Pairs of parameter and its corresponding gradients. - The key is the parameter and the value is gradient Tensor. - - Raises: - AssertionError: If ``loss`` is not an instance of Tensor. - - Examples: - .. code-block:: python - - import paddle - import paddle.nn.functional as F - - paddle.enable_static() - - x = paddle.static.data(name='x', shape=[None, 13], dtype='int64') - y = paddle.static.data(name='y', shape=[None, 1], dtype='float32') - x_emb = paddle.static.nn.embedding(x, size=[100, 256]) - y_predict = paddle.static.nn.fc(x=x_emb, size=1, activation=None, name='my_fc') - loss = F.square_error_cost(input=y_predict, label=y) - avg_loss = paddle.mean(loss) - - # Get all weights in main_program, not include bias. - all_weights = [param for param in paddle.static.default_main_program().block(0).all_parameters() if 'w_' in param.name] - all_weights_name = [w.name for w in all_weights] - - # return all param_grads needed to be updated if parameter_list set default None. - p_g_list1 = paddle.static.append_backward(loss=avg_loss) - # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)] - - # return the param_grads corresponding to parameter_list that can be list of param (Tensor). - p_g_list2 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights) - # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)] - - # parameter_list can be list of param.name (str). - p_g_list3 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights_name) - # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)] +def inverse_sort_op(ops): + ''' + if topo graph is op1 -> op2 -> op3 + return [op3, op2, op1] - # no_grad_set can be set of Tensors that means grad will be cut off from these Tensors. - p_g_list4 = paddle.static.append_backward(loss=avg_loss, no_grad_set=set([x_emb])) - # output: [(my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)] + ''' - # no_grad_set can be set of Tensor.name when the Tensor is created inside layers and can't be specified explicitly. - p_g_list5 = paddle.static.append_backward(loss=avg_loss, no_grad_set=set(['my_fc.b_0'])) - # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)] + # init pending_count[op] which descibes number of + # pending edges for its grad_op - # return [] because all param_grads are filtered by no_grad_set. - p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights)) + pending_count = collections.defaultdict(int) + ops_set = set(ops) + sorted_list = [] + for op in ops: + for x in op.operands(): + if x.source().get_defining_op() in ops_set: + pending_count[x.source().get_defining_op()] += 1 - """ - grad_op_id_to_fwd_op = ( - {} - ) # for cuda graph usage, recording the mapping between grad op original id to fwd op + queue = collections.deque() - check_type( - loss, 'loss', framework.Variable, 'paddle.static.append_backward' - ) + for op in ops: + if pending_count[op] == 0: + queue.append(op) - if loss.op is None: - # the loss is from a cloned program. Find loss op manually. - _find_loss_op_(loss) + while queue: + op = queue.popleft() + sorted_list.append(op) - loss.op._set_attr( - core.op_proto_and_checker_maker.kOpRoleAttrName(), - int(core.op_proto_and_checker_maker.OpRole.Forward) - | int(core.op_proto_and_checker_maker.OpRole.Loss), - ) + for x in op.operands(): + x_op = x.source().get_defining_op() + pending_count[x_op] -= 1 + if pending_count[x_op] == 0: + queue.append(x_op) - if callbacks is not None: - check_type( - callbacks, - 'callbacks', - (list, tuple), - 'paddle.static.append_backward', + if len(sorted_list) != len(ops): + raise ValueError( + "inverse_sort_op wrong, sorted_list size is not equal to origin_list size" ) - program = loss.block.program - root_block = program.block(0) - current_block_idx = program.current_block_idx - current_block = program.block(current_block_idx) - - is_in_control_flow = current_block_idx != 0 + return sorted_list - # Double grad is not supported in sub-block (control flow) - if not is_in_control_flow: - # _appending_grad_times used for double grad - program._appending_grad_times += 1 - if no_grad_set is None: - no_grad_set = set() - else: - no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set)) - no_grad_dict = _get_stop_gradients_(program) - # no_grad_set only contains vars in block 0 - # Todo(liym27): support vars in sub block - no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set))) - - # Currently it is only to support the optimizer.minimize - # in a switch branch, which can append_backward in a sub_block. - # Note: while_loop is in control flow, but it makes no sense to call optimizer in while. - # Todo: report error when it is in while_loop - if is_in_control_flow: - # create grad block if in switch control flow. - target_grad_block = program._create_block( - parent_idx=current_block.parent_idx - ) - target_grad_block._set_forward_block_idx(current_block_idx) - # after _create_block, program.current_block changes - else: - target_grad_block = root_block +def append_backward_ops( + block, effective_forward_op, no_grad_set, backward_ops, state +): + ''' + add grad_op in order of topological inverse sort + eg: + from :op1 -> v1 -> op2 -> v2 -> op3 -> v3 + to: og1_g <- v1_g <- op2_g <- v2_g <- op3_g <- v3_g - son_parent_block_idx_dict = _get_son_parent_block_idx_dict( - program, current_block_idx - ) + if op has grad_op, prepare its grad_op's inputs by value_to_valuegrad, + eg: + value_to_valuegrad[v3] = [[v3_g]]; + v2_g = call_vjp(op3, [v3_g], [v2_stopgradient]) - block_fwd_op_num_dict = {} # block_id: fwd_op_num - for idx in son_parent_block_idx_dict: - block_fwd_op_num_dict[idx] = program.block(idx).desc.op_size() - grad_to_var = dict() + special pattern 1: + v11 -> combine_op -> v1 -> op -> v3 + v12 -> + v2 -> + value_to_valuegrad[v3] = [[v3_g]] - # pass the cuda_graph_attr to the fill_constant which generates the loss_grad - op_desc = _create_loss_op_desc_(loss) - grad_op_id_to_fwd_op[op_desc.original_id()] = loss.op - target_grad_block.desc.append_op().copy_from(op_desc) + v1 is inside python api, we don't describe it in backward process(state) + so v1_grad is inside vjp, we don't describe it in backward process(state) + [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [v3_g], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient) - for block_idx in son_parent_block_idx_dict: - block = program.block(block_idx) - block_no_grad_set = set( - map(_strip_grad_suffix_, no_grad_dict[block_idx]) - ) + op_vjp is: + v11_g <- split_op <- v1_g <- op_g <- v3_g + v12_g <- + v2_g <- - op_path_dict = dict() - op_path = _find_op_path_( - block, [loss], [], block_no_grad_set, op_path_dict - ) + value_to_valuegrad[v11] = [[v11_g]] + value_to_valuegrad[v12] = [[v12_g]] + value_to_valuegrad[v2] = [[v2_g]] - no_grad_vars = _find_no_grad_vars( - block, op_path, [loss], block_no_grad_set - ) + if op don't has grad_op, if it don't has input and it's output has more than + one output_grad, add sumop for grad aggregation. + (eg: full op and get_parameter op etc.) - block_no_grad_set.update(no_grad_vars) - no_grad_dict[block_idx].update( - list(map(_append_grad_suffix_, block_no_grad_set)) - ) + else continue to next op. + ''' - input_grad_names_set = None - # For double backward, input_grad_names is used for filtering - # some non-used gradients op(s). - - # TODO(liym27): need a better design. - # not support double grad in control flow sub-block now. - if not is_in_control_flow: - if program._appending_grad_times > 1: - input_grad_names_set = set([_append_grad_suffix_(loss.name)]) - - # TODO: support _append_backward_ops_with_checkpoints_ in - # sub-block (control flow) - is_recompute = False - if ( - checkpoints is not None - and isinstance(checkpoints, list) - and len(checkpoints) > 0 - ): - is_recompute = True - ( - program_stat, - checkpoint_names, - vars_should_be_hold, - recompute_segments, - ) = _append_backward_ops_with_checkpoints_( - root_block, - op_path, - [loss], - root_block, - no_grad_dict, - grad_to_var, - checkpoints, - grad_op_id_to_fwd_op, - ) - else: - _append_backward_ops_( - block, # the block where forward ops are in - op_path, - [loss], - target_grad_block, - no_grad_dict, - grad_to_var, - callbacks, - input_grad_names_set=input_grad_names_set, - op_path_dict=op_path_dict, - distop_context=distop_context, - grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, - ) + def make_output_grad(op): + zero_flag = [False] * op.num_results() + output_grads = [] + for i, value in enumerate(op.results()): + if ( + value not in state.value_to_valuegrad + or state.value_to_valuegrad[value] is None + ): + if ( + not value.use_empty() + and value.first_use().owner().name() == "builtin.split" + ): + # pattern case: + # this fwd_op's output is vectorType, it will split to + # Type by builtin.split op, so need get from split op's ouput + split_zero_flag, split_output_grad = make_output_grad( + value.first_use().owner() + ) + zero_flag[i] = all(split_zero_flag) + state.value_to_valuegrad[value] = [split_output_grad] + else: + # first case: + # this fwd_op's output didn't used by other fwd_op, + # so no output_grad created. + + # second case: + # last bwd_op return None because input in no_grad_set, + # but this bwd_op need a input. + grad_value = paddle.full( + value.shape, + 0.0, + dtype=value.dtype, + ) + fillop = grad_value.get_defining_op() - grad_info_map = dict() + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], fillop + ) + zero_flag[i] = True - # if in control flow, target_grad_block is a created new block which only contains grad ops, - # so fwd_op_num is set to 0. - fwd_op_num = ( - block_fwd_op_num_dict[current_block_idx] - if not is_in_control_flow - else 0 - ) + state.value_to_valuegrad[value] = [[grad_value]] - # Because append_backward may be called multiple times, - # we need rename the internal gradient variables so that they have - # different names. - _rename_grad_(target_grad_block, fwd_op_num, grad_to_var, {}, []) + if len(state.value_to_valuegrad[value]) > 1: + # one value is input of more than one fwd_op, + # so more than one bwd_op create input_grad, + # need add sum op to accumulate gradient - _append_backward_vars_( - target_grad_block, fwd_op_num, grad_to_var, grad_info_map - ) + paddle.add_n( + [item[0] for item in state.value_to_valuegrad[value]] + ) + combineop = block.ops[len(block.ops) - 2] + sumop = block.ops[len(block.ops) - 1] + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], combineop + ) + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], sumop + ) + state.value_to_valuegrad[value] = [[sumop.result(0)]] + state.value_to_sumvaluegrad[value] = state.value_to_valuegrad[ + value + ] - program.current_block_idx = current_block_idx - program._sync_with_cpp() - - # for cuda graph, copy the cuda graph attr from forward op to backward op - for op in target_grad_block.ops: - if grad_op_id_to_fwd_op.get(op.desc.original_id(), None) is not None: - fwd_op = grad_op_id_to_fwd_op[op.desc.original_id()] - op._cuda_graph_attr = fwd_op._cuda_graph_attr - - if parameter_list is not None: - check_type( - parameter_list, - 'parameter_list', - (list, tuple, set), - 'fluid.backward.append_backward', - ) - parameters = [] - for i, param in enumerate(parameter_list): - check_type( - param, - 'parameter_list[%s]' % i, - (framework.Variable, str), - 'fluid.backward.append_backward', - ) - if isinstance(param, framework.Variable): - parameters.append(param.name) - elif isinstance(param, str): - parameters.append(param) - else: - params = program.global_block().all_parameters() - parameters = [param.name for param in params if param.trainable] + output_grads.append(state.value_to_valuegrad[value][0][0]) + return zero_flag, output_grads - params_and_grads = [] - op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName() - for param in parameters: - if param not in grad_info_map: - continue - grad_info = grad_info_map[param] - grad_block = grad_info[1] - if not grad_block.has_var(grad_info[0]): - raise ValueError( - "grad block[{0}] did not have grad var {1}".format( - grad_info[1], grad_info[0] + def make_input_stopgradient(op): + input_grad_stopgradient_list = [] + for input in op.operands_source(): + if input.get_defining_op().name() == "builtin.combine": + stop_gradient = make_input_stopgradient(input.get_defining_op()) + input_grad_stopgradient_list.append( + [info[0] for info in stop_gradient] ) - ) - # Get the param var from the global block - param_var = program.global_block().var(param) - grad_var = grad_block.var(grad_info[0]) - if not is_in_control_flow: - if loss.block.has_var(grad_info[0]): - params_and_grads.append((param_var, grad_var)) else: - params_and_grads.append((param_var, None)) - else: - params_and_grads.append((param_var, grad_var)) - - for p, g in params_and_grads: - if g is None: - continue - ops = ( - grad_block.ops if is_in_control_flow else program.global_block().ops - ) - for op in reversed(ops): - assert isinstance(op, framework.Operator) - if g.name in op.output_arg_names: - g.op = op - break - - if g.op is None: - raise ValueError("Unexpected branch") - attr_val = [p.name, g.name] - if g.op.has_attr(op_role_var_attr_name): - attr_val.extend(g.op.attr(op_role_var_attr_name)) - g.op._set_attr(op_role_var_attr_name, attr_val) - - if is_recompute: - return params_and_grads, checkpoint_names - else: - return params_and_grads - - -def _as_list(x): - if x is None: - return [] - return list(x) if isinstance(x, Sequence) else [x] - - -def _is_ancestor_block(ancestor_block, block): - prog = block.program - ancestor_idx = ancestor_block.idx - parent_idx = block.parent_idx - - while parent_idx != -1: - if parent_idx == ancestor_idx: - return True - parent_idx = prog.block(parent_idx).parent_idx - - return False - - -def _get_output_names(cur_block, targets): - """ - In `cur_block`, get output names those linked to targets. - NOTE: - 1. `targets` can be in `cur_block`; - Usually, `targets` is in `cur_block`. However, considering control flow, - 2. `targets` may be in sub-block but `cur_block` is an ancestor of `targets[0].block`; - 3. `targets` may be in the block which is ancestor of `cur_block`. - """ - - block = targets[0].block if targets else cur_block - current_output_names = set([out.name for out in targets]) - - # 1. If `targets` in cur_block or the ancestral block of `cur_block` - if block.idx == cur_block.idx or _is_ancestor_block(block, cur_block): - return current_output_names - - # 2. If `cur_block` is an ancestor of `targets[0].block`, run while loop - prog = cur_block.program - while block.idx != cur_block.idx: - assert block.parent_idx != -1 - parent_block = prog.block(block.parent_idx) - - parent_block_output_names = set() - for op in reversed(block.ops): - if _some_in_set_(op.desc.output_arg_names(), current_output_names): - for name in op.desc.input_arg_names(): - current_output_names.add(name) - if not block.desc.find_var( - name.encode() - ) and parent_block.desc.find_var(name.encode()): - parent_block_output_names.add(name) - - block = parent_block - current_output_names = parent_block_output_names - - return current_output_names - - -def _find_no_grad_vars(block, op_path, targets, no_grad_set): - """ - Find the vars which is not used in the program, and - those vars belong to no_grad_var. - """ - output_names = _get_output_names(block, targets) - no_grad_var = [] - for i, op in reversed(list(enumerate(op_path))): - # If the op has sub_block, it is too complicated to find the correct no_grad_var. - if not op.has_attr("sub_block"): - for out_var in op.desc.output_arg_names(): - if ( - out_var not in output_names - and out_var not in op.desc.input_arg_names() - and not block.vars[out_var].stop_gradient - ): - no_grad_var.append(out_var) - for name in op.desc.input_arg_names(): - if name not in no_grad_set: - output_names.add(name) - return set(no_grad_var) - - -def _find_op_path_( - block, targets, inputs, no_grad_set, op_path_dict=None, is_while=False -): - """ - It is used to find the grad path in `block`. - - Args: - block(Block): The block in which to get op path. - targets(list[Variable]): The target variables. - inputs(list[Variable]): The input variables. - no_grad_set(set): The set of no grad var name. no_grad_set will be changed. - op_path_dict(dict): op_path_dict will be changed. op_path_dict will be changed. - key(int) block index - val(list) the op path of block(index) - is_while(bool): Whether or not `block` is while block - Return: - The forward op path of block corresponding to backward op. - """ - - input_names = set([inp.name for inp in inputs]) - output_names = _get_output_names(block, targets) - if op_path_dict is None: - op_path_dict = dict() - - relevant_op_flags = [True] * len(block.ops) - - # All the inputs of the block are used if inputs is empty, - if inputs: - for i, op in enumerate(block.ops): - if _some_in_set_( - op.desc.input_arg_names(), input_names - ) and not core.has_empty_grad_op_maker(op.type): - for name in op.desc.output_arg_names(): - if name not in no_grad_set: - input_names.add(name) + if input in no_grad_set: + input_grad_stopgradient_list.append([True]) + else: + input_grad_stopgradient_list.append([False]) + return input_grad_stopgradient_list + + def update_input_grad_map(op, input_grad_list): + for i, input in enumerate(op.operands_source()): + if input.get_defining_op().name() == "builtin.combine": + update_input_grad_map( + input.get_defining_op(), input_grad_list[i] + ) else: - relevant_op_flags[i] = False + input_grad = input_grad_list[i] + if isinstance(input_grad, list): + state.value_to_valuegrad[input].append(input_grad) + else: + state.value_to_valuegrad[input].append([input_grad]) + + # there are four patterns: + # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) + # [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) + # [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType) + # [op4] (op4's inputs and outputs are not vectorType) + # einsum has twp vectorType outputs, special pattern + + clear_effective_forward_op = [] + + for op in effective_forward_op: + if op.name() != "builtin.combine" and op.name() != "builtin.split": + clear_effective_forward_op.append(op) + + for op in clear_effective_forward_op: + if paddle.framework.core.has_vjp(op): + # prepare output_grad + output_grad_list = [] # (opresult) + zero_flag, output_grad = make_output_grad(op) + output_grad_list.append(output_grad) + + # all(zero_flag) support this op has no contribution for grad + # should be delete (prune sub_graph) + if len(output_grad_list) == 0 or all(zero_flag): + continue - for i, op in reversed(list(enumerate(block.ops))): - if op.has_attr("sub_block"): - sub_block_id = op._block_attr_id("sub_block") - sub_block = block.program.block(sub_block_id) - sub_block_target_names = output_names & set(op.output_arg_names) - sub_block_path = _get_sub_block_path( - sub_block, op, set(), op_path_dict, sub_block_target_names + # prepare input_grad stop_gradient info. + input_grad_stopgradient_list = make_input_stopgradient(op) + + # create grad_op + before_ops_num = len(block.ops) + input_grad_list = paddle.framework.core.call_vjp( + op, output_grad_list, input_grad_stopgradient_list ) - op_path_dict[sub_block_id] = sub_block_path - - if _some_in_set_( - op.desc.output_arg_names(), output_names - ) and not core.has_empty_grad_op_maker(op.type): - for name in op.desc.input_arg_names(): - if name not in no_grad_set: - output_names.add(name) - else: - relevant_op_flags[i] = False + after_ops_num = len(block.ops) - if is_while: - # If block is while block, dealing with op specifically again. - # TODO(liym27): Consider special types of ops. - for i, op in reversed(list(enumerate(block.ops))): - if relevant_op_flags[i] == False and _some_in_set_( - op.desc.output_arg_names(), output_names - ): - relevant_op_flags[i] = True - if not core.has_empty_grad_op_maker(op.type): - for name in op.desc.input_arg_names(): - if name not in no_grad_set: - output_names.add(name) - - op_path = [ - block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i] - ] + # update grad_op structure + for i in range(before_ops_num, after_ops_num): + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], block.ops[i] + ) - if inputs: - for op in op_path: - for name in op.desc.input_arg_names(): - if name not in input_names and block.vars[name].stop_gradient: - no_grad_set.add(name) + # update input_grad map + update_input_grad_map(op, input_grad_list) - return op_path + else: + if op.num_operands() == 0 and op.num_results() != 0: + for value in op.results(): + if len(state.value_to_valuegrad[value]) > 1: + # need add sum op + paddle.add_n( + [ + item[0] + for item in state.value_to_valuegrad[value] + ] + ) + combineop = block.ops[len(block.ops) - 2] + sumop = block.ops[len(block.ops) - 1] + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], combineop + ) + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], sumop + ) + state.value_to_valuegrad[value] = [[sumop.result(0)]] + state.value_to_sumvaluegrad[ + value + ] = state.value_to_valuegrad[value] + else: + state.op_to_opgrad[op] = [] + else: + state.op_to_opgrad[op] = [] -def calc_gradient_helper( - targets, inputs, target_gradients=None, no_grad_set=None -): +def create_backward_prune_set(inputs, outputs, no_grad_set, state): + outputs_set = set() + for input in inputs: + for item in input.first_use().owner().operands_source(): + if state.value_to_valuegrad[item] != []: + outputs_set.add(state.value_to_valuegrad[item][0][0]) + inputs_set = set() + for output in outputs: + if state.value_to_valuegrad[output] != []: + inputs_set.add(state.value_to_valuegrad[output][0][0]) + + inputs_set_tmp = set() + for out_grad in inputs_set: + if not out_grad.use_empty(): + for item in out_grad.first_use().owner().operands_source(): + inputs_set_tmp.add(item) + inputs_set.update(inputs_set_tmp) + + no_gradvar_set = set() # grad_value of value in no_grad_set + for key in state.value_to_valuegrad: + if key in no_grad_set: + no_gradvar_set.add(state.value_to_valuegrad[key][0][0]) + + for key in state.value_to_sumvaluegrad: + if key in no_grad_set: + for item in state.value_to_sumvaluegrad[key][0]: + no_gradvar_set.add(item) + + return outputs_set, inputs_set, no_gradvar_set + + +def remove_op(block, op, state): ''' - Calculate gradient and return grad_info_map + remove op from block ''' - targets = _as_list(targets) - inputs = _as_list(inputs) - target_gradients = _as_list(target_gradients) - - block = targets[0].block - prog = block.program - # increase appending gradients times - prog._appending_grad_times += 1 - block_idx = block.idx - - if not target_gradients: - target_gradients = [None] * len(targets) - - if len(targets) != len(target_gradients): - raise ValueError( - "Should have the same number of target_gradients as targets" - ) - - if no_grad_set is None: - no_grad_set = set() - else: - no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set)) - no_grad_dict = _get_stop_gradients_(prog) - no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set))) + block.remove_op(op) + if state.opgrad_to_op[op] != []: + fwd_op = state.opgrad_to_op[op][0] + state.op_to_opgrad[fwd_op].remove(op) - fwd_op_num = block.desc.op_size() + for valuegrad in op.results(): + if state.valuegrad_to_value[valuegrad] != []: + value = state.valuegrad_to_value[valuegrad][0] + state.value_to_valuegrad[value] = [] - input_grad_names_set = set() - - target_grad_map = {} - rename_var_map = {} - skip_rename_var_list = [] - grad_name_set = set() - for i, grad in enumerate(target_gradients): - target = targets[i] - grad_name = _append_grad_suffix_(target.name) - if grad is None: - op_desc = _create_op_desc_( - "fill_any_like", - {"X": [target.name]}, - {"Out": [grad_name]}, - { - "value": 1.0, - "dtype": target.dtype, - }, - ) - block.desc.append_op().copy_from(op_desc) - block.program._sync_with_cpp() - input_grad_names_set.add(grad_name) - skip_rename_var_list.append(grad_name) - else: - if target.block.idx != block_idx or target.block.program != prog: - raise ValueError("all targets must be in the same block") - if target.shape != grad.shape: + if value in state.sumvaluegrad_to_value: raise ValueError( - "The shapes of target and grad are different: %s %s" - % (target.name, grad.name) + 'input_grad in [%s] is value which need to sum ', op.name() ) - target_grad_map[_append_grad_suffix_(target.name)] = grad.name - input_grad_names_set.add(grad.name) - rename_var_map[grad_name] = grad.name - - grad_name_set.add(grad_name) - if core._is_bwd_prim_enabled(): - core._set_prim_target_grad_name(target_grad_map) - # For double backward, input_grad_names is used for filter - # some non-used gradients op. rename_var_map is used to - # associate target_grad var name with first grad_op input name. - if prog._appending_grad_times == 1: - input_grad_names_set = None - rename_var_map = {} - for input in inputs: - if input.block.program != prog: - raise "input must be in the same program as targets" - block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) +def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): + block = outputs[0].get_defining_op().get_parent_block() + state = State(block.get_parent_program()) + # check all inputs and outputs in the same block + check_all_puts(block, inputs, outputs) + # update no_grad_set if some value stop_gradient=True + update_no_grad_set_by_stopgradient(block, no_grad_set) + complete_outputs, _, backward_ops = prepare_grad_outputs( + block, + grad_outputs, + outputs, + state.value_to_valuegrad, + state.op_to_opgrad, + ) - op_path_dict = dict() - op_path = _find_op_path_( - block, targets, inputs, block_no_grad_set, op_path_dict + inputs_set = set(inputs) + outputs_set = set(complete_outputs) + effective_forward_op, _ = prune_ops( + block.ops, inputs_set, outputs_set, no_grad_set + ) + update_no_grad_set_after_prune( + block, effective_forward_op, no_grad_set, inputs, complete_outputs ) - # only for composite to add grad_var of the last forward op - # who has more than one output, but targets only has one, - # so targets_gradients only add one grad_var, - # eg: op1 -> op2 -> var1 / var2 targets = var1, - # targets_gradients = var1_grad, need to add var2_grad here. - tmp_targets = targets + inverse_effective_forward_op = inverse_sort_op(effective_forward_op) - if core._is_bwd_prim_enabled(): - for op in reversed(block.ops): - if op.type == "fill_any_like": - continue - # Some outputs of composite op are not needed and will be removed. - # Thus, those vars should not be added with another op. - keep_var_list = [] - if op.type in core.ops_contain_none.keys(): - values = core.ops_contain_none[op.type] - if isinstance(values, list): - none_vars = values - else: - none_vars = values(op) - for none_var_name in none_vars: - keep_var_list.append(op.output(none_var_name)[0]) - - for var_name in op.desc.output_arg_names(): - if keep_var_list and (var_name in keep_var_list): - continue - grad_var_name = _append_grad_suffix_(var_name) - if grad_var_name not in grad_name_set: - op_desc = _create_op_desc_( - "fill_any_like", - {"X": [var_name]}, - {"Out": [grad_var_name]}, - {'value': 0, 'dtype': targets[0].dtype}, - ) - block.desc.append_op().copy_from(op_desc) - tmp_targets.append(block.var(var_name)) - break - block.program._sync_with_cpp() - - # find no grad var by op_path - no_grad_vars = _find_no_grad_vars( - block, op_path, tmp_targets, block_no_grad_set + append_backward_ops( + block, inverse_effective_forward_op, no_grad_set, backward_ops, state ) - block_no_grad_set.update(no_grad_vars) + # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) - no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set))) - grad_to_var = dict() - grad_info_map = dict() - _append_backward_ops_( - block, - op_path, - targets, - block, - no_grad_dict, - grad_to_var, - input_grad_names_set=input_grad_names_set, - op_path_dict=op_path_dict, - rename_var_map=rename_var_map, + outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( + inputs, complete_outputs, no_grad_set, state ) - - # Because calc_gradient may be called multiple times, - # we need rename the internal gradient variables so that they have - # different names. - _rename_grad_( - block, fwd_op_num, grad_to_var, target_grad_map, skip_rename_var_list + _, remove_ops = prune_ops( + backward_ops, inputs_set, outputs_set, no_gradvar_set ) - _append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map) - prog._sync_with_cpp() - return grad_info_map + state.turn_map() + for bwd_op in inverse_sort_op(remove_ops): + remove_op(block, bwd_op, state) + state.turn_map() + input_grad_map = state.value_to_valuegrad -def _get_grad_vars(grad_info_map, inputs): - inputs = _as_list(inputs) - grad_vars = [] - for input_var in inputs: - if input_var.name not in grad_info_map: - grad_vars.append(None) - else: - grad_info = grad_info_map[input_var.name] - grad_block = grad_info[1] - grad_var = grad_block.var(grad_info[0]) - grad_vars.append(grad_var) - return grad_vars + return input_grad_map -def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): +def calc_gradient(outputs, inputs, grad_outputs, no_grad_set): """ - Backpropagate the gradients of targets to inputs. + caclulate gradient of input Args: - targets(Tensor|list[Tensor]|tuple[Tensor]): The target Tensors - inputs(Tensor|list[Tensor]|tuple[Tensor]): The input Tensors - target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensors - of targets which has the same shape with targets, If None, ones will - be created for them. - no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients - should be ignored. All Tensors with - `stop_gradient=True` from all blocks will - be automatically added into this set. - If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set. - Default: None. + outputs (Value|list(Value)|tuple(Value)): the output Value or + Value list/tuple of the graph to compute gradients. + inputs (Value|list(Value)|tuple(Value)): the input Value or + Value list/tuple of the graph to compute gradients. The returned + values of this API are the gradients of `inputs` . + grad_outputs (Value|list(Value|None)|tuple(Value|None), optional): + initial gradient values of `outputs` . If `grad_outputs` is None, + the initial gradient values of `outputs` would be Values filled with 1; + if `grad_outputs` is not None, it must have the same length as `outputs` , + and in this case, the initial gradient value of the i-th `outputs` would + be: (1) a Value filled with 1 when the i-th element of `grad_outputs` + is None; (2) the i-th element of `grad_outputs` when the i-th element of + `grad_outputs` is a Value. Default None. + no_grad_set (set(Value), optional): + the Values whose gradients are not needed to compute. Default None. Return: - (list[Tensor]): A list of gradients for inputs + list[Value]:A list of gradients for inputs If an input does not affect targets, the corresponding gradient Tensor will be None + TODO if allow_unused=False raise TypeError() if input_grad has None """ - - # NOTE: If you want to modify the logic of calc_gradient, please modify - # it inside the calc_gradient_helper and _get_grad_vars functions - # to ensure the correctness of dy2st mode. - grad_info_map = calc_gradient_helper( - targets, - inputs, - target_gradients=target_gradients, - no_grad_set=no_grad_set, + # record input value and its gradient (Value to Value) + input_to_inputgrad_map = calc_gradient_helper( + outputs, inputs, grad_outputs=grad_outputs, no_grad_set=no_grad_set ) - grad_vars = _get_grad_vars(grad_info_map, inputs) - - if len(grad_vars) == 1: - return grad_vars[0] - else: - return grad_vars - - -@framework.static_only -def gradients(targets, inputs, target_gradients=None, no_grad_set=None): - """ - - Backpropagate the gradients of targets to inputs. - - Args: - targets (Tensor|list[Tensor]|tuple[Tensor]): The target Tensors. - inputs (Tensor|list[Tensor]|tuple[Tensor]): The input Tensors. - target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensor - of targets which has the same shape with targets, If None, ones will - be created for them. - no_grad_set (set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients - should be ignored. All Tensors with ``stop_gradient=True`` from all blocks will - be automatically added into this set. If this parameter is not None, the Tensors or Tensor.names - in this set will be added to the default set. Default: None. - - Return: - (list[Tensor]): A list of gradients for inputs - If an input does not affect targets, the corresponding gradient Tensor - will be None. - - Examples: + inputgrad = [] + for input in inputs: + inputgrad.append( + input_to_inputgrad_map[input][0][0] + if input_to_inputgrad_map[input] != [] + else None + ) + return inputgrad + + +def grad( + outputs, + inputs, + grad_outputs=None, + retain_graph=None, + create_graph=False, + only_inputs=True, + allow_unused=False, + no_grad_vars=None, +): + ''' + .. note:: + **This API is ONLY available in imperative mode.** - .. code-block:: python - :name: code-example - import paddle - import paddle.nn.functional as F + This API computes the sum of gradients of `outputs` with respect to each `inputs` . - paddle.enable_static() + Parameters: + outputs (Value|list(Value)|tuple(Value)): the output Value or + Value list/tuple of the graph to compute gradients. + inputs (Value|list(Value)|tuple(Value)): the input Value or + Value list/tuple of the graph to compute gradients. The returned + values of this API are the gradients of `inputs` . + grad_outputs (Value|list(Value|None)|tuple(Value|None), optional): + initial gradient values of `outputs` . If `grad_outputs` is None, + the initial gradient values of `outputs` would be Values filled with 1; + if `grad_outputs` is not None, it must have the same length as `outputs` , + and in this case, the initial gradient value of the i-th `outputs` would + be: (1) a Value filled with 1 when the i-th element of `grad_outputs` + is None; (2) the i-th element of `grad_outputs` when the i-th element of + `grad_outputs` is a Value. Default None. + retain_graph (bool, optional): whether to retain the forward graph which + is used to calculate the gradient. When it is True, the graph would + be retained, in which way users can calculate backward twice for the + same graph. When it is False, the graph would be freed. Default None, + which means it is equal to `create_graph` . + create_graph (bool, optional): whether to create the gradient graphs of + the computing process. When it is True, higher order derivatives are + supported to compute; when it is False, the gradient graphs of the + computing process would be discarded. Default False. + only_inputs (bool, optional): whether to only compute the gradients of + `inputs` . If it is False, the gradients of all remaining leaf + Values in the graph would be also computed and accumulated. + If it is True, only the gradients of `inputs` would be computed. + Default True. only_inputs=False is under development, and it is + not supported yet. + allow_unused (bool, optional): whether to raise error or return None if some + Values of `inputs` are unreachable in the graph. If some Values of + `inputs` are unreachable in the graph (i.e., their gradients are None), + error would be raised if allow_unused=False, or None would be returned as + their gradients if allow_unused=True. Default False. + no_grad_vars (Value|list(Value)|tuple(Value)|set(Value), optional): + the Values whose gradients are not needed to compute. Default None. - x = paddle.static.data(name='x', shape=[None, 2, 8, 8], dtype='float32') - x.stop_gradient=False - y = paddle.static.nn.conv2d(x, 4, 1, bias_attr=False) - y = F.relu(y) - z = paddle.static.gradients([y], x) - print(z) # [var x@GRAD : LOD_TENSOR.shape(-1, 2, 8, 8).dtype(float32).stop_gradient(False)] - """ + Returns: + list: a list of Values, whose length is the same as the Value number + inside `inputs`, and the i-th returned Value is the sum of gradients of + `outputs` with respect to the i-th `inputs`. + ''' check_type( - targets, - 'targets', - (framework.Variable, list, tuple), - 'paddle.static.gradients', + outputs, + 'outputs', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), + 'paddle.ir.grad', ) check_type( inputs, 'inputs', - (framework.Variable, list, tuple), - 'paddle.static.gradients', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), + 'paddle.ir.grad', ) check_type( - target_gradients, - 'target_gradients', - (framework.Variable, list, tuple, type(None)), - 'paddle.static.gradients', + grad_outputs, + 'grad_outputs', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)), + 'paddle.ir.grad', ) - outs = calc_gradient(targets, inputs, target_gradients, no_grad_set) - return _as_list(outs) - - -@framework.static_only -def gradients_with_optimizer(program, optimizer, inputs=None, outputs=None): - """ - :api_attr: Static Graph - - Backpropagate the gradients of the program and apply the gradients with the given optimizer. - - Args: - program (Program): The input program. - optimizer (Optimizer): The optimizer to apply the gradients. - inputs (Tensor|list[Tensor]|tuple[Tensor], optional): The input Tensors. - If None, the inputs will be created from the input variables in the given program. Default:None. - outputs (Tensor|list[Tensor]|tuple[Tensor], optional): The output Tensors. - If None, the outputs will be created from the output variables in the given program. Default: None. - Return: - tuple: tuple (optimize_ops, params_grads), A list of operators appended - by gradients_with_optimizer and a list of (param, grad) variable pairs, param is - ``Parameter``, grad is the gradient value corresponding to the parameter. - The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to - indicate program pruning. If so, the program will be pruned by ``feed`` and - ``fetch_list`` before run, see details in ``Executor``. - - Examples: - .. code-block:: python - - import paddle - import paddle.static as static - - paddle.enable_static() - - img = static.data(name='image', shape=[None, 784]) - pred = static.nn.fc(x=img, size=10, activation='relu') - loss = paddle.mean(pred) - opt_ops, pram_grads = paddle.fluid.backward.gradients_with_optimizer(static.default_main_program(), opt) - print(opt_ops) - - """ check_type( - program, - 'program', - paddle.fluid.Program, - 'paddle.static.gradients_with_optimizer', - ) - check_type( - optimizer, - 'optimizer', - paddle.optimizer.Optimizer, - 'paddle.static.gradients_with_optimizer', + no_grad_vars, + 'no_grad_vars', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)), + 'paddle.ir.grad', ) + outputs = _as_list(outputs) + inputs = _as_list(inputs) + grad_outputs = _as_list(grad_outputs) + if no_grad_vars is None: + no_grad_set = set() + elif no_grad_vars is not set: + no_grad_set = set(no_grad_vars) + else: + no_grad_set = no_grad_vars + + input_grad = calc_gradient(outputs, inputs, grad_outputs, no_grad_set) - if inputs is None or outputs is None: - in_set = set() - out_set = set() - for block in program.blocks: - for op in block.ops: - for name in op.input_arg_names: - in_set.add(block.vars[name]) - for name in op.output_arg_names: - out_set.add(block.vars[name]) - if inputs is None: - inputs = list(in_set.difference(out_set)) - if outputs is None: - outputs = list(out_set.difference(in_set)) - - grads = gradients(outputs, inputs) - - with program_guard(program, None): - pram_grads = [ - (pram, grad) - for pram, grad in zip(inputs, grads) - if isinstance(pram, paddle.fluid.framework.Parameter) - and grad is not None - ] - - optimize_ops = optimizer.apply_gradients(pram_grads) - - return optimize_ops, pram_grads + return input_grad diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 0a6246e74e8b2..cdb464bc5c35d 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1977,7 +1977,12 @@ def split(x, num_or_sections, axis=0, name=None): return _C_ops.split(input, num_or_sections, dim) else: if paddle.ir.core._use_new_ir_api(): - return paddle._ir_ops.split(input, num_or_sections, dim) + if not isinstance(num_or_sections, int): + return paddle._ir_ops.split(input, num_or_sections, dim) + else: + raise NotImplementedError( + "_ir_ops.split_with_num is not implemented, please change sections as list" + ) check_variable_and_dtype( input, diff --git a/test/ir/new_ir/test_build_op.py b/test/ir/new_ir/test_build_op.py index d7b8b1150e27b..16bc1adb0628e 100644 --- a/test/ir/new_ir/test_build_op.py +++ b/test/ir/new_ir/test_build_op.py @@ -127,8 +127,7 @@ def test_build_split_op(self): tanh_out = newir_program.block().ops[-1].result(0) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) with paddle.ir.core.program_guard(newir_program): - out = paddle.split(tanh_out, [1, 1], 0) - print(newir_program) + out = paddle.split(tanh_out, [2, 2], 0) self.assertEqual(out[0].get_defining_op().name(), "builtin.split") self.assertEqual( out[0] diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index 2b293f64bb8b8..84b6975434621 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -59,7 +59,6 @@ def test_grad(self): ) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) - ''' def test_full(self): # test create output_grad in backward use full op newir_program = get_ir_program_0() @@ -182,8 +181,6 @@ def test_concat(self): paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) -''' - if __name__ == "__main__": unittest.main() From a6b5af3298f22db04c5118bd807b779e92323fe0 Mon Sep 17 00:00:00 2001 From: wangruting Date: Thu, 24 Aug 2023 06:52:02 +0000 Subject: [PATCH 26/30] fix conflict --- python/paddle/autograd/backward.py | 35 ++++++++++-------------------- test/ir/new_ir/test_ir_backward.py | 18 ++++++++++++--- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index e631c02b0bd3e..5bf723be06c1b 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -204,22 +204,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): else: relevant_op_flags[i] = False - # recover full op or full_Intarray op created by mutable attribute. - total_ops_list = list(total_ops) - for i, op in enumerate(total_ops_list): - if relevant_op_flags[i] is False: - for result in op.results(): - if result.has_one_use(): - next_op = result.first_use().owner() - if ( - next_op in total_ops - and relevant_op_flags[total_ops_list.index(next_op)] - is True - ): - relevant_op_flags[i] = True - else: - continue - effective_ops = [ total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] ] @@ -356,12 +340,16 @@ def append_backward_ops( def make_output_grad(op): zero_flag = [False] * op.num_results() + output_grads = [] for i, value in enumerate(op.results()): if ( value not in state.value_to_valuegrad or state.value_to_valuegrad[value] is None ): - if value.first_use().owner().name() == "builtin.split": + if ( + not value.use_empty() + and value.first_use().owner().name() == "builtin.split" + ): # pattern case: # this fwd_op's output is vectorType, it will split to # Type by builtin.split op, so need get from split op's ouput @@ -369,7 +357,7 @@ def make_output_grad(op): value.first_use().owner() ) zero_flag[i] = all(split_zero_flag) - grad_value = [op_list[0] for op_list in split_output_grad] + state.value_to_valuegrad[value] = [split_output_grad] else: # first case: # this fwd_op's output didn't used by other fwd_op, @@ -390,7 +378,7 @@ def make_output_grad(op): ) zero_flag[i] = True - state.value_to_valuegrad[value] = [[grad_value]] + state.value_to_valuegrad[value] = [[grad_value]] if len(state.value_to_valuegrad[value]) > 1: # one value is input of more than one fwd_op, @@ -413,8 +401,8 @@ def make_output_grad(op): value ] - output_grad = state.value_to_valuegrad[value][0] - return zero_flag, output_grad + output_grads.append(state.value_to_valuegrad[value][0][0]) + return zero_flag, output_grads def make_input_stopgradient(op): input_grad_stopgradient_list = [] @@ -530,8 +518,9 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): inputs_set_tmp = set() for out_grad in inputs_set: - for item in out_grad.first_use().owner().operands_source(): - inputs_set_tmp.add(item) + if not out_grad.use_empty(): + for item in out_grad.first_use().owner().operands_source(): + inputs_set_tmp.add(item) inputs_set.update(inputs_set_tmp) no_gradvar_set = set() # grad_value of value in no_grad_set diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index 59a901f14f5e8..be29baa1069d2 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -31,7 +31,6 @@ def get_ir_program_0(): x_s = paddle.static.data('x', [4, 4], x.dtype) x_s.stop_gradient = False k_s = paddle.tanh(x_s) - print("old ir prog: ", main_program) newir_program = ir.translate_to_new_ir(main_program.desc) return newir_program @@ -102,10 +101,23 @@ def test_split(self): tanh_out = newir_program.block().ops[-1].result(0) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) with paddle.ir.core.program_guard(newir_program): - out = paddle.split(tanh_out, [1, 1], 0) + out = paddle.split(tanh_out, [2, 2], 0) input_grad = grad(out, input) - print(newir_program) + ops_name = [ + "pd.data", + "pd.tanh", + "pd.full_int_array", + "pd.full", + "pd.split", + "builtin.split", + "pd.full", + "builtin.combine", + "pd.split_grad", + "pd.tanh_grad", + ] + for i, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), ops_name[i]) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) From 2fc525f5429d2ec9eb548cdf228f07eeb6b52597 Mon Sep 17 00:00:00 2001 From: wangruting Date: Thu, 24 Aug 2023 06:55:16 +0000 Subject: [PATCH 27/30] fluid backward recover --- python/paddle/fluid/backward.py | 3180 +++++++++++++++++++++++++------ 1 file changed, 2596 insertions(+), 584 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 5bf723be06c1b..9b09ec11cd3ab 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,733 +12,2745 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .proto import framework_pb2 + +from paddle.fluid import framework as framework +from paddle.fluid import program_guard +from . import core import collections +import copy +import logging +from . import unique_name +from . import log_helper +import paddle.fluid +from .data_feeder import check_type +import warnings + from collections.abc import Sequence -import paddle.ir -from paddle.autograd.backward_utils import State +import re -""" - grad: for templete test, will combine in paddle.grad . - calc_gradient: for internal use, optest, parallel etc . - calc_gradient_helper: for dygraph to static . -""" -__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper'] +__all__ = [ + 'append_backward', + 'gradients', +] +_logger = log_helper.get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) -def check_type(input, input_name, expected_type, op_name, extra_message=''): - if not isinstance(input, expected_type): - raise TypeError( - f"The type of '{input_name}' in {op_name} must be {expected_type}, but received {type(input)}. {extra_message}" - ) +class ProgramStats: + def __init__(self, block, ops): + self.block = block + self.ops = ops + self.op_deps = {} # op-> in_ops, out_ops + self.var_op_deps = {} # var as input op, var as output op -def _as_list(x): - if x is None: + def get_input_nodes(self): + input_names = [] + for name in self.var_op_deps: + if ( + len(self.var_op_deps[name]["var_as_output_ops"]) == 0 + and len(self.var_op_deps[name]["var_as_input_ops"]) > 0 + ): + if self.block.var(name).persistable: + continue + input_names.append(name) + for op in self.ops: + if op.desc.type() == "read": + input_names.extend(op.desc.output_arg_names()) + return input_names + + def get_reserved_vars(self): + var_name = [] + for op in self.ops: + if op.desc.type() == "seed": + var_name.extend(op.desc.output_arg_names()) + return var_name + + def get_out_of_subgraph_vars(self, begin_op_idx, end_op_idx): + var_name = [] + for i in range(begin_op_idx, end_op_idx, 1): + for name in self.ops[i].desc.output_arg_names(): + if name in self.var_op_deps: + for idx in self.var_op_deps[name]["var_as_input_ops"]: + if idx >= end_op_idx: + var_name.append(name) + for name in self.ops[i].desc.input_arg_names(): + if name in self.var_op_deps: + for idx in self.var_op_deps[name]["var_as_output_ops"]: + if idx < begin_op_idx: + var_name.append(name) + return var_name + + def is_subgraph(self, var_group1, var_group2): + # should traverse from var_group1 to var_group2 + # max op idx in var_group2 + # min op idx in var_group1 + min_op_idx = len(self.ops) + max_op_idx = -1 + for name in var_group1: + if name not in self.var_op_deps: + return False, min_op_idx, max_op_idx + for name in var_group2: + if name not in self.var_op_deps: + return False, min_op_idx, max_op_idx + for name in var_group1: + op_idx = self.var_op_deps[name]["var_as_input_ops"] + for idx in op_idx: + min_op_idx = min(min_op_idx, idx) + for name in var_group2: + op_idx = self.var_op_deps[name]["var_as_output_ops"] + for idx in op_idx: + max_op_idx = max(max_op_idx, idx) + if min_op_idx >= max_op_idx: + return False, min_op_idx, max_op_idx + + return True, min_op_idx, max_op_idx + + def _update_segment_start(self, min_idx, pre_segment_end_idx): + """ + persist vars of amp-related cast should be included in recompute segment + """ + + def is_amp_cast(op): + return ( + op.desc.type() == 'cast' + and self.block.var(op.desc.input_arg_names()[0]).persistable + ) + + idx_ = min_idx - 1 + updated_min_idx = min_idx + while idx_ > pre_segment_end_idx: + if is_amp_cast(self.ops[idx_]): + _logger.info( + "found amp-cast op: {}, : {}".format( + self.ops[idx_].desc.type(), + self.ops[idx_].desc.input_arg_names()[0], + ) + ) + updated_min_idx = idx_ + idx_ -= 1 + else: + break + + return updated_min_idx + + def build_stats(self): + for i, op in enumerate(self.ops): + self.op_deps[i] = {"in_ops": [], "out_ops": []} + for j, name in enumerate(op.desc.input_arg_names()): + if name in self.var_op_deps: + self.op_deps[i]["in_ops"].extend( + self.var_op_deps[name]["var_as_output_ops"] + ) + for j, name in enumerate(op.desc.input_arg_names()): + if name in self.var_op_deps: + self.var_op_deps[name]["var_as_input_ops"].extend([i]) + else: + self.var_op_deps[name] = {} + self.var_op_deps[name]["var_as_input_ops"] = [i] + self.var_op_deps[name]["var_as_output_ops"] = [] + + for j, name in enumerate(op.desc.output_arg_names()): + if name in self.var_op_deps: + self.var_op_deps[name]["var_as_output_ops"].extend([i]) + else: + self.var_op_deps[name] = {} + self.var_op_deps[name]["var_as_input_ops"] = [] + self.var_op_deps[name]["var_as_output_ops"] = [i] + + for op_idx in self.op_deps[i]["in_ops"]: + self.op_deps[op_idx]["out_ops"].extend([i]) + + def sort_checkpoints(self, checkpoints_name): + sorted_checkpoints = [] + for name in checkpoints_name: + if name not in self.var_op_deps: + _logger.info( + "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." + % name + ) + elif self.var_op_deps[name]["var_as_output_ops"] == []: + # input nodes + sorted_checkpoints.append((name, -1)) + else: + sorted_checkpoints.append( + (name, max(self.var_op_deps[name]["var_as_output_ops"])) + ) + sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1]) + return [x[0] for x in sorted_checkpoints] + + def modify_forward_desc_for_recompute(self): + op_types = [op.desc.type() for op in self.ops] + if "dropout" not in op_types: + return + + op_idx = 0 + while op_idx < len(self.ops): + op = self.ops[op_idx] + if op.desc.type() != "dropout": + op_idx += 1 + continue + # already insert seed op before dropout + if op.input('Seed') is not None and len(op.input('Seed')) == 1: + op_idx += 1 + continue + # add a seed op so that the two dropout op can generate same output + op_unique_name = unique_name.generate("seed") + var_unique_name = unique_name.generate_with_ignorable_key( + ".".join([op_unique_name, 'tmp']) + ) + added_var = self.block.create_var( + name=var_unique_name, + dtype='int32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False, + ) + seed = 0 if op.attr("fix_seed") is False else int(op.attr("seed")) + + op_device_attr_name = ( + core.op_proto_and_checker_maker.kOpDeviceAttrName() + ) + op_device = "" + if op.desc.has_attr(op_device_attr_name): + op_device = op.desc.attr(op_device_attr_name) + + # Setting the force_cpu of seed to true will make the output of seed in cpu memory, + # reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang + added_op = self.block._insert_op( + index=op.idx, + type='seed', + inputs={}, + outputs={'Out': [added_var]}, + attrs={'seed': seed, 'op_device': op_device, 'force_cpu': True}, + ) + self.ops.insert(op_idx, added_op) + # modify dropout op desc so that it accept a seed var as input + op.desc.set_input("Seed", [var_unique_name]) + op.desc.remove_attr("fix_seed") + op.desc.remove_attr("seed") + self.block._sync_with_cpp() + op_idx += 2 + + +def _pretty_op_desc_(op_desc, prefix): + out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % ( + prefix + "_op", + str(op_desc.type()), + prefix + "_input", + " ".join(op_desc.input_arg_names()), + prefix + "_output", + " ".join(op_desc.output_arg_names()), + ) + return out_s + + +def _add_needed_descs_to_block( + descs, block, main_block, in_memory_vars, grad_op_id_to_fwd_op=None +): + if len(descs) == 0: return [] - return list(x) if isinstance(x, Sequence) else [x] + result_descs = [] + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + backward = core.op_proto_and_checker_maker.OpRole.Backward + for desc in descs: + origin_desc = desc + origin_is_operator = False + if isinstance(desc, framework.Operator): + desc = desc.desc + origin_is_operator = True + if isinstance(desc, tuple): + desc = desc[0] + is_needed = False + for name in desc.output_arg_names(): + if main_block.has_var(name) and main_block.var(name).persistable: + continue + if name not in in_memory_vars: + is_needed = True + if is_needed: + if origin_is_operator and grad_op_id_to_fwd_op is not None: + grad_op_id_to_fwd_op[desc.original_id()] = origin_desc + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(desc) + new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) + result_descs.append(new_op_desc) + return result_descs + + +def _add_descs_to_block(descs, block, grad_op_id_to_fwd_op=None): + if len(descs) == 0: + return [] + result_descs = [] + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + backward = core.op_proto_and_checker_maker.OpRole.Backward + for desc in descs: + if isinstance(desc, framework.Operator): + # for recompute, should record recompute ops + if grad_op_id_to_fwd_op is not None: + grad_op_id_to_fwd_op[desc.desc.original_id()] = desc + desc = desc.desc + if isinstance(desc, tuple): + desc = desc[0] + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(desc) + new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) + result_descs.append(new_op_desc) + return result_descs + + +def _find_loss_op_(loss): + for op in reversed(loss.block.ops): + assert isinstance(op, framework.Operator) + if ( + len(op.output_arg_names) == 1 + and op.output_arg_names[0] == loss.name + ): + loss.op = op + break + if loss.op is None: + raise ValueError("loss.op is None. Should not happen") + + +def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): + """ + Traverse all ops in op_descs[begin_idx : end_idx], + if any op has inputs/outputs named "old_name", rename it as 'new_name' + """ + if begin_idx is None: + begin_idx = 0 + if end_idx is None: + end_idx = len(op_descs) + if isinstance(op_descs, (list, tuple)): + for i in range(begin_idx, end_idx): + op_desc = op_descs[i] + if isinstance(op_desc, tuple): + op_desc = op_desc[0] + op_desc._rename_input(old_name, new_name) + op_desc._rename_output(old_name, new_name) + if isinstance(op_descs, collections.OrderedDict): + for key, value in op_descs.items(): + if isinstance(value, (list, tuple)): + for op_desc in value: + op_desc._rename_input(old_name, new_name) + op_desc._rename_output(old_name, new_name) + + +def _create_op_desc_(op_type, inputs, outputs, attrs): + """ + Create a C++ OpDesc object with specified inputs, outputs and attributes. + """ + op_desc = core.OpDesc() + op_desc.set_type(op_type) + for para, args in inputs.items(): + op_desc.set_input( + para, + list( + map( + lambda arg: arg.decode() if isinstance(arg, bytes) else arg, + args, + ) + ), + ) + for para, args in outputs.items(): + op_desc.set_output( + para, + list( + map( + lambda arg: arg.decode() if isinstance(arg, bytes) else arg, + args, + ) + ), + ) + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + op_device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + + if op_role_attr_name not in attrs: + attrs[ + op_role_attr_name + ] = core.op_proto_and_checker_maker.OpRole.Backward + if op_device_attr_name not in attrs: + attrs[op_device_attr_name] = "" + for name, val in attrs.items(): + if isinstance(val, framework.Block): + op_desc.set_block_attr(name, val.desc) + else: + op_desc._set_attr(name, val) + return op_desc -def check_all_puts(block, inputs, outputs): - for output in outputs: - if output.get_defining_op().get_parent_block() != block: - raise ValueError("all outputs must be in the same block") - for input in inputs: - if input.get_defining_op().get_parent_block() != block: - raise ValueError( - "all inputs must be in the same block with outputs" +def _create_loss_op_desc_(loss): + # 0-D Tensor or 0-Size Tensor + if len(loss.shape) == 0 or 0 in loss.shape: + create_shape = loss.shape + else: + create_shape = [1] + op_desc = _create_op_desc_( + "fill_constant", + {}, + {"Out": [_append_grad_suffix_(loss.name)]}, + { + "shape": create_shape, + "value": 1.0, + "dtype": loss.dtype, + "force_cpu": False, + core.op_proto_and_checker_maker.kOpRoleAttrName(): int( + core.op_proto_and_checker_maker.OpRole.Backward ) + | int(core.op_proto_and_checker_maker.OpRole.Loss), + core.op_proto_and_checker_maker.kOpDeviceAttrName(): loss.op.attr( + core.op_proto_and_checker_maker.kOpDeviceAttrName() + ), + }, + ) + return op_desc -def update_no_grad_set_by_stopgradient(block, no_grad_set): - for op in block.ops: - for opresult_idx in range(op.num_results()): - value = op.result(opresult_idx) - if value.stop_gradient and value not in no_grad_set: - no_grad_set.add(value) +def _infer_var_data_type_shape_(grad_var_name, block): + """ + Infer the data type and shape of given grad variable + """ + grad_var = block.desc.find_var(grad_var_name.encode()) + fwd_name = _strip_grad_suffix_(grad_var_name) + if block.desc.has_var_recursive(fwd_name.encode()): + fwd_var = block.desc.find_var_recursive(fwd_name.encode()) + grad_var.set_dtype(fwd_var.dtype()) + grad_var.set_shape(fwd_var.shape()) + else: + # TODO(jiabin): Maybe we should not to this to cause some unexpected error on dtype + warnings.warn( + "Set grad var: {} dtype to default FP32, since we can't find its related forward var".format( + grad_var_name + ) + ) + grad_var.set_dtype(core.VarDesc.VarType.FP32) -def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op): - backward_ops.append(grad_op) - op_to_opgrad_list.append(grad_op) +def _all_in_set_(cands, s): + """ + Test if all elements of 'cands' are in set 's' + """ + if len(cands) == 0: + return False + for c in cands: + if not c in s: + return False + return True -def prepare_grad_outputs( - block, grad_outputs, outputs, value_to_valuegrad, op_to_opgrad -): +def _some_in_set_(cands, s): + """ + Test if some elements of 'cands' are in set 's' """ - if grad_outputs is none, add fill_1 op to create grad_outputs, - else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error. + if len(cands) == 0: + return False + for c in cands: + if c in s: + return True + return False - if only part of op's outputs in outputs, add fill_0 op to create other grad_outputs. - eg: split. - update value_to_valuegrad and op_to_opgrad. +def _strip_grad_suffix_(name): + """ + Strip the grad suffix from the given variable name + e.g. x@GRAD ==> x + x@GRAD@GRAD ==> x + y@GRAD@RENAME@1 ==> y + z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0 + grad/grad/z@GRAD@RENAME@block0@1@GRAD ==> z + """ + pos = re.search(f'{core.grad_var_suffix()}+@', name) or re.search( + f'{core.grad_var_suffix()}$', name + ) + new_name = name[: pos.start()] if pos is not None else name + new_pos = name.rfind('grad/') + return new_name[new_pos + 5 :] if new_pos != -1 else new_name - return complete_outputs and complete_gradoutputs, backward_ops. +def _append_grad_suffix_(name): + """ + Append grad suffix to the given variable name + e.g. x ==> x@GRAD """ - if not grad_outputs: - grad_outputs = [None] * len(outputs) + return name + core.grad_var_suffix() - if len(grad_outputs) != len(outputs): - raise ValueError( - "grad_outputs should have the same length of as outputs." + +def _accumulate_gradients_by_sum_op_( + var_name, renamed_vars, pending_sum_ops, op_idx, op_device="" +): + """ + Use sum op to accumulate_gradients, the gradients are stored in renamed_vars. + """ + if op_idx not in pending_sum_ops.keys(): + pending_sum_ops[op_idx] = [] + pending_sum_ops[op_idx].append( + _create_op_desc_( + "sum", + {"X": renamed_vars[var_name]}, + {"Out": [var_name]}, + {"use_mkldnn": False, "op_device": op_device}, ) - backward_ops = [] - for i, grad in enumerate(grad_outputs): - output = outputs[i] - # fwd : op1 -> op2 -> op3 -> output - # bwd : op1G <- op2G <- op3G <- outputG <- fillop/feedop - if grad is None: - output_grad = paddle.full( - output.shape, - 1.0, - dtype=output.dtype, - ) - fillop = output_grad.get_defining_op() + ) + renamed_vars[var_name] = [var_name] - update_bwdop_structure( - backward_ops, - op_to_opgrad[output.get_defining_op()], - fillop, - ) - value_to_valuegrad[output] = [[output_grad]] + +def _accumulate_gradients_by_add_ops_( + var_name, renamed_vars, pending_sum_ops, op_idx, op_device="" +): + """ + Use several inplace add op to accumulate_gradients, the gradients are stored in renamed_vars. + """ + if op_idx not in pending_sum_ops.keys(): + pending_sum_ops[op_idx] = [] + out_name = renamed_vars[var_name][0] + for i in range(1, len(renamed_vars[var_name])): + x_name = out_name + y_name = renamed_vars[var_name][i] + if i != len(renamed_vars[var_name]) - 1: + out_name = var_name + '@ADD@' + str(i) else: - if output.shape != grad.shape: - raise ValueError( - "The shape of grad_output[%d] should be the same as the shape of output[%d]" - % (i, i) - ) - if output.dtype != grad.dtype: - raise ValueError( - "The dtype of grad_output[%d] should be the same as the dtype of output[%d]" - % (i, i) - ) - feedop = grad.get_defining_op() - update_bwdop_structure( - backward_ops, op_to_opgrad[output.get_defining_op()], feedop + out_name = var_name + pending_sum_ops[op_idx].append( + _create_op_desc_( + "grad_add", + {"X": [x_name], "Y": [y_name]}, + {"Out": [out_name]}, + {"use_mkldnn": False, "op_device": op_device}, ) - value_to_valuegrad[output] = [[grad]] + ) + renamed_vars[var_name] = [var_name] - # add input for bwd first op - complete_outputs = outputs - complete_gradoutputs = grad_outputs - visited_output = set() - for output in outputs: - if output in visited_output: - continue - for opresult in output.get_defining_op().results(): - if opresult in value_to_valuegrad: - visited_output.add(opresult) +def _addup_repetitive_outputs_( + op_descs, block_idx, grad_var_to_var=None, grad_op_id_to_fwd_op=None +): + """ + In backward part, an variable may be the output of more than one ops. + And one op may yield its multiple outputs to the same variable. + In these cases, the variable should be the accumulation of all the outputs. + `sum_op`s are added to implement the accumulate. + + Args: + grad_var_to_var(dict): used to build the mapping between grad var name and forward var name. + Only for auto parallel. + """ + + _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] + # pending_sum_ops = [] + pending_sum_ops = collections.OrderedDict() + var_rename_count = collections.defaultdict(int) + renamed_vars = collections.defaultdict(list) + renamed_var_start_idx = collections.defaultdict(list) + var_device = collections.defaultdict(str) + for idx, op_desc in enumerate(op_descs): + op_device_attr_name = ( + core.op_proto_and_checker_maker.kOpDeviceAttrName() + ) + op_device = "" + if op_desc.has_attr(op_device_attr_name): + op_device = op_desc.attr(op_device_attr_name) + for var_name in op_desc.input_arg_names(): + if "@GRAD" not in var_name: continue + if len(renamed_vars[var_name]) > 1: + if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: + _accumulate_gradients_by_sum_op_( + var_name, + renamed_vars, + pending_sum_ops, + idx, + var_device[var_name], + ) + else: + _accumulate_gradients_by_add_ops_( + var_name, + renamed_vars, + pending_sum_ops, + idx, + var_device[var_name], + ) + + for param_idx, param_name in enumerate(op_desc.output_names()): + arg_names = op_desc.output(param_name) + for arg_idx, var_name in enumerate(arg_names): + if "@GRAD" not in var_name: + continue + # if "@RENAME@" in var_name: + # continue + if ( + var_name == core.empty_var_name() + or var_name in op_desc.input_arg_names() + ): + # empty variable or inplace op + continue + if len(renamed_vars[var_name]) == 0: + # it's the first time we get the variable + renamed_vars[var_name] = [var_name] + renamed_var_start_idx[var_name] = idx + else: + if len(renamed_vars[var_name]) == 1: + new_name = ( + var_name + + "@RENAME@block" + + str(block_idx) + + "@" + + str(var_rename_count[var_name]) + ) + var_rename_count[var_name] += 1 + # Build the mapping between the new_name and var_name (Only for auto parallel) + if grad_var_to_var is not None: + if var_name in grad_var_to_var: + grad_var_to_var[new_name] = grad_var_to_var[ + var_name + ] + else: + grad_var_to_var[new_name] = var_name + # rename original var_name + renamed_vars[var_name][0] = new_name + # before change: _rename_arg_(op_descs, var_name, + # new_name, 0, idx) + # rename arg from idx of the first appearance + # in backward, not always from 0 + _rename_arg_( + op_descs, + var_name, + new_name, + renamed_var_start_idx[var_name], + idx, + ) + _rename_arg_(pending_sum_ops, var_name, new_name) + + for p in op_desc.output_names()[:param_idx]: + p_arg_names = op_desc.output(p) + if var_name in p_arg_names: + op_desc.set_output( + p, + [ + new_name if x == var_name else x + for x in p_arg_names + ], + ) + + arg_names = [ + new_name if x == var_name else x + for x in arg_names[:arg_idx] + ] + arg_names[arg_idx:] + + new_name = ( + var_name + + "@RENAME@block" + + str(block_idx) + + "@" + + str(var_rename_count[var_name]) + ) + var_rename_count[var_name] += 1 + # Build the mapping between the new_name and var_name (Only for auto parallel) + if grad_var_to_var is not None: + if var_name in grad_var_to_var: + grad_var_to_var[new_name] = grad_var_to_var[ + var_name + ] + else: + grad_var_to_var[new_name] = var_name + arg_names[arg_idx] = new_name + op_desc.set_output(param_name, arg_names) + renamed_vars[var_name].append(new_name) + # record the latest device + var_device[var_name] = op_device + + for var_name, inputs in renamed_vars.items(): + if len(renamed_vars[var_name]) > 1: + if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: + _accumulate_gradients_by_sum_op_( + var_name, + renamed_vars, + pending_sum_ops, + len(op_descs), + var_device[var_name], + ) else: - grad_value = paddle.full( - opresult.shape, - 0.0, - opresult.dtype, + _accumulate_gradients_by_add_ops_( + var_name, + renamed_vars, + pending_sum_ops, + len(op_descs), + var_device[var_name], ) - fillop = grad.get_defining_op() - update_bwdop_structure( - backward_ops, - op_to_opgrad[opresult.get_defining_op()], - fillop, + op_descs_len = len(op_descs) + # sum_op descs are sorted according to their insert position + for key, value in collections.OrderedDict( + reversed(list(pending_sum_ops.items())) + ).items(): + # NOTE(zhiqiu): Since reversed, the idx of op_descs to be inserted will remains correct. + # For example, [0, 1, 2], and we want to insert 'a' at idx 1, 'b' at idx 2, and the expected result is [0, 1, 'a', 2, 'b']. + # If reversed, we first insert 'b' at idx 2, it becomes [0, 1, 2, 'b'], and then insert 'a' at idx 1, it becomes [0, 1, 'a', 2, 'b']. + # If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2]. + idx = key + for i, op in enumerate(value): + # update the mapping between fwd and bwd + target_idx = idx - 1 if idx == op_descs_len else idx + i + if ( + grad_op_id_to_fwd_op is not None + and grad_op_id_to_fwd_op.get( + op_descs[target_idx].original_id(), None ) - value_to_valuegrad[opresult] = [grad_value] + is not None + ): + grad_op_id_to_fwd_op[op.original_id()] = grad_op_id_to_fwd_op[ + op_descs[target_idx].original_id() + ] + op_descs.insert(idx + i, op) + + return op_descs + + +def _remove_no_grad_branch_( + op_descs, no_grad_set, grad_op_id_to_fwd_op=None, target_vars=[] +): + """ + Remove unnecessary grad ops + A grad op can be removed in two cases: + 1. all outputs of the grad op are in 'no_grad_set' + 2. all grad inputs of the grad op are in 'no_grad_set' + NOTE: we will skip target_vars's grad name. + """ + + def _op_can_be_removed_(op_desc, no_grad_set): + out_arg_names = op_desc.output_arg_names() + if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set): + return True + if _all_in_set_( + [ + name + for name in op_desc.input_arg_names() + if name.find(core.grad_var_suffix()) != -1 + ], + no_grad_set, + ): + no_grad_set.update(set(out_arg_names) - target_grad_var_names) + return True + return False - visited_output.add(opresult) + # Remove ops whose outputs are all in no_grad_dict + target_grad_var_names = set( + [var.name + core.grad_var_suffix() for var in target_vars] + ) + op_descs = [ + op_desc + for op_desc in op_descs + if not _op_can_be_removed_(op_desc, no_grad_set) + ] + # Insert fill_any_like_op with value 0 + to_insert = [] + if not core._is_bwd_prim_enabled(): + for idx, op_desc in enumerate(op_descs): + for arg in op_desc.input_arg_names(): + # arg is a gradient var name and arg should not have gradient + if core.grad_var_suffix() in arg and arg in no_grad_set: + x_in = _strip_grad_suffix_(arg) + # the reason should be: arg can be input of another grad op + # and the op is a not-to-remove op + new_op_desc = _create_op_desc_( + "fill_any_like", + {"X": [x_in]}, + {"Out": [arg]}, + {'value': 0, 'dtype': -1}, + ) + # update the mapping between fwd and bwd + if ( + grad_op_id_to_fwd_op is not None + and grad_op_id_to_fwd_op.get( + op_desc.original_id(), None + ) + is not None + ): + grad_op_id_to_fwd_op[ + new_op_desc.original_id() + ] = grad_op_id_to_fwd_op[op_desc.original_id()] + to_insert.append((new_op_desc, idx)) + + list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]) + + return op_descs + + +def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set): + """ + Pruning Program with Structural Analysis Method of Computational Graph. + The nodes of the computational graph composed of backward OPS should be + interconnected. If there are unconnected sub-graphs in the computational graph, + these sub-graphs should be cut off. + + Args: + grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs. + forward_ops(list[Operator]): The forward ops. + input_grad_names_set(set): this set is used to store the gradients' name + which is generated by backward ops, and input_grad_names_set can help + to prune the unnecessary backward ops. + + Return: + (set[core.OpDesc]): A set of OpDescs which should be pruned. + """ - complete_outputs.append(opresult) - complete_gradoutputs.append(grad_value) + class Var: + def __init__(self, var_name): + self.var_name = var_name + self.gen_op = None + self.pendding_ops = [] - return complete_outputs, complete_gradoutputs, backward_ops + def set_gen_op(self, gen_op): + assert isinstance(gen_op, Op) + assert self.gen_op is None + self.gen_op = gen_op + def add_pending_op(self, op): + assert isinstance(op, Op) + self.pendding_ops.append(op) -def some_in_set(value_list, value_set): - def operand2value(values): - value_set = set() - for item in values: - if isinstance(item, paddle.ir.OpOperand): - value_set.add(item.source()) + class Op: + def __init__(self, op_desc): + self.op_desc = op_desc + self.inputs = [] + self.outputs = [] + + def insert_input(self, var): + assert isinstance(var, Var) + self.inputs.append(var) + + def insert_output(self, var): + assert isinstance(var, Var) + self.outputs.append(var) + + var_versions = dict() + + def _create_node(name): + if name not in var_versions.keys(): + var_versions[name] = [Var(name)] + else: + var_versions[name].append(Var(name)) + return var_versions[name][-1] + + def _create_or_get_last_version_node(name): + if name not in var_versions.keys(): + var_versions[name] = [Var(name)] + return var_versions[name][-1] + + def _create_op_node(op_desc): + op_node = Op(op_desc) + for input in op_desc.input_arg_names(): + var = _create_or_get_last_version_node(name=input) + var.add_pending_op(op_node) + op_node.insert_input(var) + for output in op_desc.output_arg_names(): + var = _create_node(name=output) + var.set_gen_op(op_node) + op_node.insert_output(var) + return op_node + + # Record the forward vars + forward_vars_set = ( + set() if input_grad_names_set is None else set(input_grad_names_set) + ) + for op in forward_ops: + forward_vars_set.update(op.desc.input_arg_names()) + forward_vars_set.update(op.desc.output_arg_names()) + + # Record the vars which are created during backward and is not generated by op. + backward_vars_set = set() + # special_op_nodes is the candidate sub-graph head node. + special_op_nodes = set() + for op_desc in grad_op_descs: + input_set = set(op_desc.input_arg_names()) + # The new_vars are created during backward and is not generated by op. + new_vars = input_set - forward_vars_set - backward_vars_set + backward_vars_set.update(op_desc.output_arg_names()) + + op_node = _create_op_node(op_desc) + if len(new_vars) == len(input_set): + special_op_nodes.add(op_node) + + not_need_op_descs = [] + # Start traversing all candidate sub-graph headers to check whether + # they are connected to backward computational graphs, and if they are + # not, list them in not_need_op_descs + for special_op_node in special_op_nodes: + op_list = [special_op_node] + ready_vars = set(special_op_node.inputs) + remove_ops = True + candidate_ops = [special_op_node] + while len(candidate_ops) > 0: + op_node = candidate_ops.pop(0) + if _all_in_set_(op_node.inputs, ready_vars): + for out_var in op_node.outputs: + candidate_ops.extend(out_var.pendding_ops) + op_list.extend(out_var.pendding_ops) + ready_vars.update(op_node.outputs) else: - value_set.add(item) - return value_set + remove_ops = False + break + if remove_ops: + not_need_op_descs.extend([node.op_desc for node in op_list]) + not_need_op_descs_set = set(not_need_op_descs) + grad_op_descs_set = set(grad_op_descs) + # If a backward computational graph is simply one sub-graph header, the + # not_need_op_descs will be whole graph, this IF clause avoids it. + if grad_op_descs_set == not_need_op_descs_set: + return set() + return not_need_op_descs_set + + +def serialize_op_decs(op_desc): + protostr = op_desc.serialize_to_string() + proto = framework_pb2.OpDesc.FromString(bytes(protostr)) + return proto.__str__() + + +def _append_backward_ops_with_checkpoints_( + block, + ops, + target_vars, + target_block, + no_grad_dict, + grad_to_var, + checkpoints, + grad_op_id_to_fwd_op=None, +): + """ + Create grad ops with forward ops, and insert them into given block - if operand2value(value_list) & operand2value(value_set): - return True + Args: + block(Block): the block where forward ops are + ops(Op): the forward operators whose forward recomputation backward ops need to be added + target_vars(list[Tensor]): the loss vars we want to calculate gradient. + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(str): corresponding forward variable name + checkpoints: variables that a user defined as checkpoint for forward recomputation + + Algorithms: + 0) deal with forward recomputing program descs + 1) find ops between checkpoints, i.e. recompute_segments + 2) go through all forward ops and induct all variables that will be hold in memory + a. variables that are used across segments will be held in memory + b. output of dropout op will be held in memory + c. input variables will be held in memory + 3) go through each recompute_segments, add backward ops with forward recomputation + a. add ops in current recompute_segment as forward recomputation ops + b. rename all non-checkpoint variables in recomputation ops + c. add backward ops of current recomputation ops + d. add sum op for repetitive_outputs + 4) remove no grad branch as it is in _remove_no_grad_branch_ + 5) Note1: all appended ops' OpRole are Backward + 6) Note2: all variables with new name should be returned so that _append_backward_vars_ can be called + 7) Note3: current forward recomputation backpropagation does not handle programs with subblock + """ + + checkpoints_name = [x.name for x in checkpoints] + checkpoints_name = list(set(checkpoints_name)) + local_block = block.program._create_block() + buffer_block = block.program._create_block() + # 0) deal with forward recomputing program descs + program_stat = ProgramStats(block, ops) + program_stat.modify_forward_desc_for_recompute() + program_stat.build_stats() + + # 1) find ops between checkpoints, i.e. recompute_segments + checkpoints_name = program_stat.sort_checkpoints(checkpoints_name) + segments = [] + + if len(checkpoints_name) == 1: + # only one checkpoint + max_op_idx = -1 + var_group = [checkpoints_name[0]] + for name in var_group: + if name not in program_stat.var_op_deps: + break + op_idx = program_stat.var_op_deps[name]["var_as_output_ops"] + # only count the last generate op + for idx in op_idx: + max_op_idx = max(max_op_idx, idx) + if max_op_idx > 0: + segments.append([0, max_op_idx + 1]) else: - return False + start_idx = 0 + pre_segment_end_idx = -1 + while True: + if start_idx >= len(checkpoints_name) - 1: + break + # min_idx: checkpoint_1' s input op + # max_idx: checkpoint_2' s output op + flag, min_idx, max_idx = program_stat.is_subgraph( + [checkpoints_name[start_idx]], [checkpoints_name[start_idx + 1]] + ) + if flag: + # max_idx + 1 since the exact and used segment end idx is max_idx + min_idx = program_stat._update_segment_start( + min_idx, pre_segment_end_idx + ) + segments.append([min_idx, max_idx + 1]) + else: + _logger.info( + "Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1 + ) + ) + start_idx += 1 -def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): - ''' - prune ops which do not in the path from inputs_set to outputs_set, - prune ops which do not in the path from outputs_set to inputs_set, + if segments != [] and segments[0][0] != 0: + recompute_segments = [[0, segments[0][0]]] + segments + else: + recompute_segments = segments - pruned op in total_ops is uneffective_ops, else is effective_ops + for i, (idx1, idx2) in enumerate(recompute_segments): + _logger.info("recompute segment[{}]".format(i)) + _logger.info( + "segment start op: [{}]: [{}]".format( + ops[idx1].desc.type(), ops[idx1].desc.input_arg_names() + ) + ) + _logger.info( + "segment end op: [{}]: [{}]".format( + ops[idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names() + ) + ) + _logger.info("recompute segment[{}]".format(i)) + _logger.info( + "segment start op: [{}]: [{}]".format( + ops[idx1].desc.type(), ops[idx1].desc.input_arg_names() + ) + ) + _logger.info( + "segment end op: [{}]: [{}]".format( + ops[idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names() + ) + ) - ''' - relevant_op_flags = [True] * len(total_ops) - # from input to output - if inputs_set: - for i, op in enumerate(total_ops): - if some_in_set(op.results(), inputs_set): - continue + # 2) go through all forward ops and induct all variables that will be hold in memory + vars_should_be_hold = [] + # a. variables that are used across segments will be held in memory + for segment in recompute_segments: + vars_should_be_hold.extend( + program_stat.get_out_of_subgraph_vars(segment[0], segment[1]) + ) - if some_in_set(op.operands_source(), inputs_set): - for value in op.results(): - if value not in no_grad_set: - inputs_set.add(value) - else: - relevant_op_flags[i] = False + cross_vars = set(vars_should_be_hold) - set(checkpoints_name) + _logger.info( + "found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( + len(cross_vars), cross_vars + ) + ) - # from output to input - for i, op in reversed(list(enumerate(total_ops))): - # while op support - if some_in_set(op.results(), outputs_set): - for operand in op.operands_source(): - if operand not in no_grad_set: - outputs_set.add(operand) - else: - relevant_op_flags[i] = False + # b. output of seed op should be kept in memory + vars_should_be_hold.extend(program_stat.get_reserved_vars()) + # c. input variables are checkpoints + vars_should_be_hold.extend(program_stat.get_input_nodes()) + vars_should_be_hold = list(set(vars_should_be_hold)) + + # 3) go through each recompute_segments, add backward ops with forward recomputation + grad_op_descs = [] + var_name_dict = {} + + vars_in_memory = vars_should_be_hold + checkpoints_name + + max_calculated_op_position = len(ops) + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + if recompute_segments == []: + gap_ops = ops[0:max_calculated_op_position] + for op in reversed(gap_ops): + if op.has_attr("sub_block"): + raise Exception( + "Recompute don't support ops with sub_block" + "invoke op: %s" + % _pretty_op_desc_(op.desc, "with_sub_block") + ) + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, no_grad_dict[block.idx], [] + ) - effective_ops = [ - total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] - ] - uneffective_ops = [ - total_ops[i] - for i in reversed(range(len(total_ops))) - if not relevant_op_flags[i] - ] + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for op_desc in grad_op_desc: + grad_op_id_to_fwd_op[op_desc.original_id()] = op + + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) + added_descs = _add_descs_to_block( + grad_op_desc, local_block, grad_op_id_to_fwd_op + ) + grad_op_descs.extend(added_descs) + grad_to_var.update(op_grad_to_var) + + for i, segment in enumerate(recompute_segments[::-1]): + gap_ops = ops[segment[1] : max_calculated_op_position] + max_calculated_op_position = segment[0] + for op in reversed(gap_ops): + if op.has_attr("sub_block"): + raise Exception( + "Recompute don't support ops with sub_block" + "invoke op: %s" + % _pretty_op_desc_(op.desc, "with_sub_block") + ) + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, no_grad_dict[block.idx], [] + ) + + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for op_desc in grad_op_desc: + grad_op_id_to_fwd_op[op_desc.original_id()] = op + + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) + added_descs = _add_descs_to_block( + grad_op_desc, local_block, grad_op_id_to_fwd_op + ) + grad_op_descs.extend(added_descs) + grad_to_var.update(op_grad_to_var) + + ff_ops = ops[segment[0] : segment[1]] + var_suffix = ".subprog_%d" % i + + for op in ff_ops: + if op.has_attr("sub_block"): + raise Exception( + "Recompute don't support ops with sub_block" + "invoke op: %s" + % _pretty_op_desc_(op.desc, "with_sub_block") + ) + input_and_output_names = [] + input_and_output_names.extend(op.desc.input_arg_names()) + input_and_output_names.extend(op.desc.output_arg_names()) + for name in input_and_output_names: + if block.var(name).persistable or name in checkpoints_name: + continue + if name in vars_should_be_hold: + continue + if name not in var_name_dict: + var_name_dict[name] = name + var_suffix + + # we should create the rename var in subprog, otherwise its VarType will be BOOL + ref_var = block.program.global_block().var(name) + block.create_var( + name=var_name_dict[name], + shape=ref_var.shape, + dtype=ref_var.dtype, + type=ref_var.type, + persistable=ref_var.persistable, + stop_gradient=ref_var.stop_gradient, + ) + + # 3.a. add ops in current recompute_segment as forward recomputation ops + buffer_descs = _add_needed_descs_to_block( + ff_ops, buffer_block, block, vars_in_memory, grad_op_id_to_fwd_op + ) + added_descs = _add_descs_to_block( + ff_ops, local_block, grad_op_id_to_fwd_op + ) + + # 3.b. rename all non-checkpoint variables in recomputation ops + for key in var_name_dict: + _rename_arg_(buffer_descs, key, var_name_dict[key]) + + # added_descs should be in grad_op_descs because it is backward op desc + grad_op_descs.extend(buffer_descs) - return effective_ops, uneffective_ops + # 3.c. add backward ops for all ops in current segment + for op_desc in reversed(added_descs): + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op_desc, no_grad_dict[block.idx], [] + ) + + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for g_op_desc in grad_op_desc: + grad_op_id_to_fwd_op[ + g_op_desc.original_id() + ] = grad_op_id_to_fwd_op[op_desc.original_id()] + + # Set device for grad_op according to forward Op + if op_desc.has_attr(device_attr_name): + op_device = op_desc.attr(device_attr_name) + for g_op_desc in grad_op_desc: + g_op_desc._set_attr(device_attr_name, op_device) + + for key in var_name_dict: + _rename_arg_(grad_op_desc, key, var_name_dict[key]) + grad_op_descs.extend(grad_op_desc) + grad_to_var.update(op_grad_to_var) + + # 3.d. add sum op for repetitive_outputs + grad_op_descs = _addup_repetitive_outputs_( + grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op + ) + # 4) remove no grad branch as it is in _remove_no_grad_branch_ + grad_op_descs = _remove_no_grad_branch_( + grad_op_descs, + no_grad_dict[block.idx], + grad_op_id_to_fwd_op, + target_vars, + ) + added_descs = _add_descs_to_block( + grad_op_descs, target_block, grad_op_id_to_fwd_op + ) + return ( + program_stat, + checkpoints_name, + vars_should_be_hold, + recompute_segments, + ) -def update_no_grad_set_after_prune( - block, effective_forward_op, no_grad_set, inputs, outputs +def _get_sub_block_path( + sub_block, + sub_block_op_desc, + no_grad_set, + op_path_dict, + sub_block_target_names=None, ): - ''' - update no_grad_set after forward prune + """ + Get output vars in subblock which will be assigned to parent block. + It is used to find the grad path in subblock. - from inputs to outputs add value not in the path to no_grad_set, - from outputs to inputs add value not in the path to no_grad_set, - ''' - inputs_set = set(inputs) - if inputs_set: - for op in block.ops: - if some_in_set(op.operands_source(), inputs_set): - for value in op.results(): - if value not in no_grad_set: - inputs_set.add(value) - - for op in effective_forward_op: - for value in op.operands_source(): - if value not in inputs_set: # and value.get_stopgradient(): - no_grad_set.add(value) - - outputs_set = set(outputs) - no_grad_set_tmp = set() - for op in reversed(effective_forward_op): - for output in op.results(): - if output not in outputs_set and not some_in_set( - [output], set(op.operands_source()) - ): - no_grad_set_tmp.add(output) + Args: + sub_block(Block): The sub-block in which to get op path. + sub_block_op_desc: The op desc of the sub-block op such as 'while', 'conditional_block' and 'recurrent'. + no_grad_set(set): The set of no grad var name. no_grad_set will be changed. + op_path_dict(dict): op_path_dict will be changed. + key(int) block index + val(list) the op path of block(index) + sub_block_target_names(set): Target var names of sub-block. + Return: + The forward op path of sub-block corresponding to backward op. + """ - for input in op.operands_source(): - if input not in no_grad_set: - outputs_set.add(input) + assert sub_block_op_desc.has_attr( + "sub_block" + ) and sub_block.idx == sub_block_op_desc._block_attr_id("sub_block") + assert isinstance(sub_block_target_names, (set, type(None))) + + if sub_block_target_names is None: + sub_block_target_names = sub_block_op_desc.output_arg_names + + # TODO(huihuangzheng): add support for recurrent op. + if sub_block_op_desc.type in ["conditional_block", "while"]: + # Step1: get the output vars in sub-block + sub_outputs = [ + sub_block._var_recursive(var) for var in sub_block_target_names + ] + for var in sub_block_target_names: + for op_desc in sub_block.ops: + if var in op_desc.output_arg_names: + for name in op_desc.input_arg_names: + sub_outputs.append(sub_block._var_recursive(name)) + + # Step2: find op path of sub-block + is_while = sub_block_op_desc.type in ["while"] + sub_block_op_path = _find_op_path_( + sub_block, sub_outputs, [], no_grad_set, op_path_dict, is_while + ) + return sub_block_op_path + return sub_block.ops - no_grad_set.update(no_grad_set_tmp) +def _is_grad_op_(op): + op_maker = core.op_proto_and_checker_maker + backward = core.op_proto_and_checker_maker.OpRole.Backward + if op_maker.kOpRoleVarAttrName() in op.attr_names and int( + op.all_attrs()[op_maker.kOpRoleAttrName()] + ) == int(backward): + return True + return False + + +def _rename_grad_name_(name, grad_order): + return 'grad/' * grad_order + name + + +def _append_backward_ops_( + block, + ops, + target_vars, + target_block, + no_grad_dict, + grad_to_var, + callbacks=None, + input_grad_names_set=None, + op_path_dict=None, + distop_context=None, + rename_var_map=None, + grad_op_id_to_fwd_op=None, +): + """ + Create all grad ops, and insert them into given block + + Args: + block(Block): the block where forward ops are + ops(Op): the forward operators whose backward ops need to be added + target_vars(list[Tensor]): the loss vars we want to calculate gradient. + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(set) a set of variable names. These variables have no gradient + grad_to_var(dict)(output argument): + key(str): grad variable name + val(str): corresponding forward variable name + callbacks(callable object): a callable object used to decorate new generated grad ops + input_grad_names_set(set): this set is used to store the gradients' name which is + generated by backward ops, and input_grad_names_set can help to prune the unnecessary + backward ops. + op_path_dict(dict): op_path_dict will be changed. + key(int) block index + val(list) the op path of block(index) + rename_var_map(dict): used to associate target_grad var name with first grad_op input name. + Only used in for high order gradient. + """ -def inverse_sort_op(ops): - ''' - if topo graph is op1 -> op2 -> op3 - return [op3, op2, op1] + # Build the mapping between the forward op and backward op (Only for auto parallel) + def update_distop_context( + distop_context, op_grad_to_var, appending_grad_times + ): + distop_context.grad_var_to_var[appending_grad_times].update( + op_grad_to_var + ) + for op_desc in grad_op_desc: + assert ( + op_desc.original_id() not in distop_context.grad_op_id_to_op_id + ) + distop_context.grad_op_id_to_op_id[ + op_desc.original_id() + ] = op.desc.original_id() + + if callbacks is not None: + assert isinstance(callbacks, (list, tuple)) + for cb in callbacks: + if not hasattr(cb, '__call__'): + raise ValueError("'callback' must be a callable object.") + + # grad_op_descs holds created grad_op, and will be appended to target_block + grad_op_descs = [] + program = block.program + + if rename_var_map is None: + rename_var_map = {} + assert isinstance(rename_var_map, dict) + + if core._is_bwd_prim_enabled(): + composite_block = program.clone().current_block() + # Create output and infer shape for operators whose output haven't + # been created. + for op in composite_block.ops: + for name in op.output_arg_names: + if not ( + composite_block.desc.has_var_recursive(name.encode()) + or name == core.empty_var_name() + ): + composite_block.create_var(name=name) + op.desc.infer_var_type(composite_block.desc) + op.desc.infer_shape(composite_block.desc) + + # add grad_op_desc by reversed ops + for op in reversed(ops): + grad_sub_block_list = [] + # If the op has its own sub-block, deal with the sub-block first + if op.has_attr("sub_block"): + sub_block = program.block(op._block_attr_id("sub_block")) + grad_sub_block = program._create_block() + grad_sub_block._set_forward_block_idx(sub_block.idx) + # see following comments for why set None here. + pre_input_grad_names_set = copy.copy(input_grad_names_set) + input_grad_names_set = None + sub_block_path = op_path_dict[op._block_attr_id("sub_block")] + _append_backward_ops_( + sub_block, + sub_block_path, + target_vars, + grad_sub_block, + no_grad_dict, + grad_to_var, + callbacks, + input_grad_names_set, + op_path_dict, + grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, + ) + input_grad_names_set = pre_input_grad_names_set + + program._rollback() + grad_sub_block_list.append(grad_sub_block.desc) + # In primitive mode, raw phi GradOp will be split into multiple small + # primitive operators, and the split rules are defined in c++ level, + # see details: paddle/fluid/prim/api/manual/backward/composite_backward_api.h + # It means that the output's shape and dtype of previous operators which + # maybe used as the input of next operators must be known. Therefore, + # we infer shape and dtype in a sandbox block(named composite_block) for + # used in c++ level. + # For example: + # forward: + # z = multiply(x, y) //maybe broadcast in kernel + # backward: + # x_grad_unreduce = z_grad * y // maybe unreduce + # reduced_axes = get_reduced_axes(x_grad.shape, x.shape) // need known shape + # x_grad = reduce_sum(x_grad_unreduce) + grad_op_desc = [] + op_grad_to_var = {} + if core._is_bwd_prim_enabled(): + + def find_op_index(block_desc, cur_op_desc): + for idx in range(block_desc.op_size()): + if cur_op_desc == block_desc.op(idx): + return idx + return -1 + + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + composite_block.desc.op(find_op_index(block.desc, op.desc)), + no_grad_dict[composite_block.idx], + grad_sub_block_list, + ) + for desc in grad_op_desc: + infershape_for_composite(composite_block, desc) + else: + # Getting op's corresponding grad_op + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, no_grad_dict[block.idx], grad_sub_block_list + ) - ''' + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for op_desc in grad_op_desc: + grad_op_id_to_fwd_op[op_desc.original_id()] = op - # init pending_count[op] which descibes number of - # pending edges for its grad_op + # Build the mapping between the forward op and backward op (Only for auto parallel) + if distop_context is not None: + update_distop_context( + distop_context, op_grad_to_var, program._appending_grad_times + ) + else: + default_ctx = getattr( + paddle.distributed.auto_parallel.static.dist_context, + '_g_default_distributed_context', + None, + ) + if default_ctx is not None: + distop_context = default_ctx.dist_op_context + update_distop_context( + distop_context, + op_grad_to_var, + program._appending_grad_times, + ) - pending_count = collections.defaultdict(int) - ops_set = set(ops) - sorted_list = [] - for op in ops: - for x in op.operands(): - if x.source().get_defining_op() in ops_set: - pending_count[x.source().get_defining_op()] += 1 + # Set device for grad_op according to forward Op + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) + + # Rename internal gradient variables in multiple backward + # so that they have different names with previous backward. + # For example: + # y = x * x, grad = fluid.gradients(fluid.gradients(y, x) + y * y, x) + # In second-time backward, gradient variable names of partial + # forward network (y * y) may be have same names with first-time + # fluid.gradients(y, x). + # So rename here before _addup_repetitive_outputs_. + if program._appending_grad_times > 1: + for op_desc in grad_op_desc: + forward_op_inputs = op.desc.input_arg_names() + for name in op_desc.input_arg_names(): + if name in rename_var_map and name not in forward_op_inputs: + op_desc._rename_input(name, rename_var_map[name]) + for name in op_desc.output_arg_names(): + if "@GRAD" not in name: + continue + if block.desc.find_var(name.encode("ascii")): + new_name = _rename_grad_name_( + name, program._appending_grad_times + ) + op_desc._rename_output(name, new_name) + rename_var_map[name] = new_name + + if name in op_grad_to_var: + # Build the mapping between the grad var name and var name (Only for auto parallel) + if distop_context is not None: + distop_context.grad_var_to_var[ + program._appending_grad_times + ][new_name] = op_grad_to_var[name] + op_grad_to_var[new_name] = op_grad_to_var[name] + op_grad_to_var.pop(name) + + # If input_grad_names_set is not None, extend grad_op_descs only when + # any input grad in outputs of previous grad ops. + # But this strategy is not suited for while op for some control flow, + # for example, for while op, the grads maybe generated in next loop. + if input_grad_names_set is not None: + is_grad_name = ( + lambda name: name.find(core.grad_var_suffix()) != -1 + or name in input_grad_names_set + ) + is_append_grad = False + + # NOTE: In primitive mode, the intermediate variable generated by + # decompositing raw grad op are not satisfied the rule of 'XX@GRAD', + # which will cause it be pruned according to current pruning logic. + # For simplicity, we treate all prmitive operators as one raw + # operator, and keep the pruning logic consistent with currently + # logic. The drawback of this solution is may lead to some primitive + # operators are not pruned, which is needed to fixed. + # FIXME: Optimize pruning logic from the perspective of whole graph. + input_grad_names = [] + for op_desc in grad_op_desc: + input_grad_names += [ + name + for name in op_desc.input_arg_names() + if is_grad_name(name) + ] - queue = collections.deque() + # some code of gradient ops, like increment, are not very + # standard, there is no @GRAD in these ops' inputs. + if len(input_grad_names) == 0: + is_append_grad = True + continue - for op in ops: - if pending_count[op] == 0: - queue.append(op) + if _some_in_set_(input_grad_names, input_grad_names_set): + is_append_grad = True + for op_desc in grad_op_desc: + grad_op_descs.append(op_desc) + for name in op_desc.output_arg_names(): + input_grad_names_set.add(name) - while queue: - op = queue.popleft() - sorted_list.append(op) + if is_append_grad: + grad_to_var.update(op_grad_to_var) + else: + grad_op_descs.extend(grad_op_desc) + grad_to_var.update(op_grad_to_var) + + # record mapping between grad var name and var name (Only for auto parallel) + grad_var_to_var = None + if distop_context is not None: + grad_var_to_var = distop_context.grad_var_to_var[ + program._appending_grad_times + ] + # sum parameter's gradients' var given multiple var gradient + grad_op_descs = _addup_repetitive_outputs_( + grad_op_descs, + block.idx, + grad_var_to_var, + grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, + ) - for x in op.operands(): - x_op = x.source().get_defining_op() - pending_count[x_op] -= 1 - if pending_count[x_op] == 0: - queue.append(x_op) + # if all outputs of the grad op are in no_grad_set, then just remove and fill zero + # if all inputs of the grad op are in no_grad_set, just remove this op + grad_op_descs = _remove_no_grad_branch_( + grad_op_descs, + no_grad_dict[block.idx], + grad_op_id_to_fwd_op, + target_vars, + ) - if len(sorted_list) != len(ops): - raise ValueError( - "inverse_sort_op wrong, sorted_list size is not equal to origin_list size" + # remove some backward ops + # TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem + if not core._is_bwd_prim_enabled(): + not_need_ops = _find_not_need_ops( + grad_op_descs, ops, input_grad_names_set + ) + grad_op_descs = [ + op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops + ] + else: + logging.debug( + "Running backward composite and disable find_not_need_ops" ) - return sorted_list + # append op_desc in grad_op_descs to target_block + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + backward = core.op_proto_and_checker_maker.OpRole.Backward + for op_desc in grad_op_descs: + new_op_desc = target_block.desc.append_op() + new_op_desc.copy_from(op_desc) + new_op_desc._set_attr(op_role_attr_name, backward) + grad_to_var["__current_op_desc__"] = new_op_desc + if callbacks is not None: + assert isinstance(callbacks, (list, tuple)) + for cb in callbacks: + cb(block=target_block, context=grad_to_var) -def append_backward_ops( - block, effective_forward_op, no_grad_set, backward_ops, state -): - ''' - add grad_op in order of topological inverse sort - eg: - from :op1 -> v1 -> op2 -> v2 -> op3 -> v3 - to: og1_g <- v1_g <- op2_g <- v2_g <- op3_g <- v3_g +def _is_grad_var_(var_name): + return core.grad_var_suffix() in var_name - if op has grad_op, prepare its grad_op's inputs by value_to_valuegrad, - eg: - value_to_valuegrad[v3] = [[v3_g]]; - v2_g = call_vjp(op3, [v3_g], [v2_stopgradient]) +# Find the op who holds the sub_block as its "sub_block" attr +def _find_parent_op_(sub_block): + sub_block_id = sub_block.idx - special pattern 1: - v11 -> combine_op -> v1 -> op -> v3 - v12 -> - v2 -> - value_to_valuegrad[v3] = [[v3_g]] - - v1 is inside python api, we don't describe it in backward process(state) - so v1_grad is inside vjp, we don't describe it in backward process(state) - [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [v3_g], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient) + if sub_block_id == 0: + return None + program = sub_block.program + for block_id in range(program.num_blocks): + block_desc = program.block(block_id).desc + for op_idx in range(block_desc.op_size()): + op = block_desc.op(op_idx) + if ( + op.has_attr("sub_block") + and op._block_attr_id("sub_block") == sub_block_id + ): + return op - op_vjp is: - v11_g <- split_op <- v1_g <- op_g <- v3_g - v12_g <- - v2_g <- + # NOTE(paddle-dev): When optimizer is added in conditional block, + # sub_block may not be found. + return None - value_to_valuegrad[v11] = [[v11_g]] - value_to_valuegrad[v12] = [[v12_g]] - value_to_valuegrad[v2] = [[v2_g]] - if op don't has grad_op, if it don't has input and it's output has more than - one output_grad, add sumop for grad aggregation. - (eg: full op and get_parameter op etc.) +def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): + """ + Create new variables required by backward pass. - else continue to next op. + Args: + block(Block): the block where new variables will be created + start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created + grad_to_var(dict): + key(str): grad variable name + val(str): corresponding forward variable name + In most cases, this dict is generated by _append_backward_ops_() + grad_info_map(dict)(output argument): + key(str): forward variable name + val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable + """ + ops_to_remove = [] ''' - - def make_output_grad(op): - zero_flag = [False] * op.num_results() - output_grads = [] - for i, value in enumerate(op.results()): + NOTE(paddle-dev): while_grad op may hold some inputs which are not found + in the parent/forward block, and they are also the outputs of while_grad + op. These kinds of inputs are the recursive outputs inside while_grad op. + They should be considered as "already created" when scanning the inner + ops of while_grad ops. + ''' + parent_op = _find_parent_op_(block) + parent_op_vars = [] + if parent_op is not None: + input_args = parent_op.input_arg_names() + output_args = parent_op.output_arg_names() + for in_arg in input_args: + if in_arg in output_args: + parent_op_vars.append(in_arg) + + for op_idx in range(start_op_idx, block.desc.op_size()): + op_desc = block.desc.op(op_idx) + if op_desc.has_attr("sub_block"): + sub_block = block.program.block(op_desc._block_attr_id("sub_block")) + _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map) + + grad_var_ins = [ + var for var in op_desc.input_arg_names() if _is_grad_var_(var) + ] + grad_var_outs = [ + var for var in op_desc.output_arg_names() if _is_grad_var_(var) + ] + + inputs = [ + var + for var in op_desc.input_arg_names() + if var != core.empty_var_name() + ] + outputs = [ + var + for var in op_desc.output_arg_names() + if var != core.empty_var_name() + ] + + # If the outputs of grad op is empty, just remove it + if not outputs: + ops_to_remove.append(op_idx) + continue + else: + ''' + If the output is not empty and there is any grad input, find + whether there is any existing input. If not, just remove it. + ''' + if grad_var_ins: + existing_grad_var_ins = [ + var + for var in grad_var_ins + if block.desc.has_var_recursive(var.encode()) + or var in parent_op_vars + ] + if not existing_grad_var_ins: + ''' + FIXME(paddle-dev, zengjinle): rnn_memory_helper_grad is used + in recurrent op. The input of this op does not even exist in + the program! Therefore, any dependency analysis would not + work to this op! If I do not add the following code, this op + would be pruned, and the calculation result would be wrong. + Maybe we should re-design this op later... + ''' + if op_desc.type() not in ['rnn_memory_helper_grad']: + ops_to_remove.append(op_idx) + continue + + # sum may create invalid variable, here to deal with it. + if op_desc.type() == 'sum': + new_inputs = [] + for grad_var_name in op_desc.input_arg_names(): + if block.desc.has_var_recursive(grad_var_name.encode()): + # meet invalid sum variables, remove the invalid operand. + new_inputs.append(grad_var_name) + assert ( + len(new_inputs) > 0 + ), "After remove invalid variables, sum op have no inputs." + op_desc.set_input("X", new_inputs) + + new_vars = set() + # create new gradient variables + for grad_var_name in op_desc.output_arg_names(): if ( - value not in state.value_to_valuegrad - or state.value_to_valuegrad[value] is None + block.desc.has_var_recursive(grad_var_name.encode()) + or grad_var_name == core.empty_var_name() ): - if ( - not value.use_empty() - and value.first_use().owner().name() == "builtin.split" - ): - # pattern case: - # this fwd_op's output is vectorType, it will split to - # Type by builtin.split op, so need get from split op's ouput - split_zero_flag, split_output_grad = make_output_grad( - value.first_use().owner() - ) - zero_flag[i] = all(split_zero_flag) - state.value_to_valuegrad[value] = [split_output_grad] + continue + block.desc.var(grad_var_name.encode()) + new_vars.add(grad_var_name) + if grad_var_name not in grad_to_var: + continue + grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) + # infer_shape and infer_type + op_desc.check_attrs() + op_desc.infer_var_type(block.desc) + op_desc.infer_shape(block.desc) + + for arg in op_desc.output_arg_names(): + if arg in new_vars: + _infer_var_data_type_shape_(arg, block) + + for op_idx in reversed(ops_to_remove): + block.desc._remove_op(op_idx, op_idx + 1) + + +def infershape_for_composite(block, grad_op_desc): + # NOTE: why pruning the operator with empty output here ? + # Some backward operator will output empty var, which will cause infer + # shape error, such assign with input's stop_gradient=True + if len(grad_op_desc.output_arg_names()) == 0: + return + + # create output variable + new_vars = set() + for grad_var_name in grad_op_desc.output_arg_names(): + if not ( + block.desc.has_var_recursive(grad_var_name.encode()) + or grad_var_name == core.empty_var_name() + ): + # NOTE: stop_gradient will be set in append_op + desc = block.desc.var(grad_var_name.encode()) + block.create_var(name=grad_var_name, desc=desc, type=desc.type()) + new_vars.add(grad_var_name) + + # NOTE For the primitive operator generated by decompositing phi grad kernel, + # we Operator to reconstruct the op_desc for reusing some complex logic, such + # as processing dispensable input, intermediate output, extra attrs, etc... + if framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()): + op = block.append_op( + type=grad_op_desc.type(), + inputs={ + name: [block._find_var_recursive(arg) for arg in args] + for name, args in grad_op_desc.inputs().items() + }, + outputs={ + name: [block._find_var_recursive(arg) for arg in args] + for name, args in grad_op_desc.outputs().items() + }, + # NOTE Runtime attr will be ignore as the c++ GetRuntimeAttr + # interface cann't be exported to python. Please note the WARNING + # message logged in RuntimeAttrs of composite_grad_desc_maker.h + attrs=grad_op_desc.get_attr_map(), + ) + op.desc._set_attr( + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.OpRole.Backward, + ) + grad_op_desc.copy_from(op.desc) + # For the backward operator, we reuse the logic of _append_backward_var + else: + op_desc = block.desc.append_op() + op_desc.copy_from(grad_op_desc) + op_desc._set_attr( + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.OpRole.Backward, + ) + op_desc.check_attrs() + op_desc.infer_var_type(block.desc) + op_desc.infer_shape(block.desc) + grad_op_desc.copy_from(op_desc) + + if not framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()): + # NOTE: Some raw fluid grad operators which hadn't been decomposed may not + # implement InferVarType method, such as elementwise_xx_grad, and it will + # cause the dtype or shape of corresponding cotangent incorrect. This + # patch set the cotangent dtype and shape same with corresponding + # forward variable. For primitive operators, we have ensure all + # InferVarType method to be executed correctly in PR#52818, we skip + # this patch for primitive operators. + for arg in grad_op_desc.output_arg_names(): + if arg in new_vars: + _infer_var_data_type_shape_(arg, block) + + +def _rename_grad_( + block, start_op_idx, grad_to_var, target_grad_map, skip_rename_var_list +): + var_map = copy.copy(target_grad_map) + for op_idx in range(start_op_idx, block.desc.op_size()): + op_desc = block.desc.op(op_idx) + for name in op_desc.input_arg_names(): + if name in var_map: + op_desc._rename_input(name, var_map[name]) + + for name in op_desc.output_arg_names(): + if "@GRAD" not in name: + continue + if block.desc.find_var(name.encode("ascii")): + if name in skip_rename_var_list: + continue + new_name = unique_name.generate(name) + op_desc._rename_output(name, new_name) + var_map[name] = new_name + + for g, ng in var_map.items(): + if g in grad_to_var: + grad_to_var[ng] = grad_to_var[g] + grad_to_var.pop(g) + + +def _get_stop_gradients_(program): + no_grad_dict = dict() + assert isinstance(program, framework.Program) + for block in program.blocks: + assert isinstance(block, framework.Block) + block_no_grad_set = set() + for var in list(block.vars.values()): + assert isinstance(var, framework.Variable) + if var.stop_gradient: + block_no_grad_set.add(_append_grad_suffix_(var.name)) + no_grad_dict[block.idx] = block_no_grad_set + return no_grad_dict + + +def _get_son_parent_block_idx_dict(program, current_block_idx): + son_parent_block_idx_dict = collections.OrderedDict() + while current_block_idx >= 0: + parent_block_idx = program.block(current_block_idx).parent_idx + son_parent_block_idx_dict[current_block_idx] = parent_block_idx + current_block_idx = parent_block_idx + + return son_parent_block_idx_dict + + +def _get_no_grad_set_name(no_grad_set): + no_grad_set_name = set() + if no_grad_set is not None: + if isinstance(no_grad_set, (set, list, tuple)): + for i, no_grad_var in enumerate(no_grad_set): + if isinstance(no_grad_var, framework.Variable): + no_grad_set_name.add(no_grad_var.name) + elif isinstance(no_grad_var, str): + no_grad_set_name.add(no_grad_var) else: - # first case: - # this fwd_op's output didn't used by other fwd_op, - # so no output_grad created. - - # second case: - # last bwd_op return None because input in no_grad_set, - # but this bwd_op need a input. - grad_value = paddle.full( - value.shape, - 0.0, - dtype=value.dtype, + raise TypeError( + "The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s." + % (type(no_grad_var)) ) - fillop = grad_value.get_defining_op() + else: + raise TypeError( + "The type of no_grad_set should be set or list or tuple, but received {}".format( + type(no_grad_set) + ) + ) + return no_grad_set_name - update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], fillop - ) - zero_flag[i] = True - state.value_to_valuegrad[value] = [[grad_value]] +@framework.static_only +def append_backward( + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None, + checkpoints=None, + distop_context=None, +): + """ + :api_attr: Static Graph + + This function appends backward part to main_program. - if len(state.value_to_valuegrad[value]) > 1: - # one value is input of more than one fwd_op, - # so more than one bwd_op create input_grad, - # need add sum op to accumulate gradient + A complete neural network training is made up of forward and backward + propagation. However, when we configure a network, we only need to + specify its forward part. This function uses the chain rule to automatically + generate the backward part according to the forward part. - paddle.add_n( - [item[0] for item in state.value_to_valuegrad[value]] - ) - combineop = block.ops[len(block.ops) - 2] - sumop = block.ops[len(block.ops) - 1] - update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], combineop - ) - update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], sumop - ) - state.value_to_valuegrad[value] = [[sumop.result(0)]] - state.value_to_sumvaluegrad[value] = state.value_to_valuegrad[ - value - ] + In most cases, users do not need to invoke this function manually. + It will be automatically invoked by the optimizer's `minimize` function. - output_grads.append(state.value_to_valuegrad[value][0][0]) - return zero_flag, output_grads + Parameters: + loss(Tensor): The loss Tensor of the network. + parameter_list(list[Tensor|str]|tuple[Tensor|str], optional): List/Tuple of Parameters or Parameter.names + that need to be updated by optimizers. + If it is None, all parameters + will be updated. + Default: None. + no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients + should be ignored. All Tensors with + `stop_gradient=True` from all blocks will + be automatically added into this set. + If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set. + Default: None. + callbacks(list[callable object]|tuple[callable object], optional): List/Tuple of callback functions. + The callbacks are used for + doing some custom jobs during + backward part building. All + callable objects in it will + be invoked once each time a + new gradient operator is added + into the program. The callable + object must have two input + parameters: ``block`` and ``context`` . + The ``block`` is the :ref:`api_guide_Block_en` which + the new gradient operator will + be added to. The ``context`` is a + map, whose keys are gradient + Tensor names and values are + corresponding original :ref:`api_guide_tensor_en` . + In addition to this, the ``context`` + has another special key-value pair: + the key is string ``__current_op_desc__`` + and the value is the op_desc of the + gradient operator who has just + triggered the callable object. + Default: None. - def make_input_stopgradient(op): - input_grad_stopgradient_list = [] - for input in op.operands_source(): - if input.get_defining_op().name() == "builtin.combine": - stop_gradient = make_input_stopgradient(input.get_defining_op()) - input_grad_stopgradient_list.append( - [info[0] for info in stop_gradient] - ) - else: - if input in no_grad_set: - input_grad_stopgradient_list.append([True]) - else: - input_grad_stopgradient_list.append([False]) - return input_grad_stopgradient_list - - def update_input_grad_map(op, input_grad_list): - for i, input in enumerate(op.operands_source()): - if input.get_defining_op().name() == "builtin.combine": - update_input_grad_map( - input.get_defining_op(), input_grad_list[i] - ) - else: - input_grad = input_grad_list[i] - if isinstance(input_grad, list): - state.value_to_valuegrad[input].append(input_grad) - else: - state.value_to_valuegrad[input].append([input_grad]) - - # there are four patterns: - # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) - # [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) - # [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType) - # [op4] (op4's inputs and outputs are not vectorType) - # einsum has twp vectorType outputs, special pattern - - clear_effective_forward_op = [] - - for op in effective_forward_op: - if op.name() != "builtin.combine" and op.name() != "builtin.split": - clear_effective_forward_op.append(op) - - for op in clear_effective_forward_op: - if paddle.framework.core.has_vjp(op): - # prepare output_grad - output_grad_list = [] # (opresult) - zero_flag, output_grad = make_output_grad(op) - output_grad_list.append(output_grad) - - # all(zero_flag) support this op has no contribution for grad - # should be delete (prune sub_graph) - if len(output_grad_list) == 0 or all(zero_flag): - continue + Returns: + list of tuple ( :ref:`api_guide_tensor_en` , :ref:`api_guide_tensor_en` ): Pairs of parameter and its corresponding gradients. + The key is the parameter and the value is gradient Tensor. + + Raises: + AssertionError: If ``loss`` is not an instance of Tensor. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + paddle.enable_static() + + x = paddle.static.data(name='x', shape=[None, 13], dtype='int64') + y = paddle.static.data(name='y', shape=[None, 1], dtype='float32') + x_emb = paddle.static.nn.embedding(x, size=[100, 256]) + y_predict = paddle.static.nn.fc(x=x_emb, size=1, activation=None, name='my_fc') + loss = F.square_error_cost(input=y_predict, label=y) + avg_loss = paddle.mean(loss) + + # Get all weights in main_program, not include bias. + all_weights = [param for param in paddle.static.default_main_program().block(0).all_parameters() if 'w_' in param.name] + all_weights_name = [w.name for w in all_weights] + + # return all param_grads needed to be updated if parameter_list set default None. + p_g_list1 = paddle.static.append_backward(loss=avg_loss) + # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)] + + # return the param_grads corresponding to parameter_list that can be list of param (Tensor). + p_g_list2 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights) + # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)] + + # parameter_list can be list of param.name (str). + p_g_list3 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights_name) + # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)] + + # no_grad_set can be set of Tensors that means grad will be cut off from these Tensors. + p_g_list4 = paddle.static.append_backward(loss=avg_loss, no_grad_set=set([x_emb])) + # output: [(my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)] + + # no_grad_set can be set of Tensor.name when the Tensor is created inside layers and can't be specified explicitly. + p_g_list5 = paddle.static.append_backward(loss=avg_loss, no_grad_set=set(['my_fc.b_0'])) + # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)] + + # return [] because all param_grads are filtered by no_grad_set. + p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights)) + + """ + grad_op_id_to_fwd_op = ( + {} + ) # for cuda graph usage, recording the mapping between grad op original id to fwd op - # prepare input_grad stop_gradient info. - input_grad_stopgradient_list = make_input_stopgradient(op) + check_type( + loss, 'loss', framework.Variable, 'paddle.static.append_backward' + ) + + if loss.op is None: + # the loss is from a cloned program. Find loss op manually. + _find_loss_op_(loss) + + loss.op._set_attr( + core.op_proto_and_checker_maker.kOpRoleAttrName(), + int(core.op_proto_and_checker_maker.OpRole.Forward) + | int(core.op_proto_and_checker_maker.OpRole.Loss), + ) + + if callbacks is not None: + check_type( + callbacks, + 'callbacks', + (list, tuple), + 'paddle.static.append_backward', + ) + + program = loss.block.program + root_block = program.block(0) + current_block_idx = program.current_block_idx + current_block = program.block(current_block_idx) + + is_in_control_flow = current_block_idx != 0 + + # Double grad is not supported in sub-block (control flow) + if not is_in_control_flow: + # _appending_grad_times used for double grad + program._appending_grad_times += 1 + + if no_grad_set is None: + no_grad_set = set() + else: + no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set)) + no_grad_dict = _get_stop_gradients_(program) + # no_grad_set only contains vars in block 0 + # Todo(liym27): support vars in sub block + no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set))) + + # Currently it is only to support the optimizer.minimize + # in a switch branch, which can append_backward in a sub_block. + # Note: while_loop is in control flow, but it makes no sense to call optimizer in while. + # Todo: report error when it is in while_loop + if is_in_control_flow: + # create grad block if in switch control flow. + target_grad_block = program._create_block( + parent_idx=current_block.parent_idx + ) + target_grad_block._set_forward_block_idx(current_block_idx) + # after _create_block, program.current_block changes + else: + target_grad_block = root_block + + son_parent_block_idx_dict = _get_son_parent_block_idx_dict( + program, current_block_idx + ) + + block_fwd_op_num_dict = {} # block_id: fwd_op_num + for idx in son_parent_block_idx_dict: + block_fwd_op_num_dict[idx] = program.block(idx).desc.op_size() + + grad_to_var = dict() + + # pass the cuda_graph_attr to the fill_constant which generates the loss_grad + op_desc = _create_loss_op_desc_(loss) + grad_op_id_to_fwd_op[op_desc.original_id()] = loss.op + target_grad_block.desc.append_op().copy_from(op_desc) + + for block_idx in son_parent_block_idx_dict: + block = program.block(block_idx) + + block_no_grad_set = set( + map(_strip_grad_suffix_, no_grad_dict[block_idx]) + ) + + op_path_dict = dict() + op_path = _find_op_path_( + block, [loss], [], block_no_grad_set, op_path_dict + ) + + no_grad_vars = _find_no_grad_vars( + block, op_path, [loss], block_no_grad_set + ) + + block_no_grad_set.update(no_grad_vars) + no_grad_dict[block_idx].update( + list(map(_append_grad_suffix_, block_no_grad_set)) + ) + + input_grad_names_set = None + # For double backward, input_grad_names is used for filtering + # some non-used gradients op(s). + + # TODO(liym27): need a better design. + # not support double grad in control flow sub-block now. + if not is_in_control_flow: + if program._appending_grad_times > 1: + input_grad_names_set = set([_append_grad_suffix_(loss.name)]) + + # TODO: support _append_backward_ops_with_checkpoints_ in + # sub-block (control flow) + is_recompute = False + if ( + checkpoints is not None + and isinstance(checkpoints, list) + and len(checkpoints) > 0 + ): + is_recompute = True + ( + program_stat, + checkpoint_names, + vars_should_be_hold, + recompute_segments, + ) = _append_backward_ops_with_checkpoints_( + root_block, + op_path, + [loss], + root_block, + no_grad_dict, + grad_to_var, + checkpoints, + grad_op_id_to_fwd_op, + ) + else: + _append_backward_ops_( + block, # the block where forward ops are in + op_path, + [loss], + target_grad_block, + no_grad_dict, + grad_to_var, + callbacks, + input_grad_names_set=input_grad_names_set, + op_path_dict=op_path_dict, + distop_context=distop_context, + grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, + ) + + grad_info_map = dict() + + # if in control flow, target_grad_block is a created new block which only contains grad ops, + # so fwd_op_num is set to 0. + fwd_op_num = ( + block_fwd_op_num_dict[current_block_idx] + if not is_in_control_flow + else 0 + ) + + # Because append_backward may be called multiple times, + # we need rename the internal gradient variables so that they have + # different names. + _rename_grad_(target_grad_block, fwd_op_num, grad_to_var, {}, []) + + _append_backward_vars_( + target_grad_block, fwd_op_num, grad_to_var, grad_info_map + ) - # create grad_op - before_ops_num = len(block.ops) - input_grad_list = paddle.framework.core.call_vjp( - op, output_grad_list, input_grad_stopgradient_list + program.current_block_idx = current_block_idx + program._sync_with_cpp() + + # for cuda graph, copy the cuda graph attr from forward op to backward op + for op in target_grad_block.ops: + if grad_op_id_to_fwd_op.get(op.desc.original_id(), None) is not None: + fwd_op = grad_op_id_to_fwd_op[op.desc.original_id()] + op._cuda_graph_attr = fwd_op._cuda_graph_attr + + if parameter_list is not None: + check_type( + parameter_list, + 'parameter_list', + (list, tuple, set), + 'fluid.backward.append_backward', + ) + parameters = [] + for i, param in enumerate(parameter_list): + check_type( + param, + 'parameter_list[%s]' % i, + (framework.Variable, str), + 'fluid.backward.append_backward', ) - after_ops_num = len(block.ops) + if isinstance(param, framework.Variable): + parameters.append(param.name) + elif isinstance(param, str): + parameters.append(param) + else: + params = program.global_block().all_parameters() + parameters = [param.name for param in params if param.trainable] - # update grad_op structure - for i in range(before_ops_num, after_ops_num): - update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], block.ops[i] + params_and_grads = [] + op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName() + for param in parameters: + if param not in grad_info_map: + continue + grad_info = grad_info_map[param] + grad_block = grad_info[1] + if not grad_block.has_var(grad_info[0]): + raise ValueError( + "grad block[{0}] did not have grad var {1}".format( + grad_info[1], grad_info[0] ) + ) + # Get the param var from the global block + param_var = program.global_block().var(param) + grad_var = grad_block.var(grad_info[0]) + if not is_in_control_flow: + if loss.block.has_var(grad_info[0]): + params_and_grads.append((param_var, grad_var)) + else: + params_and_grads.append((param_var, None)) + else: + params_and_grads.append((param_var, grad_var)) + + for p, g in params_and_grads: + if g is None: + continue + ops = ( + grad_block.ops if is_in_control_flow else program.global_block().ops + ) + for op in reversed(ops): + assert isinstance(op, framework.Operator) + if g.name in op.output_arg_names: + g.op = op + break + + if g.op is None: + raise ValueError("Unexpected branch") + attr_val = [p.name, g.name] + if g.op.has_attr(op_role_var_attr_name): + attr_val.extend(g.op.attr(op_role_var_attr_name)) + g.op._set_attr(op_role_var_attr_name, attr_val) + + if is_recompute: + return params_and_grads, checkpoint_names + else: + return params_and_grads + + +def _as_list(x): + if x is None: + return [] + return list(x) if isinstance(x, Sequence) else [x] - # update input_grad map - update_input_grad_map(op, input_grad_list) +def _is_ancestor_block(ancestor_block, block): + prog = block.program + ancestor_idx = ancestor_block.idx + parent_idx = block.parent_idx + + while parent_idx != -1: + if parent_idx == ancestor_idx: + return True + parent_idx = prog.block(parent_idx).parent_idx + + return False + + +def _get_output_names(cur_block, targets): + """ + In `cur_block`, get output names those linked to targets. + NOTE: + 1. `targets` can be in `cur_block`; + Usually, `targets` is in `cur_block`. However, considering control flow, + 2. `targets` may be in sub-block but `cur_block` is an ancestor of `targets[0].block`; + 3. `targets` may be in the block which is ancestor of `cur_block`. + """ + + block = targets[0].block if targets else cur_block + current_output_names = set([out.name for out in targets]) + + # 1. If `targets` in cur_block or the ancestral block of `cur_block` + if block.idx == cur_block.idx or _is_ancestor_block(block, cur_block): + return current_output_names + + # 2. If `cur_block` is an ancestor of `targets[0].block`, run while loop + prog = cur_block.program + while block.idx != cur_block.idx: + assert block.parent_idx != -1 + parent_block = prog.block(block.parent_idx) + + parent_block_output_names = set() + for op in reversed(block.ops): + if _some_in_set_(op.desc.output_arg_names(), current_output_names): + for name in op.desc.input_arg_names(): + current_output_names.add(name) + if not block.desc.find_var( + name.encode() + ) and parent_block.desc.find_var(name.encode()): + parent_block_output_names.add(name) + + block = parent_block + current_output_names = parent_block_output_names + + return current_output_names + + +def _find_no_grad_vars(block, op_path, targets, no_grad_set): + """ + Find the vars which is not used in the program, and + those vars belong to no_grad_var. + """ + output_names = _get_output_names(block, targets) + no_grad_var = [] + for i, op in reversed(list(enumerate(op_path))): + # If the op has sub_block, it is too complicated to find the correct no_grad_var. + if not op.has_attr("sub_block"): + for out_var in op.desc.output_arg_names(): + if ( + out_var not in output_names + and out_var not in op.desc.input_arg_names() + and not block.vars[out_var].stop_gradient + ): + no_grad_var.append(out_var) + for name in op.desc.input_arg_names(): + if name not in no_grad_set: + output_names.add(name) + return set(no_grad_var) + + +def _find_op_path_( + block, targets, inputs, no_grad_set, op_path_dict=None, is_while=False +): + """ + It is used to find the grad path in `block`. + + Args: + block(Block): The block in which to get op path. + targets(list[Variable]): The target variables. + inputs(list[Variable]): The input variables. + no_grad_set(set): The set of no grad var name. no_grad_set will be changed. + op_path_dict(dict): op_path_dict will be changed. op_path_dict will be changed. + key(int) block index + val(list) the op path of block(index) + is_while(bool): Whether or not `block` is while block + Return: + The forward op path of block corresponding to backward op. + """ + + input_names = set([inp.name for inp in inputs]) + output_names = _get_output_names(block, targets) + if op_path_dict is None: + op_path_dict = dict() + + relevant_op_flags = [True] * len(block.ops) + + # All the inputs of the block are used if inputs is empty, + if inputs: + for i, op in enumerate(block.ops): + if _some_in_set_( + op.desc.input_arg_names(), input_names + ) and not core.has_empty_grad_op_maker(op.type): + for name in op.desc.output_arg_names(): + if name not in no_grad_set: + input_names.add(name) + else: + relevant_op_flags[i] = False + + for i, op in reversed(list(enumerate(block.ops))): + if op.has_attr("sub_block"): + sub_block_id = op._block_attr_id("sub_block") + sub_block = block.program.block(sub_block_id) + sub_block_target_names = output_names & set(op.output_arg_names) + sub_block_path = _get_sub_block_path( + sub_block, op, set(), op_path_dict, sub_block_target_names + ) + op_path_dict[sub_block_id] = sub_block_path + + if _some_in_set_( + op.desc.output_arg_names(), output_names + ) and not core.has_empty_grad_op_maker(op.type): + for name in op.desc.input_arg_names(): + if name not in no_grad_set: + output_names.add(name) else: - if op.num_operands() == 0 and op.num_results() != 0: - for value in op.results(): - if len(state.value_to_valuegrad[value]) > 1: - # need add sum op - paddle.add_n( - [ - item[0] - for item in state.value_to_valuegrad[value] - ] - ) - combineop = block.ops[len(block.ops) - 2] - sumop = block.ops[len(block.ops) - 1] - update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], combineop - ) - update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], sumop - ) - state.value_to_valuegrad[value] = [[sumop.result(0)]] - state.value_to_sumvaluegrad[ - value - ] = state.value_to_valuegrad[value] - else: - state.op_to_opgrad[op] = [] - else: - state.op_to_opgrad[op] = [] + relevant_op_flags[i] = False + if is_while: + # If block is while block, dealing with op specifically again. + # TODO(liym27): Consider special types of ops. + for i, op in reversed(list(enumerate(block.ops))): + if relevant_op_flags[i] == False and _some_in_set_( + op.desc.output_arg_names(), output_names + ): + relevant_op_flags[i] = True + if not core.has_empty_grad_op_maker(op.type): + for name in op.desc.input_arg_names(): + if name not in no_grad_set: + output_names.add(name) + + op_path = [ + block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i] + ] -def create_backward_prune_set(inputs, outputs, no_grad_set, state): - outputs_set = set() - for input in inputs: - for item in input.first_use().owner().operands_source(): - if state.value_to_valuegrad[item] != []: - outputs_set.add(state.value_to_valuegrad[item][0][0]) - inputs_set = set() - for output in outputs: - if state.value_to_valuegrad[output] != []: - inputs_set.add(state.value_to_valuegrad[output][0][0]) - - inputs_set_tmp = set() - for out_grad in inputs_set: - if not out_grad.use_empty(): - for item in out_grad.first_use().owner().operands_source(): - inputs_set_tmp.add(item) - inputs_set.update(inputs_set_tmp) - - no_gradvar_set = set() # grad_value of value in no_grad_set - for key in state.value_to_valuegrad: - if key in no_grad_set: - no_gradvar_set.add(state.value_to_valuegrad[key][0][0]) - - for key in state.value_to_sumvaluegrad: - if key in no_grad_set: - for item in state.value_to_sumvaluegrad[key][0]: - no_gradvar_set.add(item) - - return outputs_set, inputs_set, no_gradvar_set - - -def remove_op(block, op, state): + if inputs: + for op in op_path: + for name in op.desc.input_arg_names(): + if name not in input_names and block.vars[name].stop_gradient: + no_grad_set.add(name) + + return op_path + + +def calc_gradient_helper( + targets, inputs, target_gradients=None, no_grad_set=None +): ''' - remove op from block + Calculate gradient and return grad_info_map ''' - block.remove_op(op) - if state.opgrad_to_op[op] != []: - fwd_op = state.opgrad_to_op[op][0] - state.op_to_opgrad[fwd_op].remove(op) + targets = _as_list(targets) + inputs = _as_list(inputs) + target_gradients = _as_list(target_gradients) + + block = targets[0].block + prog = block.program + # increase appending gradients times + prog._appending_grad_times += 1 + block_idx = block.idx + + if not target_gradients: + target_gradients = [None] * len(targets) + + if len(targets) != len(target_gradients): + raise ValueError( + "Should have the same number of target_gradients as targets" + ) + + if no_grad_set is None: + no_grad_set = set() + else: + no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set)) + no_grad_dict = _get_stop_gradients_(prog) + no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set))) - for valuegrad in op.results(): - if state.valuegrad_to_value[valuegrad] != []: - value = state.valuegrad_to_value[valuegrad][0] - state.value_to_valuegrad[value] = [] + fwd_op_num = block.desc.op_size() - if value in state.sumvaluegrad_to_value: + input_grad_names_set = set() + + target_grad_map = {} + rename_var_map = {} + skip_rename_var_list = [] + grad_name_set = set() + for i, grad in enumerate(target_gradients): + target = targets[i] + grad_name = _append_grad_suffix_(target.name) + if grad is None: + op_desc = _create_op_desc_( + "fill_any_like", + {"X": [target.name]}, + {"Out": [grad_name]}, + { + "value": 1.0, + "dtype": target.dtype, + }, + ) + block.desc.append_op().copy_from(op_desc) + block.program._sync_with_cpp() + input_grad_names_set.add(grad_name) + skip_rename_var_list.append(grad_name) + else: + if target.block.idx != block_idx or target.block.program != prog: + raise ValueError("all targets must be in the same block") + if target.shape != grad.shape: raise ValueError( - 'input_grad in [%s] is value which need to sum ', op.name() + "The shapes of target and grad are different: %s %s" + % (target.name, grad.name) ) + target_grad_map[_append_grad_suffix_(target.name)] = grad.name + input_grad_names_set.add(grad.name) + rename_var_map[grad_name] = grad.name + grad_name_set.add(grad_name) -def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): - block = outputs[0].get_defining_op().get_parent_block() - state = State(block.get_parent_program()) - # check all inputs and outputs in the same block - check_all_puts(block, inputs, outputs) - # update no_grad_set if some value stop_gradient=True - update_no_grad_set_by_stopgradient(block, no_grad_set) - complete_outputs, _, backward_ops = prepare_grad_outputs( - block, - grad_outputs, - outputs, - state.value_to_valuegrad, - state.op_to_opgrad, - ) + if core._is_bwd_prim_enabled(): + core._set_prim_target_grad_name(target_grad_map) + # For double backward, input_grad_names is used for filter + # some non-used gradients op. rename_var_map is used to + # associate target_grad var name with first grad_op input name. + if prog._appending_grad_times == 1: + input_grad_names_set = None + rename_var_map = {} - inputs_set = set(inputs) - outputs_set = set(complete_outputs) - effective_forward_op, _ = prune_ops( - block.ops, inputs_set, outputs_set, no_grad_set - ) - update_no_grad_set_after_prune( - block, effective_forward_op, no_grad_set, inputs, complete_outputs + for input in inputs: + if input.block.program != prog: + raise "input must be in the same program as targets" + block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) + + op_path_dict = dict() + op_path = _find_op_path_( + block, targets, inputs, block_no_grad_set, op_path_dict ) - inverse_effective_forward_op = inverse_sort_op(effective_forward_op) + # only for composite to add grad_var of the last forward op + # who has more than one output, but targets only has one, + # so targets_gradients only add one grad_var, + # eg: op1 -> op2 -> var1 / var2 targets = var1, + # targets_gradients = var1_grad, need to add var2_grad here. + tmp_targets = targets - append_backward_ops( - block, inverse_effective_forward_op, no_grad_set, backward_ops, state + if core._is_bwd_prim_enabled(): + for op in reversed(block.ops): + if op.type == "fill_any_like": + continue + # Some outputs of composite op are not needed and will be removed. + # Thus, those vars should not be added with another op. + keep_var_list = [] + if op.type in core.ops_contain_none.keys(): + values = core.ops_contain_none[op.type] + if isinstance(values, list): + none_vars = values + else: + none_vars = values(op) + for none_var_name in none_vars: + keep_var_list.append(op.output(none_var_name)[0]) + + for var_name in op.desc.output_arg_names(): + if keep_var_list and (var_name in keep_var_list): + continue + grad_var_name = _append_grad_suffix_(var_name) + if grad_var_name not in grad_name_set: + op_desc = _create_op_desc_( + "fill_any_like", + {"X": [var_name]}, + {"Out": [grad_var_name]}, + {'value': 0, 'dtype': targets[0].dtype}, + ) + block.desc.append_op().copy_from(op_desc) + tmp_targets.append(block.var(var_name)) + break + block.program._sync_with_cpp() + + # find no grad var by op_path + no_grad_vars = _find_no_grad_vars( + block, op_path, tmp_targets, block_no_grad_set ) - # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) + block_no_grad_set.update(no_grad_vars) - outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( - inputs, complete_outputs, no_grad_set, state + no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set))) + grad_to_var = dict() + grad_info_map = dict() + _append_backward_ops_( + block, + op_path, + targets, + block, + no_grad_dict, + grad_to_var, + input_grad_names_set=input_grad_names_set, + op_path_dict=op_path_dict, + rename_var_map=rename_var_map, ) - _, remove_ops = prune_ops( - backward_ops, inputs_set, outputs_set, no_gradvar_set + + # Because calc_gradient may be called multiple times, + # we need rename the internal gradient variables so that they have + # different names. + _rename_grad_( + block, fwd_op_num, grad_to_var, target_grad_map, skip_rename_var_list ) - state.turn_map() - for bwd_op in inverse_sort_op(remove_ops): - remove_op(block, bwd_op, state) - state.turn_map() + _append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map) + prog._sync_with_cpp() + return grad_info_map - input_grad_map = state.value_to_valuegrad - return input_grad_map +def _get_grad_vars(grad_info_map, inputs): + inputs = _as_list(inputs) + grad_vars = [] + for input_var in inputs: + if input_var.name not in grad_info_map: + grad_vars.append(None) + else: + grad_info = grad_info_map[input_var.name] + grad_block = grad_info[1] + grad_var = grad_block.var(grad_info[0]) + grad_vars.append(grad_var) + return grad_vars -def calc_gradient(outputs, inputs, grad_outputs, no_grad_set): +def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): """ - caclulate gradient of input + Backpropagate the gradients of targets to inputs. Args: - outputs (Value|list(Value)|tuple(Value)): the output Value or - Value list/tuple of the graph to compute gradients. - inputs (Value|list(Value)|tuple(Value)): the input Value or - Value list/tuple of the graph to compute gradients. The returned - values of this API are the gradients of `inputs` . - grad_outputs (Value|list(Value|None)|tuple(Value|None), optional): - initial gradient values of `outputs` . If `grad_outputs` is None, - the initial gradient values of `outputs` would be Values filled with 1; - if `grad_outputs` is not None, it must have the same length as `outputs` , - and in this case, the initial gradient value of the i-th `outputs` would - be: (1) a Value filled with 1 when the i-th element of `grad_outputs` - is None; (2) the i-th element of `grad_outputs` when the i-th element of - `grad_outputs` is a Value. Default None. - no_grad_set (set(Value), optional): - the Values whose gradients are not needed to compute. Default None. + targets(Tensor|list[Tensor]|tuple[Tensor]): The target Tensors + inputs(Tensor|list[Tensor]|tuple[Tensor]): The input Tensors + target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensors + of targets which has the same shape with targets, If None, ones will + be created for them. + no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients + should be ignored. All Tensors with + `stop_gradient=True` from all blocks will + be automatically added into this set. + If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set. + Default: None. Return: - list[Value]:A list of gradients for inputs + (list[Tensor]): A list of gradients for inputs If an input does not affect targets, the corresponding gradient Tensor will be None - TODO if allow_unused=False raise TypeError() if input_grad has None """ - # record input value and its gradient (Value to Value) - input_to_inputgrad_map = calc_gradient_helper( - outputs, inputs, grad_outputs=grad_outputs, no_grad_set=no_grad_set + + # NOTE: If you want to modify the logic of calc_gradient, please modify + # it inside the calc_gradient_helper and _get_grad_vars functions + # to ensure the correctness of dy2st mode. + grad_info_map = calc_gradient_helper( + targets, + inputs, + target_gradients=target_gradients, + no_grad_set=no_grad_set, ) - inputgrad = [] - for input in inputs: - inputgrad.append( - input_to_inputgrad_map[input][0][0] - if input_to_inputgrad_map[input] != [] - else None - ) - return inputgrad - - -def grad( - outputs, - inputs, - grad_outputs=None, - retain_graph=None, - create_graph=False, - only_inputs=True, - allow_unused=False, - no_grad_vars=None, -): - ''' - .. note:: - **This API is ONLY available in imperative mode.** + grad_vars = _get_grad_vars(grad_info_map, inputs) + + if len(grad_vars) == 1: + return grad_vars[0] + else: + return grad_vars - This API computes the sum of gradients of `outputs` with respect to each `inputs` . - Parameters: - outputs (Value|list(Value)|tuple(Value)): the output Value or - Value list/tuple of the graph to compute gradients. - inputs (Value|list(Value)|tuple(Value)): the input Value or - Value list/tuple of the graph to compute gradients. The returned - values of this API are the gradients of `inputs` . - grad_outputs (Value|list(Value|None)|tuple(Value|None), optional): - initial gradient values of `outputs` . If `grad_outputs` is None, - the initial gradient values of `outputs` would be Values filled with 1; - if `grad_outputs` is not None, it must have the same length as `outputs` , - and in this case, the initial gradient value of the i-th `outputs` would - be: (1) a Value filled with 1 when the i-th element of `grad_outputs` - is None; (2) the i-th element of `grad_outputs` when the i-th element of - `grad_outputs` is a Value. Default None. - retain_graph (bool, optional): whether to retain the forward graph which - is used to calculate the gradient. When it is True, the graph would - be retained, in which way users can calculate backward twice for the - same graph. When it is False, the graph would be freed. Default None, - which means it is equal to `create_graph` . - create_graph (bool, optional): whether to create the gradient graphs of - the computing process. When it is True, higher order derivatives are - supported to compute; when it is False, the gradient graphs of the - computing process would be discarded. Default False. - only_inputs (bool, optional): whether to only compute the gradients of - `inputs` . If it is False, the gradients of all remaining leaf - Values in the graph would be also computed and accumulated. - If it is True, only the gradients of `inputs` would be computed. - Default True. only_inputs=False is under development, and it is - not supported yet. - allow_unused (bool, optional): whether to raise error or return None if some - Values of `inputs` are unreachable in the graph. If some Values of - `inputs` are unreachable in the graph (i.e., their gradients are None), - error would be raised if allow_unused=False, or None would be returned as - their gradients if allow_unused=True. Default False. - no_grad_vars (Value|list(Value)|tuple(Value)|set(Value), optional): - the Values whose gradients are not needed to compute. Default None. +@framework.static_only +def gradients(targets, inputs, target_gradients=None, no_grad_set=None): + """ - Returns: - list: a list of Values, whose length is the same as the Value number - inside `inputs`, and the i-th returned Value is the sum of gradients of - `outputs` with respect to the i-th `inputs`. - ''' + Backpropagate the gradients of targets to inputs. + + Args: + targets (Tensor|list[Tensor]|tuple[Tensor]): The target Tensors. + inputs (Tensor|list[Tensor]|tuple[Tensor]): The input Tensors. + target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensor + of targets which has the same shape with targets, If None, ones will + be created for them. + no_grad_set (set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients + should be ignored. All Tensors with ``stop_gradient=True`` from all blocks will + be automatically added into this set. If this parameter is not None, the Tensors or Tensor.names + in this set will be added to the default set. Default: None. + + Return: + (list[Tensor]): A list of gradients for inputs + If an input does not affect targets, the corresponding gradient Tensor + will be None. + + Examples: + + .. code-block:: python + :name: code-example + import paddle + import paddle.nn.functional as F + + paddle.enable_static() + + x = paddle.static.data(name='x', shape=[None, 2, 8, 8], dtype='float32') + x.stop_gradient=False + y = paddle.static.nn.conv2d(x, 4, 1, bias_attr=False) + y = F.relu(y) + z = paddle.static.gradients([y], x) + print(z) # [var x@GRAD : LOD_TENSOR.shape(-1, 2, 8, 8).dtype(float32).stop_gradient(False)] + """ check_type( - outputs, - 'outputs', - ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), - 'paddle.ir.grad', + targets, + 'targets', + (framework.Variable, list, tuple), + 'paddle.static.gradients', ) check_type( inputs, 'inputs', - ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), - 'paddle.ir.grad', + (framework.Variable, list, tuple), + 'paddle.static.gradients', ) check_type( - grad_outputs, - 'grad_outputs', - ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)), - 'paddle.ir.grad', + target_gradients, + 'target_gradients', + (framework.Variable, list, tuple, type(None)), + 'paddle.static.gradients', ) + outs = calc_gradient(targets, inputs, target_gradients, no_grad_set) + return _as_list(outs) + + +@framework.static_only +def gradients_with_optimizer(program, optimizer, inputs=None, outputs=None): + """ + :api_attr: Static Graph + + Backpropagate the gradients of the program and apply the gradients with the given optimizer. + + Args: + program (Program): The input program. + optimizer (Optimizer): The optimizer to apply the gradients. + inputs (Tensor|list[Tensor]|tuple[Tensor], optional): The input Tensors. + If None, the inputs will be created from the input variables in the given program. Default:None. + outputs (Tensor|list[Tensor]|tuple[Tensor], optional): The output Tensors. + If None, the outputs will be created from the output variables in the given program. Default: None. + Return: + tuple: tuple (optimize_ops, params_grads), A list of operators appended + by gradients_with_optimizer and a list of (param, grad) variable pairs, param is + ``Parameter``, grad is the gradient value corresponding to the parameter. + The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to + indicate program pruning. If so, the program will be pruned by ``feed`` and + ``fetch_list`` before run, see details in ``Executor``. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + img = static.data(name='image', shape=[None, 784]) + pred = static.nn.fc(x=img, size=10, activation='relu') + loss = paddle.mean(pred) + opt_ops, pram_grads = paddle.fluid.backward.gradients_with_optimizer(static.default_main_program(), opt) + print(opt_ops) + + """ check_type( - no_grad_vars, - 'no_grad_vars', - ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)), - 'paddle.ir.grad', + program, + 'program', + paddle.fluid.Program, + 'paddle.static.gradients_with_optimizer', + ) + check_type( + optimizer, + 'optimizer', + paddle.optimizer.Optimizer, + 'paddle.static.gradients_with_optimizer', ) - outputs = _as_list(outputs) - inputs = _as_list(inputs) - grad_outputs = _as_list(grad_outputs) - if no_grad_vars is None: - no_grad_set = set() - elif no_grad_vars is not set: - no_grad_set = set(no_grad_vars) - else: - no_grad_set = no_grad_vars - - input_grad = calc_gradient(outputs, inputs, grad_outputs, no_grad_set) - return input_grad + if inputs is None or outputs is None: + in_set = set() + out_set = set() + for block in program.blocks: + for op in block.ops: + for name in op.input_arg_names: + in_set.add(block.vars[name]) + for name in op.output_arg_names: + out_set.add(block.vars[name]) + if inputs is None: + inputs = list(in_set.difference(out_set)) + if outputs is None: + outputs = list(out_set.difference(in_set)) + + grads = gradients(outputs, inputs) + + with program_guard(program, None): + pram_grads = [ + (pram, grad) + for pram, grad in zip(inputs, grads) + if isinstance(pram, paddle.fluid.framework.Parameter) + and grad is not None + ] + + optimize_ops = optimizer.apply_gradients(pram_grads) + + return optimize_ops, pram_grads From 1072cddbe1ba24d96720988ee7887ef7d91a3cd6 Mon Sep 17 00:00:00 2001 From: wangruting Date: Thu, 24 Aug 2023 08:35:39 +0000 Subject: [PATCH 28/30] recover conflict --- python/paddle/autograd/backward.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index 5bf723be06c1b..7b88394713b75 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -203,6 +203,21 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): outputs_set.add(operand) else: relevant_op_flags[i] = False + # recover full op or full_Intarray op created by mutable attribute. + total_ops_list = list(total_ops) + for i, op in enumerate(total_ops_list): + if relevant_op_flags[i] is False: + for result in op.results(): + if result.has_one_use(): + next_op = result.first_use().owner() + if ( + next_op in total_ops + and relevant_op_flags[total_ops_list.index(next_op)] + is True + ): + relevant_op_flags[i] = True + else: + continue effective_ops = [ total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] From 0ded8f3b9b51b4aa53f4440ff6944a4ec5062aac Mon Sep 17 00:00:00 2001 From: wangruting Date: Fri, 25 Aug 2023 03:41:25 +0000 Subject: [PATCH 29/30] reply review comments --- .../dialect/paddle_dialect/ir/pd_manual_op.cc | 122 +++++++++++++++++- .../dialect/paddle_dialect/ir/pd_manual_op.h | 5 + 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index c34bed7c1f622..5e1116863ab77 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -173,6 +173,69 @@ OpInfoTuple SplitGradOp::GetOpInfo() { inputs, attributes, outputs, run_time_info, "split_grad"); } +void SplitGradOp::Build(ir::Builder &builder, + ir::OperationArgument &argument, + ir::OpResult out_grad_, + float axis) { + // Generate scalar mutable attribute: axis + paddle::dialect::FullOp full_axis_op = builder.Build( + std::vector{1}, axis, phi::DataType::FLOAT32, phi::CPUPlace()); + ir::OpResult axis_ = full_axis_op->result(0); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {out_grad_, axis_}; + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); + + VLOG(4) << "Builder construction attributes"; + + VLOG(4) << "Builder construction outputs"; + ir::VectorType out_grad = out_grad_.type().dyn_cast(); + std::vector vec_dense_out_grad; + for (size_t i = 0; i < static_cast(out_grad.size()); i++) { + vec_dense_out_grad.push_back(phi::DenseTensor( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + phi::DenseTensorMeta( + paddle::dialect::TransToPhiDataType( + out_grad[i] + .dyn_cast() + .dtype()), + out_grad[i].dyn_cast().dims(), + out_grad[i] + .dyn_cast() + .data_layout(), + out_grad[i].dyn_cast().lod(), + out_grad[i] + .dyn_cast() + .offset()))); + } + std::vector vec_meta_out_grad; + for (size_t i = 0; i < vec_dense_out_grad.size(); i++) { + vec_meta_out_grad.push_back(phi::MetaTensor(&vec_dense_out_grad[i])); + } + + std::vector meta_out_grad; + for (size_t i = 0; i < static_cast(vec_meta_out_grad.size()); i++) { + meta_out_grad.push_back(&vec_meta_out_grad[i]); + } + phi::DenseTensor dense_x_grad; + phi::MetaTensor meta_x_grad(&dense_x_grad); + + phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad); + + std::vector argument_outputs; + ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + ir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_x_grad.dtype()), + dense_x_grad.dims(), + dense_x_grad.layout(), + dense_x_grad.lod(), + dense_x_grad.offset()); + argument_outputs.push_back(x_grad_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + void SplitGradOp::Build(ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult out_grad_, @@ -185,7 +248,6 @@ void SplitGradOp::Build(ir::Builder &builder, VLOG(4) << "Builder construction outputs"; ir::VectorType out_grad = out_grad_.type().dyn_cast(); - (void)out_grad; int axis = axis_.owner() ->dyn_cast() .attributes() @@ -193,7 +255,6 @@ void SplitGradOp::Build(ir::Builder &builder, .dyn_cast() .data() .to(); - (void)axis; std::vector vec_dense_out_grad; for (size_t i = 0; i < static_cast(out_grad.size()); i++) { @@ -240,7 +301,62 @@ void SplitGradOp::Build(ir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void SplitGradOp::Verify() {} +void SplitGradOp::Verify() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 2u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 2.", input_size)); + if (auto vec_type = + (*this)->operand_source(0).type().dyn_cast()) { + for (size_t i = 0; i < vec_type.size(); ++i) { + PADDLE_ENFORCE(vec_type[i].isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + } else { + PADDLE_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + PADDLE_ENFORCE((*this) + ->operand_source(1) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 1th input.")); + } + VLOG(4) << "Verifying attributes:"; + { + // Attributes num is 0, not need to check attributes type. + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + output_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", output_size)); + PADDLE_ENFORCE( + (*this)->result(0).type().isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th output.")); + } + VLOG(4) << "End Verifying for: SplitGradOp."; +} + +void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::ConcatInferMeta); + fn(infer_meta); +} } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h index 72c16a4ac8b9d..fe9beb46012ed 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h @@ -58,6 +58,10 @@ class SplitGradOp : public ir::Op { static const char *attributes_name[1]; static constexpr uint32_t attributes_num = 1; static OpInfoTuple GetOpInfo(); + static void Build(ir::Builder &builder, // NOLINT + ir::OperationArgument &argument, // NOLINT + ir::OpResult x_, + float axis = 0); static void Build(ir::Builder &builder, // NOLINT ir::OperationArgument &argument, // NOLINT ir::OpResult out_grad_, @@ -67,6 +71,7 @@ class SplitGradOp : public ir::Op { ir::Value out_grad() { return operand_source(0); } ir::Value axis() { return operand_source(1); } ir::OpResult x_grad() { return result(0); } + static void InferMeta(phi::InferMetaContext *infer_meta); }; } // namespace dialect From 5dcecd95aace93b5949189175288a339b2f5f8fc Mon Sep 17 00:00:00 2001 From: wangruting Date: Fri, 25 Aug 2023 06:16:47 +0000 Subject: [PATCH 30/30] modify opruntimeinfo num --- paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 7f5ced9f06b19..64cb1d69b210a 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -167,6 +167,7 @@ OpInfoTuple SplitGradOp::GetOpInfo() { {"out_grad", "axis"}, {"out_grad"}, {}, + {}, {}); return std::make_tuple(