From 1350d6407227df3e22db50d4f9111c6e7cf061b7 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 26 Oct 2023 14:10:25 +0800 Subject: [PATCH] [AutoParallel] Support disttensor for Tensor.copy_ (#58369) * support disttensor for tensor.copy_ --- paddle/phi/api/lib/tensor_method.cc | 66 +++++++++++++++++-- .../semi_auto_parallel_recompute.py | 57 ++++++++++++++++ .../semi_auto_parallel_simple_net.py | 35 ++++++++++ ...test_semi_auto_parallel_single_strategy.py | 10 +++ 4 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_recompute.py diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 74ee1e380dcc4a..deebdbe0019ee7 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -27,7 +27,11 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" // clang-format off - +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/api/lib/data_transform.h" +#endif namespace paddle { namespace experimental { // declare cast api @@ -87,9 +91,7 @@ void Tensor::copy_(const Tensor &src, VLOG(8) << "Src is empty, skip copy"; return; } - // Prepare copy kernel key and outputs - auto kernel_key_set = ParseKernelKeyByInputArgs(src); - KernelType kernel_type = ParseKernelTypeByInputArgs(src); + VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name(); if (initialized()) { PADDLE_ENFORCE_EQ(dtype(), @@ -114,6 +116,12 @@ void Tensor::copy_(const Tensor &src, "Copy cannot be performed!", target_place, place())); + } + + // Prepare copy kernel key and outputs + auto kernel_key_set = ParseKernelKeyByInputArgs(src); + KernelType kernel_type = ParseKernelTypeByInputArgs(src); + if (initialized()) { kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place())); } else { @@ -129,6 +137,56 @@ void Tensor::copy_(const Tensor &src, place.GetType() == target_place.GetType() ? target_place : place); if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { +#ifdef PADDLE_WITH_DISTRIBUTE + bool run_auto_parallel = AllInputsAreDistTensor(src); + bool rank_is_in_current_mesh = false; + if (run_auto_parallel) { + auto mesh = std::static_pointer_cast( + src.impl())->dist_attr().process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + // 1. InferSpmd (Infer DistAttr of Inputs&Outputs) + auto meta_dist_input_x = MakeDistMetaTensor(*src.impl()); + + // 2. Create API Output & Prepare Dist and Dense Output + auto dist_out = SetKernelDistOutput(this, meta_dist_input_x.dist_attr()); + auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) { + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, + 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + } + + // 3. Infer DistTensor's Global Shape + phi::MetaTensor meta_dist_out(dist_out); + phi::UnchangedInferMeta(MakeMetaTensor(*(src.impl_)), &meta_dist_out); + + if (rank_is_in_current_mesh) { + // 4. Select Kernel + + // 5. Reshard Input + auto dist_input_x = static_cast( + src.impl().get());; + + // 6. PrepareData (DataTransform & Prepare Dense Input) + auto input_x = &dist_input_x->value(); + + // 7. Infer Local DenseTensor Meta + phi::MetaTensor meta_dense_out(dense_out); + phi::UnchangedInferMeta(MakeMetaTensor(*input_x), &meta_dense_out); + + // 8. DenseTensor Kernel Call + phi::Copy(*dev_ctx, *input_x, target_place, blocking, dense_out); + + // 9. Reshard Partial Output to Replicated (Temporary) + } + + // 10. Set Output Dist Attr For Default Impl + // API `copy_` does not need to set DistAttr for output. + return; + } +#endif SetKernelOutput(this); phi::MetaTensor meta_out(impl_.get()); phi::UnchangedInferMeta( diff --git a/test/auto_parallel/semi_auto_parallel_recompute.py b/test/auto_parallel/semi_auto_parallel_recompute.py new file mode 100644 index 00000000000000..7329a1f4d0bafb --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_recompute.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from semi_auto_parallel_simple_net import MPDemoNetRecompute + +import paddle +import paddle.distributed as dist +from paddle import nn + +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +def run_dynamic(layer, image, label): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(image) + image.stop_gradient = False + out = layer(image) + + label = paddle.to_tensor(label) + loss = loss_fn(out, label) + + loss.backward() + return loss, layer.w0.grad, layer.w1.grad + + +class TestSemiAutoParallelRecompute: + def test_recompute(): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype('float32') + label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32') + w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32') + w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32') + run_dynamic( + layer=MPDemoNetRecompute(w0, w1, mesh), image=image, label=label + ) + + +if __name__ == "__main__": + TestSemiAutoParallelRecompute.test_recompute() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net.py b/test/auto_parallel/semi_auto_parallel_simple_net.py index df5fbbaa6fae9f..3e37085d3886d8 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net.py @@ -20,6 +20,7 @@ import paddle.distributed as dist import paddle.nn.functional as F from paddle import nn +from paddle.distributed.fleet.utils import recompute BATCH_SIZE = 16 BATCH_NUM = 4 @@ -120,6 +121,40 @@ def forward(self, x): return out +class MPDemoNetRecompute(nn.Layer): + def __init__(self, np_w0, np_w1, mesh, param_suffix=""): + super().__init__() + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="mp_demo_weight_1" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w0), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']), + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="mp_nemo_weight_2" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w1), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]), + ) + + def _inner_forward_fn(self, x): + y = paddle.matmul(x, self.w0) + z = paddle.matmul(y, self.w1) + return z + + def forward(self, x): + z = recompute(self._inner_forward_fn, x) + return z + + class PPDemoNet(nn.Layer): def __init__(self, np_w0, np_w1, mesh0, mesh1): super().__init__() diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index b74fa38a500bf5..453d2f27963c7e 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -61,6 +61,16 @@ def test_simple_net_single_strategy_with_gradient_merge(self): user_defined_envs=envs, ) + def test_simple_net_recompute(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_recompute.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main()