-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support shard to partial reshard function.
- Loading branch information
Showing
8 changed files
with
232 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h" | ||
|
||
#include "glog/logging.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" | ||
#include "paddle/phi/kernels/reduce_scatter_kernel.h" | ||
#include "paddle/phi/kernels/transpose_kernel.h" | ||
|
||
namespace phi { | ||
namespace distributed { | ||
|
||
bool SToPReshardFunction::IsSuitable(const DistTensor& in, | ||
const TensorDistAttr& out_dist_attr) { | ||
const auto& in_dist_attr = in.dist_attr(); | ||
|
||
RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); | ||
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_partial()); | ||
|
||
const auto& in_process_mesh = in_dist_attr.process_mesh(); | ||
const auto& out_process_mesh = out_dist_attr.process_mesh(); | ||
|
||
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); | ||
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); | ||
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); | ||
|
||
return true; | ||
} | ||
|
||
void SToPReshardFunction::Eval(DeviceContext* dev_ctx, | ||
const DistTensor& in, | ||
const TensorDistAttr& out_dist_attr, | ||
DistTensor* out) { | ||
VLOG(3) << "Call SToPReshardFunction Eval"; | ||
|
||
// step 1, create tmp dist attr and tmp dist tensor | ||
TensorDistAttr tmp_attr(out_dist_attr); | ||
DistTensor tmp_tensor; | ||
tmp_attr.clean_partial_status(); | ||
|
||
// step 2, do s to r reshard on `in` to `tmp` | ||
SToRReshardFunction s_to_r; | ||
s_to_r.Eval(dev_ctx, in, tmp_attr, &tmp_tensor); | ||
|
||
// step 3, do r to p reshard on `tmp` to `out` | ||
RToPReshardFunction r_to_p; | ||
r_to_p.Eval(dev_ctx, tmp_tensor, out_dist_attr, out); | ||
} | ||
|
||
REGISTER_RESHARD_FUNC(SToPReshardFunction); | ||
|
||
} // namespace distributed | ||
} // namespace phi |
35 changes: 35 additions & 0 deletions
35
paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" | ||
|
||
namespace phi { | ||
namespace distributed { | ||
|
||
class SToPReshardFunction final : public ReshardFunction { | ||
public: | ||
bool IsSuitable(const DistTensor& in, | ||
const TensorDistAttr& out_dist_attr) override; | ||
|
||
void Eval(DeviceContext* dev_ctx, | ||
const DistTensor& in, | ||
const TensorDistAttr& out_dist_attr, | ||
DistTensor* out) override; | ||
|
||
std::string Name() override { return "SToPReshard"; } | ||
}; | ||
|
||
} // namespace distributed | ||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddle.base import core | ||
|
||
|
||
class TestReshardSToR: | ||
def __init__(self): | ||
self._shape = eval(os.getenv("shape")) | ||
self._dtype = os.getenv("dtype") | ||
self._seeds = eval(os.getenv("seeds")) | ||
self._backend = os.getenv("backend") | ||
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
|
||
def run_test_case(self): | ||
if self._backend == "cpu": | ||
paddle.set_device("cpu") | ||
place = paddle.CPUPlace() | ||
elif self._backend == "gpu": | ||
place = paddle.CUDAPlace(dist.get_rank()) | ||
|
||
dev_ctx = core.DeviceContext.create(place) | ||
a = paddle.ones(self._shape) | ||
|
||
in_shard_specs = [None for i in range(len(self._shape))] | ||
in_shard_specs[0] = "x" | ||
out_shard_specs = [None for i in range(len(self._shape))] | ||
|
||
dist_attr = dist.DistAttr( | ||
mesh=self._mesh, sharding_specs=in_shard_specs | ||
) | ||
out_dist_attr = dist.DistAttr( | ||
mesh=self._mesh, sharding_specs=out_shard_specs | ||
) | ||
out_dist_attr._set_partial_dims([0]) | ||
|
||
input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) | ||
|
||
assert input_tensor._local_shape[0] == self._shape[0] // 2 | ||
|
||
reshard_func = core.SToPReshardFunction() | ||
assert reshard_func.is_suitable(input_tensor, out_dist_attr) | ||
|
||
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) | ||
|
||
if dist.get_rank() == 0: | ||
np.testing.assert_equal(out._local_value().numpy(), a.numpy()) | ||
else: | ||
zeros = paddle.zeros(self._shape) | ||
np.testing.assert_equal(out._local_value().numpy(), zeros.numpy()) | ||
|
||
assert np.equal(out.shape, input_tensor.shape).all() | ||
assert np.equal(out._local_shape, input_tensor.shape).all() | ||
|
||
|
||
if __name__ == '__main__': | ||
TestReshardSToR().run_test_case() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import collective.test_communication_api_base as test_base | ||
|
||
|
||
class TestReshardSToP(test_base.CommunicationTestDistBase): | ||
def setUp(self): | ||
super().setUp(num_of_devices=2, timeout=120) | ||
self._default_envs = { | ||
"shape": "(10, 20)", | ||
"dtype": "float32", | ||
"seeds": "1234", | ||
} | ||
self._changeable_envs = { | ||
"backend": ["cpu", "gpu"], | ||
} | ||
|
||
def test_reshard_s_to_r(self): | ||
envs_list = test_base.gen_product_envs_list( | ||
self._default_envs, self._changeable_envs | ||
) | ||
for envs in envs_list: | ||
self.run_test_case( | ||
"reshard_s_to_p.py", | ||
user_defined_envs=envs, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |