Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix inplace bug when the first grad_var(loss_grad) is inplace var #37420

Merged
merged 4 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ void BasicEngine::Init(
platform::errors::AlreadyExists(
"Accumulators are not empty before preparing it for "
"backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

for (size_t i = 0; i < tensors.size(); ++i) {
auto var = tensors[i];
Expand All @@ -73,7 +77,6 @@ void BasicEngine::Init(
VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name()
<< " because of retain_graph=False when calling backward";
var->GradVarBase()->SetGraphIsFreed(true);
var->GradVarBase()->ClearGradNode();
}

if (init_node == nullptr || var->OverridedStopGradient()) {
Expand Down Expand Up @@ -108,14 +111,18 @@ void BasicEngine::Init(
}

VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get();
auto& accumulator = accumulators_[init_grad_var];
auto& accumulator =
accumulators_with_grad_node_[init_grad_var->GetGradNode()]
[init_grad_var];
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(init_grad_var));
} else {
accumulator.reset(new EagerGradientAccumulator(init_grad_var));
}
}
accumulator->IncreaseRefCnt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两行的作用是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可参考PR #34582 。BasicEngine::Init()中,如果传入的grad var(设置好了初始grad值)是一个中间节点,反向计算时,需要对这个中间节点做梯度累加操作。
因为本PR将传入的grad var从叶子节点变成了非叶子节点,导致原梯度累加操作发生了变化,需要指定其ref cnt和cur cnt做梯度聚合。
可以理解成,传入的这个grad var已经是某一个op的输出(不存在的op),在后续操作中,如果这个grad var又作为网络中一个op的输出的话,就可以做梯度聚合了。

accumulator->IncreaseCurCnt();

init_nodes_.push_back(init_node);
}
Expand Down Expand Up @@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() {
node_deps_.empty(), true,
platform::errors::AlreadyExists("Op deps are not empty before preparing "
"it for backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

std::queue<GradOpNode*> q;
std::unordered_set<GradOpNode*> visited;
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,5 +409,30 @@ def inplace_api_processing(self, var):
return var.subtract_(self.input_var_2)


class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh_()

loss.backward()
inplace_grad_var_a = var_a.grad.numpy()

with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh()

loss.backward()
grad_var_a = var_a.grad.numpy()

self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))


if __name__ == '__main__':
unittest.main()