Skip to content

Commit

Permalink
[AutoParallel] add create_mesh api (PaddlePaddle#58659)
Browse files Browse the repository at this point in the history
* [AutoParallel] add create_mesh api

* fix completion and partitioner

* revert

* format

* add ut
  • Loading branch information
zhaoyinglia authored Nov 10, 2023
1 parent 203cd6a commit 80ceec3
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 23 deletions.
7 changes: 1 addition & 6 deletions paddle/fluid/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,14 @@ std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"stream_priority",
"scheduling_priority"};

OperatorDistAttr::OperatorDistAttr(const OpDesc& op) {
VLOG(4) << "[OperatorDistAttr constructor] op type: " << op.Type();
initialize(&op);
}
OperatorDistAttr::OperatorDistAttr(const OpDesc& op) { initialize(&op); }

OperatorDistAttr::OperatorDistAttr(const OperatorDistAttr& dist_attr) {
VLOG(4) << "[OperatorDistAttr copy constructor]";
copy_from(dist_attr);
}

OperatorDistAttr& OperatorDistAttr::operator=(
const OperatorDistAttr& dist_attr) {
VLOG(4) << "[OperatorDistAttr assign constructor]";
if (this == &dist_attr) return *this;
OperatorDistAttr tmp(dist_attr);
std::swap(this->input_dist_attrs_, tmp.input_dist_attrs_);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from .interface import recompute
from .interface import exclude_ops_in_recompute
from .interface import fetch
from .interface import create_mesh
from .interface import get_mesh
from .random import parallel_manual_seed

__all__ = []
30 changes: 30 additions & 0 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce
from typing import List, Tuple

import numpy as np

import paddle
from paddle.framework import core

Expand Down Expand Up @@ -315,3 +320,28 @@ def fetch(tensor, name=None, logging=False):
add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name)


_g_mesh = None


def get_mesh():
global _g_mesh
return _g_mesh


def create_mesh(mesh_dims: List[Tuple[str, int]]):
"""
Create a global process_mesh for auto parallel.
Args:
mesh_dims (list[tuple[str, int]]): A list of tuple, each element is (dim_name, dim_degree).
"""
global _g_mesh
dim_names = [mesh_dim[0] for mesh_dim in mesh_dims]
mesh_shape = [mesh_dim[1] for mesh_dim in mesh_dims]
mesh_arr = np.arange(0, reduce(lambda x, y: x * y, mesh_shape, 1)).reshape(
mesh_shape
)
_g_mesh = ProcessMesh(mesh_arr, dim_names)
return _g_mesh
18 changes: 18 additions & 0 deletions python/paddle/distributed/auto_parallel/process_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,24 @@ def __getitem__(self, index):
else:
return ProcessMesh([new_mesh])

def get_dim_size(self, dim_name):
assert dim_name in self._dim_names
return self._shape[self._dim_names.index(dim_name)]

def get_mesh_with_dim(self, dim_name):
assert (
dim_name in self._dim_names
), f'{dim_name} is not a valid dim name.'
index_axis = self._dim_names.index(dim_name)
new_order = [index_axis] + [
i for i in range(len(self._dim_names)) if i != index_axis
]
new_dim_names = [dim_name] + [
dim for dim in self._dim_names if dim != dim_name
]
new_mesh = self._mesh.transpose(new_order)
return ProcessMesh(new_mesh, new_dim_names)

def __enter__(self):
set_current_process_mesh(self)
default_prog = paddle.static.default_main_program()
Expand Down
32 changes: 15 additions & 17 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,11 +938,10 @@ def fit(
"""
self._mode = 'train'

self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size
)

if not self._has_prepared[self._mode]:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size
)
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
Expand Down Expand Up @@ -1123,11 +1122,11 @@ def evaluate(
"""
self._mode = 'eval'
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size
)

if not self._has_prepared[self._mode]:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size
)
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
Expand Down Expand Up @@ -1257,11 +1256,11 @@ def predict(
>>> engine.predict(valid_dataset, batch_size=64)
"""
self._mode = 'predict'
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size
)

if not self._has_prepared[self._mode]:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size
)
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
Expand Down Expand Up @@ -1346,11 +1345,10 @@ def dataloader(
if mode is not None:
self.to_mode(mode)

self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)

if not self._has_prepared[self._mode]:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
Expand Down Expand Up @@ -1391,11 +1389,11 @@ def dataloader_from_generator(
):
if mode is not None:
self.to_mode(mode)
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)

if not self._has_prepared[self._mode]:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
Expand Down
23 changes: 23 additions & 0 deletions test/auto_parallel/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import unittest

import numpy as np

import paddle
import paddle.nn.functional as F
from paddle import nn, static
Expand Down Expand Up @@ -241,6 +243,27 @@ def test_api(self):
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))

def test_create_mesh(self):
arr = np.arange(32).reshape([2, 4, 4])
auto.create_mesh([('dp', 2), ('pp', 4), ('mp', 4)])
self.assertEqual(auto.get_mesh().shape, [2, 4, 4])
self.assertEqual(auto.get_mesh().get_dim_size('dp'), 2)
self.assertEqual(auto.get_mesh().get_dim_size('pp'), 4)
self.assertEqual(auto.get_mesh().get_dim_size('mp'), 4)
self.assertEqual(auto.get_mesh().process_ids, list(np.arange(32)))

first_pp_mesh = auto.get_mesh().get_mesh_with_dim("pp")
self.assertEqual(first_pp_mesh.shape, [4, 2, 4])
self.assertEqual(
first_pp_mesh.process_ids, list(arr.transpose([1, 0, 2]).flatten())
)

pp_stage_0_mesh = first_pp_mesh[0]
self.assertEqual(pp_stage_0_mesh.shape, [2, 4])
self.assertEqual(
pp_stage_0_mesh.process_ids, [0, 1, 2, 3, 16, 17, 18, 19]
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 80ceec3

Please sign in to comment.