diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index faf29add30d91..bd49f0cff1708 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -231,6 +231,7 @@ void NCCLCommContext::GroupStart() { } void NCCLCommContext::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); } +#if NCCL_VERSION_CODE >= 21100 void NCCLCommContext::RedOpCreatePreMulSum(ncclRedOp_t* op, void* scalar, ncclDataType_t dtype, @@ -242,6 +243,7 @@ void NCCLCommContext::RedOpCreatePreMulSum(ncclRedOp_t* op, void NCCLCommContext::RedOpDestroy(ncclRedOp_t op) { PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, nccl_comm_)); } +#endif } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_context.h b/paddle/phi/core/distributed/nccl_comm_context.h index 61c3fb06c0e33..b9fdce02f4b5f 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.h +++ b/paddle/phi/core/distributed/nccl_comm_context.h @@ -98,16 +98,22 @@ class NCCLCommContext final : public CommContext { int root, gpuStream_t stream); + void GroupStart(); + + void GroupEnd(); + +#if NCCL_VERSION_CODE >= 21100 + // Creates a new reduction operator which pre-multiplies input values by a + // given scalar locally before reducing them with peer values via summation. void RedOpCreatePreMulSum(ncclRedOp_t* op, void* scalar, ncclDataType_t dtype, ncclScalarResidence_t residence); + // Destroys the reduction operator op. The operator must have been created by + // ncclRedOpCreatePreMul with the matching communicator comm. void RedOpDestroy(ncclRedOp_t op); - - void GroupStart(); - - void GroupEnd(); +#endif private: DISABLE_COPY_AND_ASSIGN(NCCLCommContext); diff --git a/test/legacy_test/distributed_fused_lamb_test_base.py b/test/legacy_test/distributed_fused_lamb_test_base.py index baffc7dd5e546..ea011becc9090 100644 --- a/test/legacy_test/distributed_fused_lamb_test_base.py +++ b/test/legacy_test/distributed_fused_lamb_test_base.py @@ -270,7 +270,10 @@ def setUpClass(cls): paddle.enable_static() paddle.set_flags({'FLAGS_cudnn_deterministic': True}) _clip_by_global_norm_using_mp_type(True) - fleet.init(role_maker=get_role_maker()) + if os.environ.get("FLAGS_dynamic_static_unified_comm") == "1": + paddle.distributed.collective._init_parallel_env("nccl") + else: + fleet.init(role_maker=get_role_maker()) def config(self): clip_after_allreduce = bool(