Skip to content

Commit

Permalink
[AutoParallel] Gradient merge supporting. (PaddlePaddle#58339)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Oct 25, 2023
1 parent 166e33e commit 5d5eeec
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
18 changes: 18 additions & 0 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ static void CopyOrAddTensor(paddle::Tensor* tensor,
&tensor_values);
}
}
} else if (LIKELY(t.is_dist_tensor())) {
PADDLE_ENFORCE(
tensor->is_dist_tensor(),
paddle::platform::errors::Fatal("A DistTensor can only do gradient "
"merge with another DistTensor."));
PADDLE_ENFORCE(!t.is_custom_device(),
paddle::platform::errors::Fatal(
"DistTensor doesn't support custom device."));
auto t_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl());
paddle::Tensor t_values(
std::make_shared<phi::DenseTensor>(t_dist->value()));
auto tensor_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor->impl());
paddle::Tensor tensor_values(
std::make_shared<phi::DenseTensor>(tensor_dist->value()));
paddle::imperative::TensorAdd<paddle::Tensor>(t_values, &tensor_values);
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
Expand Down
93 changes: 93 additions & 0 deletions test/auto_parallel/semi_auto_parallel_simple_net_gradient_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from semi_auto_parallel_simple_net import (
DemoNet,
DPDemoNet,
MPDemoNet,
TestSimpleNetForSemiAutoParallel,
)

import paddle
import paddle.distributed as dist
from paddle import nn


class TestSimpleNetWithGradientMergeForSemiAutoParallel(
TestSimpleNetForSemiAutoParallel
):
def __init__(self):
self._dtype = os.getenv("dtype")
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

paddle.set_device(self._backend)
self.init_input_data()
self.init_single_card_net_result()

def run_dynamic_gradient_merge(self, layer):
# create loss
loss_fn = nn.MSELoss()
# run forward and backward
image = paddle.to_tensor(self.image)

for i in range(2):
out = layer(image)
label = paddle.to_tensor(self.label)
loss = loss_fn(out, label)
loss.backward()

return loss, layer.w0.grad, layer.w1.grad

def init_single_card_net_result(self):
(
self.base_loss,
self.base_w0_grad,
self.base_w1_grad,
) = self.run_dynamic_gradient_merge(DemoNet(self.w0, self.w1))

def test_dp_demo_net(self):
(
self.dp_loss,
self.dp_w0_grad,
self.dp_w1_grad,
) = self.run_dynamic_gradient_merge(
DPDemoNet(self.w0, self.w1, self._mesh)
)
self.check_tensor_eq(self.dp_loss, self.base_loss)
self.check_tensor_eq(self.dp_w0_grad, self.base_w0_grad)
self.check_tensor_eq(self.dp_w1_grad, self.base_w1_grad)

def test_mp_demo_net(self):
(
self.mp_loss,
self.mp_w0_grad,
self.mp_w1_grad,
) = self.run_dynamic_gradient_merge(
MPDemoNet(self.w0, self.w1, self._mesh)
)
self.check_tensor_eq(self.mp_loss, self.base_loss)
self.check_tensor_eq(self.mp_w0_grad, self.base_w0_grad)
self.check_tensor_eq(self.mp_w1_grad, self.base_w1_grad)

def run_test_case(self):
self.test_dp_demo_net()
self.test_mp_demo_net()


if __name__ == '__main__':
TestSimpleNetWithGradientMergeForSemiAutoParallel().run_test_case()
11 changes: 11 additions & 0 deletions test/auto_parallel/test_semi_auto_parallel_single_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def test_simple_net_single_strategy_with_amp(self):
user_defined_envs=envs,
)

def test_simple_net_single_strategy_with_gradient_merge(self):
self._changeable_envs = {"backend": ["gpu"]}
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_simple_net_gradient_merge.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5d5eeec

Please sign in to comment.