Skip to content

Commit

Permalink
Fix stop_gradient in RunProgramOp (#36339)
Browse files Browse the repository at this point in the history
* Fix stop_gradient in RunProgramOp

* fix reference
  • Loading branch information
Aurelius84 committed Oct 12, 2021
1 parent 31a5829 commit 50f0119
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
26 changes: 19 additions & 7 deletions paddle/fluid/operators/run_program_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,15 @@ static void ShareVarsIntoScope(const std::vector<Variable *> &vars,

static void ShareVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
const BlockDesc &global_block,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
// NOTE: In case of setting out_tmp.stop_gradient = True in model code, all
// parameters before generating out_tmp have no @GRAD, it will raise error
// because we can't findthem in scope. So we skip sharing these vars or
// var@GRAD if they don't appear in global block.
if (var_names[i] == framework::kEmptyVarName ||
var_names[i] == "Fake_var") {
var_names[i] == "Fake_var" || !global_block.HasVar(var_names[i])) {
VLOG(2) << "find variable name is " << var_names[i] << ", skip it!";
continue;
}
Expand Down Expand Up @@ -214,8 +219,10 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
details::ShareVarsIntoScope(input_vars, input_var_names, &scope);
details::ShareVarsIntoScope(param_vars, param_names, &scope);

auto *global_block = ctx.Attr<BlockDesc *>("global_block");

if (end_op_index > start_op_index) {
auto *program = ctx.Attr<BlockDesc *>("global_block")->Program();
auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad=*/false, program_id, &scope);
Expand All @@ -240,8 +247,10 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
parallel_executor->RunWithoutFetch(skip_eager_delete_vars);
}
// Step 4. Get Output
details::ShareVarsFromScope(output_vars, output_var_names, &scope);
details::ShareVarsFromScope(dout_vars, dout_var_names, &scope);
details::ShareVarsFromScope(output_vars, output_var_names, *global_block,
&scope);
details::ShareVarsFromScope(dout_vars, dout_var_names, *global_block,
&scope);

// Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());
Expand Down Expand Up @@ -307,10 +316,11 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
"least one sub scope."));

auto &scope = *(global_inner_scope->kids().front());
auto *global_block = ctx.Attr<BlockDesc *>("global_block");

if (end_op_index > start_op_index) {
// Step 2. prepare executor and scope
auto *program = ctx.Attr<BlockDesc *>("global_block")->Program();
auto *program = global_block->Program();
auto cache_info = framework::GetExecutorInfoFromCache(
*program, ctx.GetPlace(), start_op_index, end_op_index,
/*is_grad*/ true, program_id, &scope);
Expand Down Expand Up @@ -341,8 +351,10 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
}

// Step 4. get outputs
details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, &scope);
details::ShareVarsFromScope(param_grad_vars, param_grad_names, &scope);
details::ShareVarsFromScope(input_grad_vars, input_grad_var_names,
*global_block, &scope);
details::ShareVarsFromScope(param_grad_vars, param_grad_names,
*global_block, &scope);

// Step5. drop current scope
global_inner_scope->DeleteScope(&scope);
Expand Down
48 changes: 48 additions & 0 deletions python/paddle/fluid/tests/unittests/test_run_program_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,5 +343,53 @@ def build_model(self):
return fwd_op_num


class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.fc1 = paddle.nn.Linear(10, 10)
self.fc2 = paddle.nn.Linear(10, 1)

def forward(self, x):
out = self.fc1(x)
out.stop_gradient = True
out = self.fc2(out)
return out


class TestParametersWithStopGradient(unittest.TestCase):
def setUp(self):
self.seed = 2021
self.iter = 5

def train(self, to_static):
# prepare env
paddle.seed(self.seed)

net = Net()
if to_static:
net = paddle.jit.to_static(net)
sgd = paddle.optimizer.SGD(0.01, parameters=net.parameters())

for i in range(self.iter):
x = paddle.rand([4, 10])
out = net(x)
loss = paddle.mean(out)

loss.backward()
sgd.minimize(loss)
net.clear_gradients()

return loss

def test_stop_gradient(self):
paddle.disable_static()

dy_loss = self.train(to_static=False)
st_loss = self.train(to_static=True)
self.assertEqual(dy_loss[0], st_loss[0])

paddle.enable_static()


if __name__ == "__main__":
unittest.main()

1 comment on commit 50f0119

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.