Skip to content

Commit

Permalink
Update ncclCommInit function in fix version (#60947)
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII authored Feb 2, 2024
1 parent 0a55857 commit ec5ca16
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 13 deletions.
30 changes: 23 additions & 7 deletions paddle/fluid/distributed/collective/process_group_nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,15 @@ ProcessGroupNCCL::ProcessGroupNCCL(
int rank,
int size,
int gid,
int64_t timeout)
int64_t timeout,
bool nccl_comm_init_option)
: ProcessGroupWithStream(rank, size, gid),
store_(store),
pg_timeout_(timeout) {
pg_timeout_(timeout),
nccl_comm_init_option_(nccl_comm_init_option) {
LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_;
LOG(INFO) << "ProcessGroupNCCL nccl_comm_init_option_ "
<< nccl_comm_init_option_;
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
LOG(INFO) << "ProcessGroupNCCL destruct ";
Expand Down Expand Up @@ -851,8 +855,19 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,

NCCL_CHECK(phi::dynload::ncclGroupStart());
ncclComm_t nccl_comm;
NCCL_CHECK(
phi::dynload::ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank));
if (nccl_comm_init_option_) {
#if NCCL_FIX_CODE > 0
NCCL_CHECK(phi::dynload::ncclCommInitRank2(
&nccl_comm, num_ranks, nccl_id, rank, 1));
#else
LOG(WARNING) << "ncclCommInitRank2 is not supported.";
NCCL_CHECK(
phi::dynload::ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank));
#endif
} else {
NCCL_CHECK(
phi::dynload::ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank));
}
NCCL_CHECK(phi::dynload::ncclGroupEnd());

VLOG(3) << "Get nccl comm: " << nccl_comm << " for place_key: " << place_key
Expand Down Expand Up @@ -1112,9 +1127,10 @@ std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
int rank,
int size,
int gid,
int64_t timeout) {
auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid, timeout);
int64_t timeout,
bool nccl_comm_init_option) {
auto process_group = std::make_shared<ProcessGroupNCCL>(
store, rank, size, gid, timeout, nccl_comm_init_option);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/distributed/collective/process_group_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
int rank,
int size,
int gid,
int64_t timeout);
int64_t timeout,
bool nccl_comm_init_option);

ProcessGroupNCCL(const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid,
int64_t timeout = 20 * 1000);
int64_t timeout = 20 * 1000,
bool nccl_comm_init_option = false);
~ProcessGroupNCCL();

std::string GetBackendName() const override { return "NCCL"; }
Expand Down Expand Up @@ -176,6 +178,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {

ncclComm_t NCCLComm(const Place& place) const;

const bool GetNCCLCommInitOption() { return nccl_comm_init_option_; }

private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
Expand Down Expand Up @@ -242,6 +246,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
static uint64_t s_group_call_counter;
// default 30 minutes
int64_t pg_timeout_;
bool nccl_comm_init_option_;

// optimize memory for process_group
std::vector<std::pair<std::weak_ptr<phi::Allocation>, gpuStream_t>>
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/platform/dynload/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ NCCL_RAND_ROUTINE_EACH_AFTER_2703(PLATFORM_DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
NCCL_RAND_ROUTINE_EACH_AFTER_21100(PLATFORM_DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#endif

#ifndef NCCL_FIX_CODE
#define NCCL_FIX_CODE 0
#endif

#if NCCL_FIX_CODE > 0
#define NCCL_RAND_ROUTINE_EACH_WITH_FIX(__macro) __macro(ncclCommInitRank2);
NCCL_RAND_ROUTINE_EACH_WITH_FIX(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#endif

} // namespace dynload
} // namespace platform
} // namespace paddle
6 changes: 5 additions & 1 deletion paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1237,9 +1237,13 @@ void BindDistributed(py::module *m) {
py::arg("world_size"),
py::arg("group_id") = 0,
py::arg("timeout") = 30 * 60 * 1000,
py::arg("enable_nccl_comm_init_option") = false,
py::call_guard<py::gil_scoped_release>())
.def_static("group_start", distributed::ProcessGroupNCCL::GroupStart)
.def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd);
.def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd)
.def("get_nccl_comm_init_option",
&paddle::distributed::ProcessGroupNCCL::GetNCCLCommInitOption,
py::call_guard<py::gil_scoped_release>());

#endif

Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/backends/dynload/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,14 @@ NCCL_RAND_ROUTINE_EACH_AFTER_2703(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
NCCL_RAND_ROUTINE_EACH_AFTER_21100(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#endif

#ifndef NCCL_FIX_CODE
#define NCCL_FIX_CODE 0
#endif

#if NCCL_FIX_CODE > 0
#define NCCL_RAND_ROUTINE_EACH_WITH_FIX(__macro) __macro(ncclCommInitRank2);
NCCL_RAND_ROUTINE_EACH_WITH_FIX(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#endif

} // namespace dynload
} // namespace phi
16 changes: 14 additions & 2 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def _new_process_group_impl(
group_name,
pg_options,
group_id=0,
enable_nccl_comm_init_option=False,
):
pg = None
genv = _get_global_env()
Expand All @@ -152,7 +153,12 @@ def _new_process_group_impl(
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl":
pg = core.ProcessGroupNCCL.create(
store, rank, world_size, group_id, genv.pg_timeout
store,
rank,
world_size,
group_id,
genv.pg_timeout,
enable_nccl_comm_init_option,
)
elif backend == "xccl":
pg = core.ProcessGroupCustom.create(
Expand All @@ -174,7 +180,12 @@ def _set_custom_gid(gid):
_custom_gid = gid


def new_group(ranks=None, backend=None, timeout=_default_timeout):
def new_group(
ranks=None,
backend=None,
timeout=_default_timeout,
enable_nccl_comm_init_option=False,
):
"""
Creates a new distributed communication group.
Expand Down Expand Up @@ -227,6 +238,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
group_name,
pg_options=None,
group_id=gid,
enable_nccl_comm_init_option=enable_nccl_comm_init_option,
)
else:
rank = -1
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import os
from functools import reduce
from itertools import product

Expand All @@ -25,6 +26,10 @@

_HYBRID_PARALLEL_GROUP = None

g_pipeline_enable_nccl_comm_init_option = bool(
os.environ.get("FLAGS_pipeline_enable_nccl_comm_init_option", 0)
)


class ParallelMode:
"""
Expand Down Expand Up @@ -274,8 +279,15 @@ def _set_comm_group(self, parallel_method="data"):
parallel_comm_group = None
parallel_groups = self._topo.get_comm_list(parallel_method)

group_enable_nccl_comm_init_option = (
g_pipeline_enable_nccl_comm_init_option
and (parallel_method == "pipe")
)
for group in parallel_groups:
comm_group = paddle.distributed.new_group(ranks=group)
comm_group = paddle.distributed.new_group(
ranks=group,
enable_nccl_comm_init_option=group_enable_nccl_comm_init_option,
)
if self.global_rank in group:
parallel_group = group
parallel_comm_group = comm_group
Expand Down

0 comments on commit ec5ca16

Please sign in to comment.