diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index f5bcd6d17b6dd9..b1ec85afe1a866 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -33,6 +33,7 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" @@ -173,6 +174,10 @@ void BindAutoParallel(py::module *m) { *m, "SToSReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "SameNdMeshReshardFunction", ReshardFunction) + .def(py::init<>()); + py::class_(*m, "ProcessMesh") .def(py::init<>()) .def(py::init &, diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 2d8dff6adb245f..038888196f09f1 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -16,4 +16,5 @@ collect_srcs( s_to_r_reshard_function.cc r_to_p_reshard_function.cc p_to_r_reshard_function.cc - s_to_s_reshard_function.cc) + s_to_s_reshard_function.cc + nd_mesh_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc new file mode 100644 index 00000000000000..c28fbe3acbd7de --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc @@ -0,0 +1,269 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h" + +#include "glog/logging.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" + +namespace phi { +namespace distributed { + +namespace { +ProcessMesh GetSubProcessMesh(const ProcessMesh& mesh, int64_t axis) { + int64_t shape_of_axis = mesh.dim_size(axis); + std::vector shape = {shape_of_axis}; + std::vector dim_names = {mesh.dim_names()[axis]}; + std::vector coord = GetCurRankCoordInMesh(mesh); + + std::vector process_ids; + for (int64_t i = 0; i < shape_of_axis; ++i) { + coord[axis] = i; + int64_t rank = coord.back(); + for (int64_t j = coord.size() - 2; j >= 0; --j) { + rank += coord[j] * mesh.dim_size(j + 1); + } + process_ids.emplace_back(rank); + } + + std::sort(process_ids.begin(), process_ids.end()); + ProcessMesh out_mesh(shape, process_ids, dim_names); + return out_mesh; +} + +// Given the input two dist_attr, traversing from high-dimension axis to +// low-dimension. Find and return the first different axis which is shard status +// between these two. For example, the input two dims_mapping are [-1, 0, -1, +// -1] and [-1, -1, 0, -1], the first diff shard axis is 2. +int64_t FindFirstDiffShardAxis(const TensorDistAttr& in_dist_attr, + const TensorDistAttr& out_dist_attr) { + const auto& in_dims_mapping = in_dist_attr.dims_mapping(); + const auto& out_dims_mapping = out_dist_attr.dims_mapping(); + int64_t axis = -1; + + for (int64_t i = in_dims_mapping.size() - 1; i >= 0; --i) { + if (in_dims_mapping[i] != out_dims_mapping[i]) { + axis = i; + break; + } + } + + // const auto& in_partial_status = in_dist_attr.partial_status(); + // const auto& out_partial_status = out_dist_attr.partial_status(); + // for (const auto& kv : in_partial_status) { + // if (out_partial_status.count(kv.first) == 0 && kv.first > axis) { + // axis = kv.first; + // } + // } + // for (const auto& kv : out_partial_status) { + // if (in_partial_status.count(kv.first) == 0 && kv.first > axis) { + // axis = kv.first; + // } + // } + + return axis; +} + +} // namespace + +bool SameNdMeshReshardFunction::IsSuitable( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + bool flag = true; + // const auto& in_dist_attr = in.dist_attr(); + + flag &= (in.dist_attr().process_mesh() == out_dist_attr.process_mesh()); + flag &= (out_dist_attr.process_mesh().ndim() > 1); + // check the input and output dims_mapping is not equal + flag &= in.dist_attr() != out_dist_attr; + + return flag; +} + +void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + const auto& in_dist_attr = in.dist_attr(); + const auto& process_mesh = out_dist_attr.process_mesh(); + + int64_t first_diff_axis = FindFirstDiffShardAxis(in_dist_attr, out_dist_attr); + + SetValue(out, in.value()); + SetDistProps(out, in.dims(), in_dist_attr); + + // 1. change all the partial status to replicated status if needed + if (in_dist_attr.is_partial()) { + const auto& in_partial_status = in_dist_attr.partial_status(); + const auto& out_partial_status = out_dist_attr.partial_status(); + for (const auto& kv : in_partial_status) { + if (out_partial_status.count(kv.first) != 0) { + continue; + } + VLOG(0) << "Step1: partial axis " << kv.first; + // 1.1 Calculate the dist_attr after this transform + TensorDistAttr real_out_dist_attr(out->dist_attr()); + real_out_dist_attr.clean_partial_dims({kv.first}); + + // 1.2 Calculate the process_mesh on specific axis + ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, kv.first); + VLOG(0) << "Process mesh " << sub_mesh; + + // 1.3 Calculate the input one dim dist attr + TensorDistAttr in_one_dim_dist_attr(vectorize(in.dims())); + in_one_dim_dist_attr.set_process_mesh(sub_mesh); + in_one_dim_dist_attr.set_partial_status(std::vector{0}); + + // 1.4 Calculate the output one dim dist attr + TensorDistAttr out_one_dim_dist_attr(vectorize(in.dims())); + out_one_dim_dist_attr.set_process_mesh(sub_mesh); + + // 1.5 Change from partial to replicated + SetDistProps(out, in_one_dim_dist_attr); + + DistTensor tmp_result; + PToRReshardFunction func; + func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result); + + // 1.6 Reset to the right dist attr + SetValue(out, tmp_result.value()); + SetDistProps(out, real_out_dist_attr); + } + } + + // 2. change all the shard status to replicated status + for (int64_t i = first_diff_axis; i >= 0; --i) { + int64_t in_mesh_axis = out->dist_attr().dims_mapping()[i]; + if (in_mesh_axis != -1) { + VLOG(0) << "Step2: in_mesh axis " << in_mesh_axis; + // 2.1 Calculate the dist_attr after this transform + TensorDistAttr real_out_dist_attr(out->dist_attr()); + std::vector real_dims_mapping = + real_out_dist_attr.dims_mapping(); + real_dims_mapping[i] = -1; + real_out_dist_attr.set_dims_mapping(real_dims_mapping); + + // 2.2 Calculate the process_mesh on specific axis + ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, in_mesh_axis); + + // 2.3 Calculate the input one dim dist attr + TensorDistAttr in_one_dim_dist_attr(vectorize(in.dims())); + in_one_dim_dist_attr.set_process_mesh(sub_mesh); + std::vector in_one_dims_mapping = + in_one_dim_dist_attr.dims_mapping(); + in_one_dims_mapping[i] = 0; + in_one_dim_dist_attr.set_dims_mapping(in_one_dims_mapping); + + // 2.4 Calculate the output one dim dist attr + TensorDistAttr out_one_dim_dist_attr(vectorize(in.dims())); + out_one_dim_dist_attr.set_process_mesh(sub_mesh); + + // 2.5 Change from shard to replicated + SetDistProps(out, in_one_dim_dist_attr); + DistTensor tmp_result; + SToRReshardFunction func; + func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result); + + // 2.6 Reset to the right dist attr + SetValue(out, tmp_result.value()); + SetDistProps(out, real_out_dist_attr); + } + } + + // 3. Change replicated to partial + if (out_dist_attr.is_partial()) { + const auto& in_partial_status = out->dist_attr().partial_status(); + const auto& out_partial_status = out_dist_attr.partial_status(); + for (const auto& kv : out_partial_status) { + if (in_partial_status.count(kv.first) != 0) { + continue; + } + VLOG(0) << "Step3: Partial status mesh axis " << kv.first; + // 3.1 Calculate the dist_attr after this transform + TensorDistAttr real_out_dist_attr(out->dist_attr()); + real_out_dist_attr.set_partial_status(std::vector{kv.first}); + + // 3.2 Calculate the process_mesh on specific axis + ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, kv.first); + + // 3.3 Calculate the input one dim dist attr + TensorDistAttr in_one_dim_dist_attr(vectorize(in.dims())); + in_one_dim_dist_attr.set_process_mesh(sub_mesh); + + // 3.4 Calculate the output one dim dist attr + TensorDistAttr out_one_dim_dist_attr(vectorize(in.dims())); + out_one_dim_dist_attr.set_process_mesh(sub_mesh); + out_one_dim_dist_attr.set_partial_status(std::vector{0}); + + // 3.5 Change from partial to replicated + DistTensor tmp_result; + SetDistProps(out, in_one_dim_dist_attr); + RToPReshardFunction func; + func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result); + + // 3.6 Reset to the right dist attr + SetValue(out, tmp_result.value()); + SetDistProps(out, real_out_dist_attr); + } + } + + // 4. Change replicated to shard + for (int64_t i = first_diff_axis; i >= 0; --i) { + int64_t out_mesh_axis = out_dist_attr.dims_mapping()[i]; + if (out_mesh_axis != -1) { + // 4.1 Calculate the dist_attr after this transform + TensorDistAttr real_out_dist_attr(out->dist_attr()); + std::vector real_dims_mapping = + real_out_dist_attr.dims_mapping(); + real_dims_mapping[i] = out_mesh_axis; + real_out_dist_attr.set_dims_mapping(real_dims_mapping); + + // 4.2 Calculate the process_mesh on specific axis + ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, out_mesh_axis); + + // 4.3 Calculate the input one dim dist attr + TensorDistAttr in_one_dim_dist_attr(vectorize(in.dims())); + in_one_dim_dist_attr.set_process_mesh(sub_mesh); + + // 4.4 Calculate the output one dim dist attr + TensorDistAttr out_one_dim_dist_attr(vectorize(in.dims())); + out_one_dim_dist_attr.set_process_mesh(sub_mesh); + std::vector out_one_dims_mapping = + out_one_dim_dist_attr.dims_mapping(); + out_one_dims_mapping[i] = 0; + out_one_dim_dist_attr.set_dims_mapping(out_one_dims_mapping); + + // 4.5 Change from replicated to shard + DistTensor tmp_result; + SetDistProps(out, in_one_dim_dist_attr); + RToSReshardFunction func; + func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result); + + // 4.6 Reset to the right dist attr + SetValue(out, tmp_result.value()); + SetDistProps(out, real_out_dist_attr); + } + } +} + +REGISTER_RESHARD_FUNC(SameNdMeshReshardFunction); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h new file mode 100644 index 00000000000000..e47cb46138f7bf --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" + +namespace phi { +namespace distributed { + +class SameNdMeshReshardFunction final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc index d8681218ae68a4..bd2cb4c58a46c1 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc @@ -54,20 +54,8 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, RESHARD_FUNCTOR(dev_ctx, Full, in.dtype(), shape, 0, GetMutableTensor(out)); } else { // assign the input value to output - if (phi::CPUContext::classof(dev_ctx)) { - Assign(static_cast(*dev_ctx), - in.value(), - GetMutableTensor(out)); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - } else if (phi::GPUContext::classof(dev_ctx)) { - Assign(static_cast(*dev_ctx), - in.value(), - GetMutableTensor(out)); -#endif - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "The assign in reshard only supported on CPU and GPU for now.")); - } + RESHARD_FUNCTOR_WITHOUT_DTYPE( + dev_ctx, Assign, in.value(), GetMutableTensor(out)); } SetDistProps(out, in.dims(), out_dist_attr); } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc index 63044549a2e370..b8e355e689caea 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc @@ -45,6 +45,16 @@ void ReshardFunction::SetDistProps(DistTensor* tensor, tensor->dist_attr_ = dist_attr; } +void ReshardFunction::SetDistProps(DistTensor* tensor, + const TensorDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(tensor->dims())), + true, + phi::errors::InvalidArgument( + "The input dist_attr and dims are improper.")); + + tensor->dist_attr_ = dist_attr; +} + DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) { return &tensor->value_; } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard_function.h index 48d9fe64eabcc8..dd51768053bbf6 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.h @@ -48,6 +48,7 @@ class ReshardFunction { void SetDistProps(DistTensor* tensor, const DDim& dims, const TensorDistAttr& dist_attr); + void SetDistProps(DistTensor* tensor, const TensorDistAttr& dist_attr); DenseTensor* GetMutableTensor(DistTensor* tensor); }; diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 8857dc530f9477..344582d2e470f9 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -100,6 +100,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_reshard_p_to_r MODULES test_reshard_p_to_r) set_tests_properties(test_reshard_p_to_r PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_reshard_nd_mesh MODULES test_reshard_nd_mesh) + set_tests_properties(test_reshard_nd_mesh + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_semi_auto_parallel_basic MODULES test_semi_auto_parallel_basic) set_tests_properties(test_semi_auto_parallel_basic diff --git a/test/auto_parallel/reshard_nd_mesh.py b/test/auto_parallel/reshard_nd_mesh.py new file mode 100644 index 00000000000000..dfd43ccdd644be --- /dev/null +++ b/test/auto_parallel/reshard_nd_mesh.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.framework import core + + +class TestReshardNdMesh: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._backend = os.getenv("backend") + self._mesh = dist.ProcessMesh( + [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"] + ) + + def test_shard_partial_to_shard_replicated(self, dev_ctx): + paddle.seed(self._seeds) + value = paddle.uniform(self._shape, self._dtype) + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[0] = "y" + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + dist_attr._set_partial_dims([0]) + input_tensor = dist.shard_tensor(value, dist_attr=dist_attr) + + # check the shape of input tensor + in_expected_shape = list(self._shape) + in_expected_shape[0] = in_expected_shape[0] // self._mesh.shape[1] + assert np.equal(input_tensor._local_shape, in_expected_shape).all() + + # check the value of input tensor + in_expected_local_tensor_list = paddle.split( + value, num_or_sections=self._mesh.shape[1], axis=0 + ) + index = dist.get_rank() % self._mesh.shape[1] + if dist.get_rank() // self._mesh.shape[1] == 0: + np.testing.assert_equal( + input_tensor._local_value().numpy(), + in_expected_local_tensor_list[index].numpy(), + ) + else: + zeros = paddle.zeros(in_expected_shape) + np.testing.assert_equal( + input_tensor._local_value().numpy(), zeros.numpy() + ) + + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[0] = "y" + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + + reshard_func = core.SameNdMeshReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + np.testing.assert_equal( + out._local_value().numpy(), + in_expected_local_tensor_list[index].numpy(), + ) + + def test_shard_partial_to_replicated(self, dev_ctx): + paddle.seed(self._seeds) + value = paddle.uniform(self._shape, self._dtype) + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[0] = "y" + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + dist_attr._set_partial_dims([0]) + input_tensor = dist.shard_tensor(value, dist_attr=dist_attr) + + # check the shape of input tensor + in_expected_shape = list(self._shape) + in_expected_shape[0] = in_expected_shape[0] // self._mesh.shape[1] + assert np.equal(input_tensor._local_shape, in_expected_shape).all() + + # check the value of input tensor + in_expected_local_tensor_list = paddle.split( + value, num_or_sections=self._mesh.shape[1], axis=0 + ) + index = dist.get_rank() % self._mesh.shape[1] + if dist.get_rank() // self._mesh.shape[1] == 0: + np.testing.assert_equal( + input_tensor._local_value().numpy(), + in_expected_local_tensor_list[index].numpy(), + ) + else: + zeros = paddle.zeros(in_expected_shape) + np.testing.assert_equal( + input_tensor._local_value().numpy(), zeros.numpy() + ) + + out_shard_specs = [None for i in range(len(self._shape))] + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + + reshard_func = core.SameNdMeshReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + np.testing.assert_equal(out._local_value().numpy(), value.numpy()) + + def test_nd_p_to_p(self, dev_ctx): + a = paddle.ones(self._shape) + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + dist_attr._set_partial_dims([1]) + + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + out_dist_attr._set_partial_dims([0]) + + input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + if dist.get_rank() == 0 or dist.get_rank() == 2: + np.testing.assert_equal( + input_tensor._local_value().numpy(), a.numpy() + ) + else: + zeros = paddle.zeros(self._shape) + np.testing.assert_equal( + input_tensor._local_value().numpy(), zeros.numpy() + ) + + reshard_func = core.SameNdMeshReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + if dist.get_rank() == 0 or dist.get_rank() == 1: + np.testing.assert_equal(out._local_value().numpy(), a.numpy()) + else: + zeros = paddle.zeros(self._shape) + np.testing.assert_equal(out._local_value().numpy(), zeros.numpy()) + + assert np.equal(out.shape, input_tensor.shape).all() + assert np.equal(out._local_shape, input_tensor._local_shape).all() + + def test_nd_s_to_s(self, dev_ctx): + a = paddle.ones(self._shape) + in_shard_axis = 1 + out_shard_axis = 0 + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[in_shard_axis] = "y" + + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[out_shard_axis] = "x" + + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + + input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + in_expected_shape = list(self._shape) + in_expected_shape[in_shard_axis] = in_expected_shape[in_shard_axis] // 4 + assert np.equal(input_tensor._local_shape, in_expected_shape).all() + + reshard_func = core.SameNdMeshReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + out_expected_shape = list(self._shape) + out_expected_shape[out_shard_axis] = ( + out_expected_shape[out_shard_axis] // 2 + ) + assert np.equal(input_tensor._local_shape, in_expected_shape).all() + + assert np.equal(out.shape, input_tensor.shape).all() + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + + # self.test_nd_p_to_p(dev_ctx) + # self.test_nd_s_to_s(dev_ctx) + self.test_shard_partial_to_shard_replicated(dev_ctx) + self.test_shard_partial_to_replicated(dev_ctx) + + +if __name__ == '__main__': + TestReshardNdMesh().run_test_case() diff --git a/test/auto_parallel/test_reshard_nd_mesh.py b/test/auto_parallel/test_reshard_nd_mesh.py new file mode 100644 index 00000000000000..8df0af3a112d4e --- /dev/null +++ b/test/auto_parallel/test_reshard_nd_mesh.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestReshardNdMesh(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=8, timeout=120) + self._default_envs = { + "shape": "(12, 20, 8, 16)", + "dtype": "float32", + "seeds": "100", + } + self._changeable_envs = { + "backend": ["gpu"], + } + + def test_reshard_nd_mesh(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "reshard_nd_mesh.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main()