Skip to content

Commit

Permalink
Support shard to partial reshard function.
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Nov 23, 2023
1 parent e1f3e75 commit 6bd87bd
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 2 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h"
Expand Down Expand Up @@ -215,6 +216,10 @@ void BindAutoParallel(py::module *m) {
*m, "SToSReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::SToPReshardFunction>(
*m, "SToPReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::PToSReshardFunction>(
*m, "PToSReshardFunction", ReshardFunction)
.def(py::init<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ collect_srcs(
p_to_r_reshard_function.cc
s_to_s_reshard_function.cc
p_to_s_reshard_function.cc
s_to_p_reshard_function.cc
nd_mesh_reshard_function.cc
same_status_reshard_function.cc)
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h"

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/kernels/reduce_scatter_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {
namespace distributed {

bool SToPReshardFunction::IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
const auto& in_dist_attr = in.dist_attr();

RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard());
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_partial());

const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh);

return true;
}

void SToPReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call SToPReshardFunction Eval";

// step 1, create tmp dist attr and tmp dist tensor
TensorDistAttr tmp_attr(out_dist_attr);
DistTensor tmp_tensor;
tmp_attr.clean_partial_status();

// step 2, do s to r reshard on `in` to `tmp`
SToRReshardFunction s_to_r;
s_to_r.Eval(dev_ctx, in, tmp_attr, &tmp_tensor);

// step 3, do r to p reshard on `tmp` to `out`
RToPReshardFunction r_to_p;
r_to_p.Eval(dev_ctx, tmp_tensor, out_dist_attr, out);
}

REGISTER_RESHARD_FUNC(SToPReshardFunction);

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h"

namespace phi {
namespace distributed {

class SToPReshardFunction final : public ReshardFunction {
public:
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;

std::string Name() override { return "SToPReshard"; }
};

} // namespace distributed
} // namespace phi
3 changes: 3 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_s_to_p MODULES test_reshard_s_to_p)
set_tests_properties(test_reshard_s_to_p
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_p_to_s MODULES test_reshard_p_to_s)
set_tests_properties(test_reshard_p_to_s
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
Expand Down
74 changes: 74 additions & 0 deletions test/auto_parallel/reshard_s_to_p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.base import core


class TestReshardSToR:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())

dev_ctx = core.DeviceContext.create(place)
a = paddle.ones(self._shape)

in_shard_specs = [None for i in range(len(self._shape))]
in_shard_specs[0] = "x"
out_shard_specs = [None for i in range(len(self._shape))]

dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)
out_dist_attr._set_partial_dims([0])

input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)

assert input_tensor._local_shape[0] == self._shape[0] // 2

reshard_func = core.SToPReshardFunction()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)

out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)

if dist.get_rank() == 0:
np.testing.assert_equal(out._local_value().numpy(), a.numpy())
else:
zeros = paddle.zeros(self._shape)
np.testing.assert_equal(out._local_value().numpy(), zeros.numpy())

assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out._local_shape, input_tensor.shape).all()


if __name__ == '__main__':
TestReshardSToR().run_test_case()
3 changes: 1 addition & 2 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ def run_test_case(self):
if self._backend == "gpu":
self.test_adamw_mp()
self.test_adamw_shard_optimizer(stage1=True)
# A problem has to be addressed if not shard batch.
# self.test_adamw_shard_optimizer(stage1=False)
self.test_adamw_shard_optimizer(stage1=False)


if __name__ == '__main__':
Expand Down
44 changes: 44 additions & 0 deletions test/auto_parallel/test_reshard_s_to_p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import collective.test_communication_api_base as test_base


class TestReshardSToP(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": "1234",
}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
}

def test_reshard_s_to_r(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_s_to_p.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6bd87bd

Please sign in to comment.