Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Support shard parameter #57278

Merged
merged 4 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ void InitDistTensorWithTensor(TensorObject* self,
paddle::platform::errors::InvalidArgument(
"DistTensor can only initialize by DenseTensor"));
self->tensor.set_name(name);
VLOG(4) << "Do TensorCopy from DenseTensor to DistTensor.";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这条日志应该可以移除,后续PR再移除

if (place == src.place()) {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(src.impl());
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2904,6 +2904,18 @@ static PyObject* tensor_is_contiguous(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method__set_impl(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(4) << "Running in tensor_method__set_impl: set Tensor impl form the "
"other Tensor.";
auto tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
self->tensor.set_impl(tensor.impl());
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}

#if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self,
PyObject* args,
Expand Down Expand Up @@ -3199,6 +3211,10 @@ PyMethodDef variable_methods[] = { // NOLINT
(PyCFunction)(void (*)(void))tensor_method_strides,
METH_VARARGS | METH_KEYWORDS,
tensor_get_strides__doc__},
{"_set_impl",
(PyCFunction)(void (*)(void))tensor_method__set_impl,
METH_VARARGS | METH_KEYWORDS,
nullptr},
#if defined(PADDLE_WITH_CUDA)
{"_tensor_uva",
(PyCFunction)(void (*)())tensor_method__uva,
Expand Down
16 changes: 16 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7376,6 +7376,22 @@ def __init__(self, shape, dtype, **kwargs):
self._init_func = None
self._init_op_creator = None

@classmethod
def from_tensor(cls, tensor, **kwargs):
# 1. construct EagerParamBase
param = cls(tensor.shape, tensor.dtype, **kwargs)

# 2. transform data if needed
dist_attr = kwargs.get('dist_attr', None)
src_tensor = tensor
if dist_attr is not None:
src_tensor = core.eager.Tensor(tensor, dist_attr=dist_attr)

# 3. set param data
param._set_impl(src_tensor)

return param

def set_init_func(self, obj):
self._init_func = obj

Expand Down
15 changes: 11 additions & 4 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import paddle
from paddle.base.framework import EagerParamBase
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
Expand Down Expand Up @@ -127,19 +128,25 @@ def shard_tensor(
"""
# 1. create dense tensor
# `paddle.to_tensor` supports both dynamic and static mode
data = paddle.to_tensor(data)
tensor = paddle.to_tensor(data)

# 2. create dist tensor
assert len(dist_attr.dims_mapping) == len(
list(data.shape)
list(tensor.shape)
), "The length of sharding_specs must be same as the shape of the input tensor."

if paddle.in_dynamic_mode():
return paddle.Tensor(data, dist_attr=dist_attr)
# here the dist tensor is deep copy constructed
if isinstance(data, EagerParamBase):
return EagerParamBase.from_tensor(
tensor, dist_attr=dist_attr, **tensor.__dict__
)
else:
return paddle.Tensor(tensor, dist_attr=dist_attr)
else:
# TODO(zhiqiu): we need to refine the static shard_tensor
return shard_tensor_static(
data, dist_attr.process_mesh, dist_attr.sharding_specs
tensor, dist_attr.process_mesh, dist_attr.sharding_specs
)


Expand Down
30 changes: 30 additions & 0 deletions test/auto_parallel/test_shard_tensor_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import unittest

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.base.dygraph.base import switch_to_static_graph
Expand Down Expand Up @@ -131,5 +133,33 @@ def func():
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))


class DemoNet(paddle.nn.Layer):
def __init__(self, dist_attr):
super().__init__()
self.w0 = dist.shard_tensor(
self.create_parameter(shape=[784, 784]), dist_attr=dist_attr
)

def forward(self, x):
return paddle.matmul(x, self.w0)


class TestShardTensorParameter(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
self.dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=[None, None]
)

def test_shard_parameter(self):
x = np.random.random(size=[16, 784]).astype("float32")
dist_x = dist.shard_tensor(x, dist_attr=self.dist_attr)
net = DemoNet(self.dist_attr)
out = net(dist_x)
self.assertEqual(out.shape, [16, 784])
self.assertEqual(out.is_dist(), True)
self.assertEqual(out.dist_attr, self.dist_attr)


if __name__ == "__main__":
unittest.main()