Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add _init_parallel_env and _new_group #40579

Merged
merged 22 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 229 additions & 2 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import os
from datetime import timedelta
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable
from ..fluid.framework import OpProtoHolder
Expand Down Expand Up @@ -73,18 +74,21 @@ class ReduceOp:
MAX = 1
MIN = 2
PROD = 3
AVG = 4


class Group():
"""
The abstract representation of group.
"""

def __init__(self, rank, rank_num, id=0, ranks=[]):
def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
self.rank = rank
self.nranks = rank_num
self.id = id
self.ranks = ranks
self.pg = pg
self.name = name

def is_member(self):
if self.rank < 0:
Expand All @@ -99,11 +103,16 @@ def get_group_rank(self, rank):
else:
return -1

@property
def process_group(self):
return self.pg

def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += ". "
debug_str += "; name: "
debug_str += self.name if self.name else "None"
return debug_str


Expand All @@ -121,6 +130,17 @@ def _get_global_env():
# Dict[int, Group]
_group_map = {}

# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

_valid_backend_list = ['nccl', 'gloo', 'hccl']
_default_store = None # the default tcp store
_default_backend = None


def _get_group_map():
global _group_map
Expand All @@ -135,10 +155,29 @@ def _get_global_group():
return _get_group_map()[0]


def _get_group_map_by_name():
global _group_map_by_name
assert _default_group_name in _group_map_by_name, (
"Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment.")
return _group_map_by_name


def _get_default_group():
assert _default_group_name in _group_map_by_name, (
"Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment.")
return _get_group_map_by_name()[_default_group_name]


def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9)


def _new_group_name_id():
return len(_get_group_map_by_name()) + max(_get_global_env().nrings, 9)


def get_group(id=0):
"""

Expand All @@ -163,6 +202,194 @@ def get_group(id=0):
return gm[id] if id in gm else None


def _new_process_group_impl(backend, store, rank, world_size, group_name,
pg_options):
if backend == "gloo":
gloo_store = core.GlooStore(store)

pg = None
if backend == "gloo":
pg = core.ProcessGroupGloo(gloo_store, rank, world_size)
elif backend == "nccl":
pg = core.ProcessGroupNCCL(store, rank, world_size)
elif backend == "hccl":
pg = core.ProcessGroupHCCL(store, rank, world_size)

return pg


def _init_parallel_env(rank=None,
world_size=None,
backend="nccl",
timeout=timedelta(0),
pg_options=None):
"""

Initializes the default distributed environment.

Args:
rank (int, optional): the rank of the current process or device from 0 to world_size (exclusive).
If you launch your training with paddle.distributed.run or
paddle.distributed.launch module, None can be given. Default: None.
world_size (int, optional): total number of processes or devices.
If you launch your training with paddle.distributed.run or
paddle.distributed.launch module, None can be given. Default: None.
backend (str, optional): the name of the backend used to initialize
the distributed environment. The value can be one of 'nccl' for
GPU, 'gloo' for CPU or 'hccl' for NPU. Default: 'nccl'.
timeout (datetime.timedelta, optional): timeout used for operations of
the group. Default: datetime.timedelta(0) which means no timeout.
pg_options (dict, optional): options for the group. Default: None.

Returns:
Group: a group.

Examples:

.. code-block:: python

# filename: train.py
import paddle
paddle.distributed.init_parallel_env(0, 1)

# how to start
# python paddle.distributed.run --gpus="0,1" train.py

"""

global _group_map_by_name
global _default_group_name
assert _default_group_name not in _group_map_by_name, (
"The default distributed environment has been initialized.")

assert backend in _valid_backend_list, (
"Backend must be one of {}, but the given one is: {}".format(
_valid_backend_list, backend))
_default_backend = backend

assert isinstance(timeout, timedelta), (
"timeout must be of the type datetime.timedelta.")

if rank is None or world_size is None:
assert rank is None and world_size is None, (
"rank and world_size should be unset at the same time.")
trainer_id = os.getenv("PADDLE_TRAINER_ID", None)
trainer_num = os.getenv("PADDLE_TRAINERS_NUM", None)
if trainer_id is None or trainer_num is None:
warnings.warn("If rank and world_size are both None, please start "
"your training with paddle.distributed.run or "
"paddle.distributed.launch module. Otherwise, "
"init_parallel_env will do nothing.")
return None
rank = int(trainer_id)
world_size = int(trainer_num)

assert rank >= 0 and world_size > rank and world_size > 1, (
"rank must be non-negative and world_size must be the "
"maximum rank plus one. Moreover, at least two processes are "
"required to create a process group.")

master_addr = os.getenv("MASTER_ADDR", None)
master_port = os.getenv("MASTER_PORT", None)
if not master_addr or not master_port:
endpoints = os.getenv("PADDLE_MASTER", None)
if endpoints is None:
endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
if not endpoints:
raise ValueError(
"The environment variable 'MASTER_ADDR' and 'MASTER_PORT' "
"must be specified, for example 'export MASTER_ADDR=127.0.0.1' "
"and 'export MASTER_ADDR=54612'. Or you can start your training"
"with paddle.distributed.run or "
"paddle.distributed.luanch module.")
if ',' in endpoints:
endpoints = endpoints.split(',')[0]
master_addr, master_port = endpoints.split(":")

master_port = int(master_port)

is_master = rank == 0
global _default_store
_default_store = core.TCPStore(master_addr, master_port, is_master,
world_size, timeout)

pg = _new_process_group_impl(backend, _default_store, rank, world_size,
_default_group_name, pg_options)
ranks = list(range(world_size))
group = Group(
rank, world_size, id=0, ranks=ranks, pg=pg, name=_default_group_name)

paddle.fluid.dygraph.parallel_helper._set_parallel_ctx(True)
_group_map_by_name[_default_group_name] = group
return group


def _new_group(ranks=None,
backend=None,
group_name=None,
timeout=timedelta(0),
pg_options=None):
"""
Create a new process group.

Args:
ranks (list, optional): list of ranks for the new group. If None is given,
all processes is used. Default: None.
backend (str, optional): the name of the backend used to initialize
the distributed environment. Default: the one for init_parallel_env.
timeout (datetime.timedelta, optional): timeout used for operations of
the group. Default: datetime.timedelta(0).
pg_options (dict, optional): options for the group. Default: None.

Examples:

.. code-block:: python

import paddle
paddle.distributed.init_parallel_env(0, 1)
paddle.distributed.new_group([0, 1])

# how to start
# python paddle.distributed.run --gpus="0,1" train.py

"""
global _default_group_name
if group_name is None:
group_name = _default_group_name + str(_new_group_name_id())
if group_name == _default_group_name:
raise ValueError("group_name must be specified and it cannot be '{}' "
"which is used for the default process group created "
"by init_parallel_env.".format(_default_group_name))
global_group = _get_default_group()
global_rank = global_group.rank
global_ranks = global_group.ranks
if ranks is None:
ranks = global_ranks
assert len(ranks) <= len(global_ranks), (
"Size of new group must be less than or "
"equal to that of the default global group.")
size = len(ranks)
assert size > 1, "A group must have at least two memebers."
ranks = sorted(ranks)
if global_rank in ranks:
rank = ranks.index(global_rank)
pg = _new_process_group_impl(backend, _default_store, rank, size,
group_name, pg_options)
else:
rank = -1
pg = None
group = Group(
rank,
size,
id=_new_group_name_id(),
ranks=ranks,
pg=pg,
name=group_name)
_group_map_by_name[group_name] = group

return group


def barrier(group=None):
"""

Expand Down
49 changes: 49 additions & 0 deletions python/paddle/fluid/tests/unittests/init_process_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import random
import numpy as np
import os
import shutil

import paddle
from paddle.fluid import core
import datetime
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv


class TestProcessGroupFp32(unittest.TestCase):
def setUp(self):
self.config()

def config(self):
pass

def test_init_process_group(self):
paddle.distributed.collective._init_parallel_env()
paddle.distributed.collective._new_group()
with self.assertRaises(ValueError):
paddle.distributed.collective._new_group(
backend="gloo", group_name="_default_pg")
print("test ok\n")


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def test_process_group_nccl(self):
def test_process_group_gloo(self):
self.run_mnist_2gpu('process_group_gloo.py')

def test_init_process_group(self):
self.run_mnist_2gpu('init_process_group.py')


if __name__ == "__main__":
unittest.main()