Skip to content

Commit

Permalink
[Dygraph] Refactoring of reducer in DataParallel (#40389)
Browse files Browse the repository at this point in the history
* refactor reducer

* modify cmakelists

* solve conflicts

* rename group and update process_group

* fix bugs of ProcessGroupNCCL

* modify for CIs

* refactoring reducer
  • Loading branch information
haohongxiang authored Mar 15, 2022
1 parent af6ef88 commit 1a32391
Show file tree
Hide file tree
Showing 9 changed files with 736 additions and 23 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api)
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api)

if (WITH_DISTRIBUTE)
cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper)
endif()
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup)

if(WITH_NCCL)
cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api)
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ void SyncDefaultStream(
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
ncclEvents[i].Record(*dev_ctx[i]);
ncclEvents[i].Block(*default_ctx);
ncclEvents[i].Record(*default_ctx);
ncclEvents[i].Block(*dev_ctx[i]);
}
}

Expand Down
Loading

0 comments on commit 1a32391

Please sign in to comment.