Skip to content

Commit

Permalink
[Eager Grad] Support eager grad interface (#40170)
Browse files Browse the repository at this point in the history
* [Eager] Support eager grad interface, draft version

* Support eager grad interface with allow_unused and multi startup_op

* Fix code format

* Fix allow_unused case, return PyNone if tensor not initialize

* Support output's stop_gradient related to create_graph

* Support grad exception case in eager mode, fix coverage CI

* Update ToPyObject, return PyNone if not initialize

* AccumulationNode add FLAGS_retain_grad_for_all_tensor

* Fix ci issue

* Fix CI issue

* fix, use core.eager.Tensor

* Add func SetBufferSlotRankZeros for GradTensorHolder

* Support retain_graph by using ClearTensorWrappers

* Support retain_graph by using ClearTensorWrappers

* Update retain_graph and no_grad_vars related test case

* Update code gen logic for ClearTensorWrappers

* Fix by override statement

* fix override func args

* Support retain_graph, update unit tests

* Updated ClearTensorWrappers logic

* fix grad python interface

* Use deep copy and update unit tests

* Polish code

* Polish code

* Fix CI issue, Deep copy only use when user set grad_tensors

* Fix CI, use Backward instead RunBackward

* Fix CI, Declare kernel explicitly in test file

* Polish, remove vector of TensorWrapper

* Refactor the logic of grad/backward, polish codes

* Update code after merge upstream develop

* Polish after merge upstream develop

* Update to adapt new GradNodeBase superclass

* Fix error introduced during conflict resolution

* Update purify potential_startup_nodes logic

* Fix errors

* Polish code

* Remove useless args for ToPyObject

* Remove useless TensorWrappersSet

* Fix code-format, re-install pre-commit

* Fix pre-process logic for potential_startup_ops

* Update unit tests, use eager mode
  • Loading branch information
veyron95 authored Mar 17, 2022
1 parent 1e045ca commit 4db8cf2
Show file tree
Hide file tree
Showing 32 changed files with 1,217 additions and 163 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "paddle/fluid/platform/errors.h"

#include "glog/logging.h"

DECLARE_bool(retain_grad_for_all_tensor);
namespace egr {

static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
Expand All @@ -39,8 +39,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
}

std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal(
Expand All @@ -62,7 +62,7 @@ operator()(
grad_out = grads[0][0];
}

if (!weak_grad_.expired()) {
if (!weak_grad_.expired() && FLAGS_retain_grad_for_all_tensor) {
auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out);
}
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/eager/accumulation/accumulation_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,15 @@ class GradNodeAccumulation : public GradNodeBase {

// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads)
override;
const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) override;

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

std::string name() { return "GradNodeAccumulation"; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ void GradNodeScale::SetTensorWrappers_X(
void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }

std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) {
// 1. Check Output Size
PADDLE_ENFORCE(
((grads.size() == 1) && (grads[0].size() == 1)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@ class GradNodeScale : public GradNodeBase {

// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads)
override;
const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) override;

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

void SetTensorWrappers_X(
const std::vector<paddle::experimental::Tensor>& tensors);
Expand Down
33 changes: 28 additions & 5 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2074,7 +2074,8 @@ static std::string GenerateGradNodeCCContents(
const char* GRAD_FUNCTION_TEMPLATE =
"std::vector<std::vector<paddle::experimental::Tensor>> "
"GradNode%s::operator()(const "
"std::vector<std::vector<paddle::experimental::Tensor>>& grads) {\n%s\n}";
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, "
"bool create_graph) {\n%s\n}";
std::string grad_function_str = paddle::string::Sprintf(
GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body);

Expand Down Expand Up @@ -2109,18 +2110,28 @@ static std::string GenerateGradNodeHeaderContents(
"\n"
" virtual std::vector<std::vector<paddle::experimental::Tensor>> "
"operator()(const "
"std::vector<std::vector<paddle::experimental::Tensor>>& grads) "
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, const "
"bool create_graph = false) "
"override;\n"
"\n"
" void ClearTensorWrappers() override { \n"
"%s\n"
" is_tensor_wrappers_cleared = true;\n"
" }\n"
" std::string name() override { return \" GradNode%s \"; } \n "
"\n"
" // SetX, SetY, ...\n"
"%s\n"
" // SetAttrMap\n"
"%s\n"
" bool IsTensorWrappersCleared() override { \n"
" return is_tensor_wrappers_cleared;\n"
" }\n"
" private:\n"
" // TensorWrappers\n"
"%s\n"
" bool is_tensor_wrappers_cleared = false;\n"
"\n"
" // Attribute Map\n"
"%s\n"
"};";
Expand Down Expand Up @@ -2154,6 +2165,7 @@ static std::string GenerateGradNodeHeaderContents(

std::string set_tensor_wrappers_str = "";
std::string tensor_wrapper_members_str = "";
std::string clear_tensor_wrappers_str = "";
for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap();
Expand Down Expand Up @@ -2185,6 +2197,13 @@ static std::string GenerateGradNodeHeaderContents(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name);

const char* CLEAR_TENSOR_WRAPPER_TEMPLATE =
"for (auto tw: %s) {\n"
" tw.clear();\n"
" }\n";
clear_tensor_wrappers_str += paddle::string::Sprintf(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);

} else {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const paddle::experimental::Tensor& %s";
Expand All @@ -2197,10 +2216,14 @@ static std::string GenerateGradNodeHeaderContents(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);

const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, %s /*full_reserved*/);";
"%s = egr::TensorWrapper(%s, %s /*full_reserved*/);\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name, full_reserved_str);

const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n";
clear_tensor_wrappers_str += paddle::string::Sprintf(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
}
std::string full_reserved_signature_str = "bool full_reserved";
const char* SET_TENSOR_WRAPPER_TEMPLATE =
Expand All @@ -2215,8 +2238,8 @@ static std::string GenerateGradNodeHeaderContents(

std::string grad_node_str = paddle::string::Sprintf(
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type,
op_type, op_type, set_tensor_wrappers_str, set_attr_map_str,
tensor_wrapper_members_str, attr_members_str);
op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str,
set_attr_map_str, tensor_wrapper_members_str, attr_members_str);

return grad_node_str;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
# SetTensorWrapper Methods & TensorWrapper Members
set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items():
if tname in no_need_buffer_set:
no_need_buffer = "true"
Expand All @@ -499,6 +500,13 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
"""
tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)

CLEAR_TENSOR_WRAPPERS_TEMPLATE = """
{}.clear();
"""
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)

else:
assert IsVectorTensorType(ttype)
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """
Expand All @@ -516,6 +524,15 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
"""
tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)

CLEAR_TENSOR_WRAPPERS_TEMPLATE = """
for (auto tw: {}) {
tw.clear();
};
"""
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)

# End: SetTensorWrapper Methods & TensorWrapper Members

# SetAttributes & Attribute Members
Expand All @@ -524,7 +541,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
for aname, atype, default_val, _ in backward_attrs_list:
saved_attr_name = GetSavedName(aname)
SET_ATTR_METHOD_TEMPLATE = """
void SetAttribute{}({} {}) {{
void SetAttribute{}({} {}) {{
{} = {};
}}
"""
Expand Down Expand Up @@ -555,25 +572,37 @@ class {} : public egr::GradNodeBase {{
~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) override;
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{}
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
{}
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
"""
node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, set_tensor_wrapper_methods_str,
set_attribute_methods_str, tensor_wrapper_members_str,
attribute_members_str)
grad_node_name, clear_tensor_wrapper_str,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)

return node_declaration_str

Expand Down Expand Up @@ -637,7 +666,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_namespace = f"paddle::experimental"

FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Call grad_api function
auto grad_api_returns = {}::{}({});
{}
Expand Down
Loading

0 comments on commit 4db8cf2

Please sign in to comment.