Skip to content

Commit

Permalink
support r to s cross mesh (#58550)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Nov 2, 2023
1 parent bb354dd commit 11c62e6
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 3 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ void BindAutoParallel(py::module *m) {
*m, "RToSReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::RToSReshardFunctionCrossMesh>(
*m, "RToSReshardFunctionCrossMesh", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::SToRReshardFunction>(
*m, "SToRReshardFunction", ReshardFunction)
.def(py::init<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h"
#include "paddle/phi/kernels/split_kernel.h"

namespace phi {
Expand Down Expand Up @@ -85,7 +86,57 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
SetDistProps(out, in.dims(), out_dist_attr);
}

bool RToSReshardFunctionCrossMesh::IsSuitable(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
const auto& in_dist_attr = in.dist_attr();

RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated());
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard());

const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() ==
out_process_mesh.shape());
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh);

return true;
}

void RToSReshardFunctionCrossMesh::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call RToSReshardFunctionCrossMesh Eval";
const auto& in_dist_attr = in.dist_attr();

DistTensor tmp_result;
TensorDistAttr in_dist_attr_shard = in_dist_attr;
in_dist_attr_shard.set_dims_mapping(out_dist_attr.dims_mapping());
RToSReshardFunction r_to_s_func;
PADDLE_ENFORCE(
r_to_s_func.IsSuitable(in, in_dist_attr_shard),
phi::errors::InvalidArgument(
"Invoke the r to s reshard function is not valid from %s to %s.",
tmp_result.dist_attr(),
out_dist_attr));
r_to_s_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result);

// Step 2: Same status from the input mesh to output mesh
SameStatusReshardFunction same_status_func;
PADDLE_ENFORCE(
same_status_func.IsSuitable(tmp_result, out_dist_attr),
phi::errors::InvalidArgument("Invoke the same status reshard function "
"is not valid from %s to %s.",
tmp_result.dist_attr(),
out_dist_attr));
same_status_func.Eval(dev_ctx, tmp_result, out_dist_attr, out);
}

REGISTER_RESHARD_FUNC(RToSReshardFunction);
REGISTER_RESHARD_FUNC(RToSReshardFunctionCrossMesh);

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@ namespace distributed {

class RToSReshardFunction final : public ReshardFunction {
public:
RToSReshardFunction() = default;
~RToSReshardFunction() = default;
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
};

class RToSReshardFunctionCrossMesh final : public ReshardFunction {
public:
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

Expand Down
93 changes: 93 additions & 0 deletions test/auto_parallel/reshard_r_to_s_cross_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.base import core


class TestReshardRToSCrossMesh:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._shard = eval(os.getenv("shard"))
self._backend = os.getenv("backend")
self._in_mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
self._out_mesh = dist.ProcessMesh([1, 0], dim_names=["x"])

def run_test_case(self):
# cpu does not support send/recv
if self._backend == "cpu":
return
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())

dev_ctx = core.DeviceContext.create(place)

paddle.seed(self._seeds)
value = paddle.uniform(self._shape, self._dtype)

in_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs[self._shard] = "x"

dist_attr = dist.DistAttr(
mesh=self._in_mesh, sharding_specs=in_shard_specs
)
out_dist_attr = dist.DistAttr(
mesh=self._out_mesh, sharding_specs=out_shard_specs
)

input_tensor = dist.shard_tensor(value, dist_attr=dist_attr)

reshard_func = core.RToSReshardFunctionCrossMesh()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)

out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
out_shape = list(self._shape)

if out_shape[self._shard] % 2 == 0:
out_shape[self._shard] = out_shape[self._shard] // 2
split_shape = self._in_mesh.shape[0]
else:
split_shape = [
out_shape[self._shard] // 2 + 1,
out_shape[self._shard] // 2,
]
out_shape[self._shard] = (
split_shape[0] if dist.get_rank() == 1 else split_shape[1]
)

out_expected_local_tensor_list = paddle.split(
value, num_or_sections=split_shape, axis=self._shard
)

np.testing.assert_equal(
out._local_value().numpy(),
out_expected_local_tensor_list[0].numpy()
if dist.get_rank() == 1
else out_expected_local_tensor_list[1].numpy(),
)

assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out._local_shape, out_shape).all()


if __name__ == '__main__':
TestReshardRToSCrossMesh().run_test_case()
13 changes: 12 additions & 1 deletion test/auto_parallel/test_reshard_r_to_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"dtype": "float32",
"seeds": str(self._seeds),
"seeds": "2023",
}
self._changeable_envs = {
"shape": ["(10, 20)", "(5, 7)"],
Expand All @@ -40,6 +40,17 @@ def test_reshard_r_to_s(self):
user_defined_envs=envs,
)

def test_reshard_r_to_s_cross_mesh(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
if envs["backend"] != "cpu":
self.run_test_case(
"reshard_r_to_s_cross_mesh.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 11c62e6

Please sign in to comment.