Skip to content

Commit

Permalink
[AutoParallel] Support disttensor for Tensor.copy_ (PaddlePaddle#58369)
Browse files Browse the repository at this point in the history
* support disttensor for tensor.copy_
  • Loading branch information
wanghuancoder authored Oct 26, 2023
1 parent e65abf5 commit 1350d64
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4 deletions.
66 changes: 62 additions & 4 deletions paddle/phi/api/lib/tensor_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ limitations under the License. */
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/infermeta/unary.h"
// clang-format off

#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#endif
namespace paddle {
namespace experimental {
// declare cast api
Expand Down Expand Up @@ -87,9 +91,7 @@ void Tensor::copy_(const Tensor &src,
VLOG(8) << "Src is empty, skip copy";
return;
}
// Prepare copy kernel key and outputs
auto kernel_key_set = ParseKernelKeyByInputArgs(src);
KernelType kernel_type = ParseKernelTypeByInputArgs(src);

VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name();
if (initialized()) {
PADDLE_ENFORCE_EQ(dtype(),
Expand All @@ -114,6 +116,12 @@ void Tensor::copy_(const Tensor &src,
"Copy cannot be performed!",
target_place,
place()));
}

// Prepare copy kernel key and outputs
auto kernel_key_set = ParseKernelKeyByInputArgs(src);
KernelType kernel_type = ParseKernelTypeByInputArgs(src);
if (initialized()) {
kernel_key_set.backend_set = kernel_key_set.backend_set |
BackendSet(phi::TransToPhiBackend(place()));
} else {
Expand All @@ -129,6 +137,56 @@ void Tensor::copy_(const Tensor &src,
place.GetType() == target_place.GetType() ? target_place : place);

if (kernel_type == KernelType::DENSE_TENSOR_KENREL) {
#ifdef PADDLE_WITH_DISTRIBUTE
bool run_auto_parallel = AllInputsAreDistTensor(src);
bool rank_is_in_current_mesh = false;
if (run_auto_parallel) {
auto mesh = std::static_pointer_cast<phi::distributed::DistTensor>(
src.impl())->dist_attr().process_mesh();
rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh);

// 1. InferSpmd (Infer DistAttr of Inputs&Outputs)
auto meta_dist_input_x = MakeDistMetaTensor(*src.impl());

// 2. Create API Output & Prepare Dist and Dense Output
auto dist_out = SetKernelDistOutput(this, meta_dist_input_x.dist_attr());
auto dense_out = dist_out->unsafe_mutable_value();
if (!rank_is_in_current_mesh) {
*dense_out = phi::DenseTensor(
std::make_shared<phi::Allocation>(nullptr,
0, phi::distributed::GetDefaultPlace()),
phi::DenseTensorMeta());
}

// 3. Infer DistTensor's Global Shape
phi::MetaTensor meta_dist_out(dist_out);
phi::UnchangedInferMeta(MakeMetaTensor(*(src.impl_)), &meta_dist_out);

if (rank_is_in_current_mesh) {
// 4. Select Kernel

// 5. Reshard Input
auto dist_input_x = static_cast<phi::distributed::DistTensor*>(
src.impl().get());;

// 6. PrepareData (DataTransform & Prepare Dense Input)
auto input_x = &dist_input_x->value();

// 7. Infer Local DenseTensor Meta
phi::MetaTensor meta_dense_out(dense_out);
phi::UnchangedInferMeta(MakeMetaTensor(*input_x), &meta_dense_out);

// 8. DenseTensor Kernel Call
phi::Copy(*dev_ctx, *input_x, target_place, blocking, dense_out);

// 9. Reshard Partial Output to Replicated (Temporary)
}

// 10. Set Output Dist Attr For Default Impl
// API `copy_` does not need to set DistAttr for output.
return;
}
#endif
SetKernelOutput(this);
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
Expand Down
57 changes: 57 additions & 0 deletions test/auto_parallel/semi_auto_parallel_recompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
from semi_auto_parallel_simple_net import MPDemoNetRecompute

import paddle
import paddle.distributed as dist
from paddle import nn

BATCH_SIZE = 16
BATCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10


def run_dynamic(layer, image, label):
# create loss
loss_fn = nn.MSELoss()
# run forward and backward
image = paddle.to_tensor(image)
image.stop_gradient = False
out = layer(image)

label = paddle.to_tensor(label)
loss = loss_fn(out, label)

loss.backward()
return loss, layer.w0.grad, layer.w1.grad


class TestSemiAutoParallelRecompute:
def test_recompute():
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype('float32')
label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32')
w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32')
w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32')
run_dynamic(
layer=MPDemoNetRecompute(w0, w1, mesh), image=image, label=label
)


if __name__ == "__main__":
TestSemiAutoParallelRecompute.test_recompute()
35 changes: 35 additions & 0 deletions test/auto_parallel/semi_auto_parallel_simple_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle.distributed as dist
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed.fleet.utils import recompute

BATCH_SIZE = 16
BATCH_NUM = 4
Expand Down Expand Up @@ -120,6 +121,40 @@ def forward(self, x):
return out


class MPDemoNetRecompute(nn.Layer):
def __init__(self, np_w0, np_w1, mesh, param_suffix=""):
super().__init__()
self.w0 = dist.shard_tensor(
self.create_parameter(
shape=[IMAGE_SIZE, IMAGE_SIZE],
attr=paddle.framework.ParamAttr(
name="mp_demo_weight_1" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w0),
),
),
dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']),
)
self.w1 = dist.shard_tensor(
self.create_parameter(
shape=[IMAGE_SIZE, CLASS_NUM],
attr=paddle.framework.ParamAttr(
name="mp_nemo_weight_2" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w1),
),
),
dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]),
)

def _inner_forward_fn(self, x):
y = paddle.matmul(x, self.w0)
z = paddle.matmul(y, self.w1)
return z

def forward(self, x):
z = recompute(self._inner_forward_fn, x)
return z


class PPDemoNet(nn.Layer):
def __init__(self, np_w0, np_w1, mesh0, mesh1):
super().__init__()
Expand Down
10 changes: 10 additions & 0 deletions test/auto_parallel/test_semi_auto_parallel_single_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def test_simple_net_single_strategy_with_gradient_merge(self):
user_defined_envs=envs,
)

def test_simple_net_recompute(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_recompute.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1350d64

Please sign in to comment.