Skip to content

Commit

Permalink
Refactor and simplify hook design & add Tensor.register_hook API (#31775
Browse files Browse the repository at this point in the history
)

* refactor and simplify hook design

* fix reducer add hook error

* add Tensor.register_hook basic impl

* refine prepare data impl

* revert prepare data change

* support register_hook for Tensor

* add hook test in model

* polish tests and doc example

* fix double grad test failed

* remove reduce hook func

* fix set empty error

* polish code by comments

* change reduce_hook to mutable_hook

* remove useless tmp_ins

* fix shape code format error

* fix shape code format error
  • Loading branch information
chenwhql authored Apr 1, 2021
1 parent 6b74486 commit dbeb3ea
Show file tree
Hide file tree
Showing 13 changed files with 863 additions and 356 deletions.
79 changes: 57 additions & 22 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,6 @@ void BasicEngine::PrepareGradAccumulators(
<< var.get()
<< ") that don't have grad node with reference count "
<< accumulator->RefCnt();

if (var->HasLeafHooks()) {
VLOG(3) << "Grad variable wrapper (" << var->Name()
<< ") has leaf grad hooks.";
PADDLE_ENFORCE_NE(
var->HasGradNode(), true,
platform::errors::PermissionDenied(
"Only leaf Tensor's gradient can append hook to "
"Gradientaccumulator."));
accumulator->SetPostHooks(var->GetLeafHooks());
}
} else {
// Because Inplace op overwrites the grad_node of the input grad_var. So
// only the information of grad_pending_node can be used to find the
Expand Down Expand Up @@ -262,6 +251,30 @@ void BasicEngine::PrepareDeps() {
}
}

static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
const NameVarMap<VariableWrapper>& bwd_ins, const std::string& op_type) {
std::shared_ptr<NameVarMap<VariableWrapper>> tmp_ins_ptr = nullptr;
for (const auto& pair : bwd_ins) {
for (size_t i = 0; i < pair.second.size(); ++i) {
auto& var = pair.second[i];
if (var->HasHook()) {
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VariableWrapper>>(bwd_ins);
}
VLOG(3) << "Call " << var->GetHooks().size() << " hooks of " << op_type
<< "'s input `" << pair.first << "`'s var `" << var->Name()
<< "`.";
auto tmp_var = var;
for (const auto& hook_pair : var->GetHooks()) {
tmp_var = (*hook_pair.second)(tmp_var);
}
(*tmp_ins_ptr)[pair.first][i] = tmp_var;
}
}
}
return tmp_ins_ptr;
}

void BasicEngine::Execute() {
if (init_node_ == nullptr) {
return;
Expand Down Expand Up @@ -292,10 +305,15 @@ void BasicEngine::Execute() {
auto& bwd_ins = cur_op.GetInsMap();
auto& bwd_outs = cur_op.GetOutsMap();

/**
* [ Why need temporary outputs here? ]
*
* - construct the temp output map, avoid to disrupt graph
* - replace the element in the map by temp var, because a
* var may be coresponding to several grad var in one op
*/
NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
// 1. construct the temp output map, avoid to disrupt graph
// 2. replace the element in the map by temp var, because a
// var may be coresponding to several grad var in one op

for (auto& pair : tmp_outs) {
if (!pair.second.IsGrad()) {
continue;
Expand Down Expand Up @@ -408,10 +426,28 @@ void BasicEngine::Execute() {
}
}

/**
* [ Why need temporary inputs here? ]
*
* - Hook execution should not change original input tensor.
* User can register hook for Tensor's gradient, It is expected
* that the hook only affects the gradient of the backward
* propagation, and does not affect the gradient value input
* as the hook.
* - use `tmp_ins_ptr`, only copy bwd_ins when the var in bwd_ins
* hold hooks
*/
auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type());

{
VLOG(3) << "Start to execute grad op " << cur_op.Type();
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
} else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, cur_op.Attrs(),
cur_op.place());
}
}

for (auto& pair : inplace_output_grad_var_list_) {
Expand All @@ -428,15 +464,14 @@ void BasicEngine::Execute() {
if (!accumulator->SumGradCompleted()) {
continue;
}
// 1. Call Hooks for **inner_var_**
// 1. Call Hooks for `inner_var_`
accumulator->CallGradientHooks();

// 2. Sum Gradient with Previous Graph
// 2. Sum Gradient `inner_var_` to `var_` of Current or Previous Graph
accumulator->AccumulateGrad();

// 3. Call backward Hooks for **var_**
if (accumulator->HasPostHooks()) {
accumulator->CallBackwardPostHooks();
}
// 3. Call backward Hooks for `var_`
accumulator->CallReduceHooks();
}

need_accu_var_list_.clear();
Expand Down
61 changes: 58 additions & 3 deletions paddle/fluid/imperative/gradient_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ static platform::Place GetPlaceOfVar(

void GradientAccumulator::AccumulateGrad() {
/**
* If the gradient has been calculated by previous graph,
* it should be added to the previous graph result.
* If the leaf gradient has been calculated done, the inner_var_
* should be added to the var_.
*/
if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) {
return;
Expand All @@ -396,7 +396,7 @@ void GradientAccumulator::AccumulateGrad() {
"this auto-grad"));
PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true,
platform::errors::InvalidArgument(
"Interior var of Leaf tensor should be initialized."));
"Interior var of Leaf tensor should be initialized."));
auto* src = inner_var_->MutableVar();
auto* dst = var_->MutableVar();
if (!var_->IsEmpty()) {
Expand Down Expand Up @@ -427,10 +427,65 @@ void GradientAccumulator::AccumulateGrad() {
*(dst) = std::move(*src);
var_->SetType(inner_var_->Type());
var_->SetDataType(inner_var_->DataType());
var_->SetIsEmpty(false);
}
inner_var_.reset();
}

void GradientAccumulator::CallGradientHooks() {
PADDLE_ENFORCE_EQ(var_->IsLeafGrad(), true,
platform::errors::Unavailable(
"Only leaf gradient Tensor can deal with by gradient "
"hook in gradient accumulator."));
PADDLE_ENFORCE_EQ(
SumGradCompleted(), true,
platform::errors::PreconditionNotMet(
"Only can call gradient hooks after sum gradient completed."));
PADDLE_ENFORCE_EQ(
HasInnerVar(), true,
platform::errors::PreconditionNotMet(
"Leaf Tensor's inner var is nullptr when call gradient hook."));
PADDLE_ENFORCE_EQ(
inner_var_->Var().IsInitialized(), true,
platform::errors::PreconditionNotMet("Leaf Tensor's inner var "
"is not initialized when "
"call gradient hook."));
if (var_->HasHook()) {
VLOG(3) << "Call " << var_->GetHooks().size()
<< " hooks of leaf gradient accumulator's inner var `"
<< var_->Name() << "`.";
auto tmp_var = inner_var_;
VLOG(3) << "Input var " << var_->Name() << "'s hook size - "
<< var_->GetHooks().size();
for (const auto& hook_pair : var_->GetHooks()) {
tmp_var = (*hook_pair.second)(tmp_var);
}
inner_var_ = tmp_var;
}
}

void GradientAccumulator::CallReduceHooks() {
PADDLE_ENFORCE_EQ(
var_->IsLeafGrad(), true,
platform::errors::Unavailable("Only leaf gradient Tensor can deal with "
"by reduce hook in gradient accumulator."));
PADDLE_ENFORCE_EQ(SumGradCompleted(), true,
platform::errors::PreconditionNotMet(
"Only can call reduce hooks after the gradient "
"summation is completed in current batch."));
PADDLE_ENFORCE_EQ(HasInnerVar(), false,
platform::errors::PreconditionNotMet(
"Only can call reduce hooks after the "
"gradient accumulation is completed in "
"current batch or across batchs."));
if (var_->HasMutableHook()) {
for (const auto& hook : var_->GetMutableHooks()) {
VLOG(3) << "call gradient accumulator backward hooks.";
(*hook)(var_);
}
}
}

void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
size_t trace_id, bool unchange_input) {
/**
Expand Down
70 changes: 31 additions & 39 deletions paddle/fluid/imperative/gradient_accumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class GradientAccumulator {
}

// inner_var_ record the grad of this auto-grad.
// Only need to generate inner var for non-empty leaf-tensor.
if (var->IsLeafGrad() && !var->IsEmpty()) {
// Only need to generate inner var for leaf-tensor.
if (var->IsLeafGrad()) {
inner_var_ = std::make_shared<VariableWrapper>(var->Name());
inner_var_->SetType(var->Type());
inner_var_->SetDataType(var->DataType());
Expand All @@ -52,9 +52,6 @@ class GradientAccumulator {
<< ") to store result of this Graph";
}

// TODO(zhouwei): fix Tensor.clear_gradient() bug, remove this hard flag
var->SetIsEmpty(false);

// var_ is the final grad, processed by hooks and grad accumulation
var_ = var;
}
Expand Down Expand Up @@ -93,42 +90,38 @@ class GradientAccumulator {

inline bool HasInnerVar() const { return inner_var_ != nullptr; }

/* Hook related methods */
inline bool HasPostHooks() const { return !post_hooks_.expired(); }

void SetPostHooks(const std::shared_ptr<LeafVarHookPipeline>& hooks) {
PADDLE_ENFORCE_NOT_NULL(
hooks, platform::errors::InvalidArgument(
"The hook set to GradientAccumulator is nullptr."));

auto shared_hooks = post_hooks_.lock();
if (shared_hooks != hooks) {
PADDLE_ENFORCE_EQ(
shared_hooks, nullptr,
platform::errors::PermissionDenied(
"Cannot set post hooks twice to GradientAccumulator."));
post_hooks_ = hooks;
}
}
// void CallHooks(){}
// ** inner_var_ **

// function that Sum Gradient with Previous Graph
void AccumulateGrad();

// call backward post hooks, such as reduce hook
void CallBackwardPostHooks() {
PADDLE_ENFORCE_NE(
post_hooks_.expired(), true,
platform::errors::NotFound(
"The post hooks of GradientAccumulator for Tensor `%s` expired.",
var_->Name()));
auto shared_hooks = post_hooks_.lock();
for (const auto& hook : shared_hooks->backward_hooks()) {
VLOG(3) << "call gradient accumulator backward hooks.";
(*hook)(var_);
}
}
/** [ Hook related methods ]
*
* [Why need two types of VariableWrapperHook? ]
*
* There are two types of gradient accumulation:
* 1. Gradient accumulation in same batch
* 2. Gradient accumulation across batchs
* The order of execution between Hooks and gradient accumulation:
* [ Gradient accumulation in same batch]
* |
* [ leaf GradVarBase hooks ]
* |
* [ Gradient accumulation across batchs ]
* |
* [ Gradient reduce / allreduce hooks ]
* Because we currently intend to accumulate these two gradient
* accumulation in one GradientAccumulator, We must distinguish between
* two types of hooks.
* And the InplaceVariableWrapperHook does not allow users to register
* directly, and is currently only used to support the reduce strategy of
* parallel multi-card training.
*/

void CallGradientHooks();

void CallReduceHooks();

protected:
VariableWrapper* var_;
Expand All @@ -137,7 +130,6 @@ class GradientAccumulator {
std::shared_ptr<VariableWrapper> inner_var_;
size_t ref_cnt_{0};
size_t cur_cnt_{0};
std::weak_ptr<LeafVarHookPipeline> post_hooks_;
};

class EagerGradientAccumulator : public GradientAccumulator {
Expand Down
Loading

0 comments on commit dbeb3ea

Please sign in to comment.