Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
BeingGod committed Sep 19, 2023
1 parent b6b8773 commit 4f2badd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
2 changes: 2 additions & 0 deletions paddle/phi/core/distributed/nccl_comm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ void NCCLCommContext::GroupStart() {
}
void NCCLCommContext::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); }

#if NCCL_VERSION_CODE >= 21100
void NCCLCommContext::RedOpCreatePreMulSum(ncclRedOp_t* op,
void* scalar,
ncclDataType_t dtype,
Expand All @@ -242,6 +243,7 @@ void NCCLCommContext::RedOpCreatePreMulSum(ncclRedOp_t* op,
void NCCLCommContext::RedOpDestroy(ncclRedOp_t op) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, nccl_comm_));
}
#endif

} // namespace distributed
} // namespace phi
10 changes: 6 additions & 4 deletions paddle/phi/core/distributed/nccl_comm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,18 @@ class NCCLCommContext final : public CommContext {
int root,
gpuStream_t stream);

void GroupStart();

void GroupEnd();

#if NCCL_VERSION_CODE >= 21100
void RedOpCreatePreMulSum(ncclRedOp_t* op,
void* scalar,
ncclDataType_t dtype,
ncclScalarResidence_t residence);

void RedOpDestroy(ncclRedOp_t op);

void GroupStart();

void GroupEnd();
#endif

private:
DISABLE_COPY_AND_ASSIGN(NCCLCommContext);
Expand Down
5 changes: 4 additions & 1 deletion test/legacy_test/distributed_fused_lamb_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,10 @@ def setUpClass(cls):
paddle.enable_static()
paddle.set_flags({'FLAGS_cudnn_deterministic': True})
_clip_by_global_norm_using_mp_type(True)
fleet.init(role_maker=get_role_maker())
if os.environ.get("FLAGS_dynamic_static_unified_comm") == "1":
fleet.init(role_maker=get_role_maker())
else:
paddle.distributed.collective._init_parallel_env("nccl")

def config(self):
clip_after_allreduce = bool(
Expand Down

0 comments on commit 4f2badd

Please sign in to comment.