Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recompute support tuple #56793

Merged
merged 6 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 66 additions & 16 deletions python/paddle/distributed/fleet/recompute/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,41 @@
def detach_variable(inputs):
out = []
for inp in inputs:
if not isinstance(inp, core.eager.Tensor):
if not isinstance(inp, core.eager.Tensor) and (
type(inp) is not tuple or not isinstance(inp[0], core.eager.Tensor)
):
# the inp is not a tensor or not a tuple of tensors
out.append(inp)
continue

if type(inp) is tuple:
detach_inp = []
for i in inp:
# detach all tensors in the tuple
assert isinstance(i, core.eager.Tensor)
tmp_i = i.detach()
tmp_i.stop_gradient = i.stop_gradient
detach_inp.append(tmp_i)
out.append(tuple(detach_inp))
continue

x = inp.detach()
x.stop_gradient = inp.stop_gradient
out.append(x)
return tuple(out)


def check_recompute_necessary(inputs):
if not any(
not input_.stop_gradient
for input_ in inputs
if isinstance(input_, (core.eager.Tensor, paddle.Tensor))
):
necessary_for_each_input = []
for input_ in inputs:
if isinstance(input_, (core.eager.Tensor, paddle.Tensor)):
necessary_for_each_input.append(input_.stop_gradient)
elif type(input_) is tuple:
for i in input_:
# traverse all tensors in the tuple
if isinstance(i, (core.eager.Tensor, paddle.Tensor)):
necessary_for_each_input.append(i.stop_gradient)
if all(necessary_for_each_input):
logger.warning(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !"
Expand Down Expand Up @@ -81,12 +100,37 @@ def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.duplicate_tensor = [False for _ in range(len(args))]
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
elif type(arg) is tuple:
is_tensors = [paddle.is_tensor(a) for a in arg]
if all(is_tensors):
# the tuple is a tuple of tensors
tensors_stop_gradient = [a.stop_gradient for a in arg]
if not all(tensors_stop_gradient) and any(
tensors_stop_gradient
):
# tensors in the tuple have different stop_gradient value, which pylayer doesn't support
raise ValueError(
"Recompute receive a tuple containing tensor holds different stop gradient."
)
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
# Mark the tuple is a tuple of tensors
ctx.duplicate_tensor[i] = True
ctx.inputs.append(None)
elif any(is_tensors):
# the tuple contains tensors and non-tensor values
raise ValueError(
"Recompute receive a tuple containing tensor and non-tensor at same time."
)
else:
ctx.inputs.append(arg)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
Expand Down Expand Up @@ -132,6 +176,7 @@ def backward(ctx, *args):
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
duplicate_tensor = ctx.duplicate_tensor
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
Expand Down Expand Up @@ -198,18 +243,23 @@ def backward(ctx, *args):
forward_outputs_with_grad, backward_inputs_with_grad
)

grads = []
for idx, inp in enumerate(detached_inputs):
if isinstance(inp, core.eager.Tensor):
grads.append(inp._grad_ivar())
elif type(inp) is tuple and duplicate_tensor[idx]:
# input is a tuple and is a tuple of tensors
if all(i.stop_gradient for i in inp):
# all tensors in the tuple doesn't need grad, only return a None for the whole tuple
grads.append(None)
else:
# all tensors in the tuple nees grad, should return a tuple of grads
grads.append(tuple(i._grad_ivar() for i in inp))

if in_dynamic_mode():
grads = tuple(
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, core.eager.Tensor)
)
grads = tuple(grads)
else:
grads = [
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, core.eager.Tensor)
]
grads = list(grads)
return grads


Expand Down
89 changes: 89 additions & 0 deletions test/legacy_test/test_recompute_with_tuple_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle.distributed.fleet.utils import recompute


class Layer(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear1 = paddle.nn.Linear(10, 10)
self.linear2 = paddle.nn.Linear(10, 10)
self.linear3 = paddle.nn.Linear(10, 10)
self.silu1 = paddle.nn.Silu()
self.silu2 = paddle.nn.Silu()
self.silu3 = paddle.nn.Silu()

def forward(self, x, y):
assert type(x) is tuple
assert len(x) == 2
o1 = self.silu1(self.linear1(x[0]))
o2 = self.silu2(self.linear2(x[1]))
o3 = self.silu3(self.linear3(y))
o = o1 + o2 + o3
return o


class TestPyLayer(unittest.TestCase):
def test_tuple_input(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x1.stop_gradient = False
x2 = paddle.rand(shape=[10, 10])
x2.stop_gradient = False
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
o = recompute(layer, (x1, x2), y)
loss = paddle.mean(o, keepdim=True)
loss.backward()

def test_tuple_input_with_non_tensor(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x1.stop_gradient = False
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
try:
o = recompute(layer, (x1, True), y)
except ValueError:
pass

def test_tuple_input_with_different_stop_gradient(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x1.stop_gradient = False
x2 = paddle.rand(shape=[10, 10])
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
try:
o = recompute(layer, (x1, True), y)
except ValueError:
pass

def test_tuple_input_all_no_gradient(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x2 = paddle.rand(shape=[10, 10])
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
o = recompute(layer, (x1, x2), y)
loss = paddle.mean(o, keepdim=True)
loss.backward()


if __name__ == '__main__':
unittest.main()