diff --git a/paddle/fluid/distributed/collective/process_group.cc b/paddle/fluid/distributed/collective/process_group.cc index 279e06ebb0faa..f151c041c7412 100644 --- a/paddle/fluid/distributed/collective/process_group.cc +++ b/paddle/fluid/distributed/collective/process_group.cc @@ -28,6 +28,12 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid) auto map = ProcessGroupMapFromGid::getInstance(); map->insert(gid_, this); } + const char* global_rank = std::getenv("PADDLE_TRAINER_ID"); + PADDLE_ENFORCE_NOT_NULL( + global_rank, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + global_rank_ = std::atoi(global_rank); } // TODO(sunyilun): methods below will be removed later diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index 8767dfa60cf18..e2b31950bd51b 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -490,6 +490,7 @@ class ProcessGroup { } protected: + int global_rank_{-1}; int rank_; int size_; int gid_; diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 2a12bb764f7a3..7733e217f757e 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -13,20 +13,19 @@ // limitations under the License. #include "paddle/fluid/distributed/collective/process_group_nccl.h" - #include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" #include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/comm_task_manager.h" #include "paddle/phi/core/distributed/nccl_comm_task.h" #include "paddle/phi/core/distributed/nccl_tools.h" -#include "paddle/phi/core/distributed/trace_utils.h" #include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" @@ -46,8 +45,6 @@ namespace paddle { namespace distributed { using phi::distributed::CheckSizeOnEachRank; -using phi::distributed::GetTraceEndKey; -using phi::distributed::GetTraceStartKey; using phi::distributed::IsP2POP; using phi::distributed::NCCLDTypeToString; using phi::distributed::NCCLRedTypeToString; @@ -119,6 +116,13 @@ ProcessGroupNCCL::ProcessGroupNCCL( pg_timeout_(timeout) { LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_; } +ProcessGroupNCCL::~ProcessGroupNCCL() { + LOG(INFO) << "ProcessGroupNCCL destruct "; + if (FLAGS_enable_async_trace) { + auto& comm_task_manager = phi::distributed::CommTaskManager::GetInstance(); + comm_task_manager.Stop(); + } +} void ProcessGroupNCCL::GroupStart() { NCCL_CHECK(phi::dynload::ncclGroupStart()); @@ -674,6 +678,7 @@ void ProcessGroupNCCL::GetStoreKey(const std::string& place_key, } else { *store_key = "nccl_ids/" + std::to_string(gid_) + "/" + place_key; } + place_to_group_key_[place_key] = *store_key; } void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, @@ -711,6 +716,50 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, auto comm_ctx = std::make_unique(place); comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm()); + if (FLAGS_enable_async_trace) { + // gather global ranks in current group + size_t gpu_global_rank_size = sizeof(int); + auto gpu_global_rank = phi::memory_utils::Alloc( + phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()), + gpu_global_rank_size); + + phi::memory_utils::Copy(phi::GPUPlace(), + gpu_global_rank->ptr(), + phi::CPUPlace(), + &global_rank_, + gpu_global_rank_size); + + size_t gpu_global_ranks_size = num_ranks * sizeof(int); + auto gpu_global_ranks = phi::memory_utils::Alloc( + phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()), + gpu_global_ranks_size); + + NCCL_CHECK(phi::dynload::ncclAllGather(gpu_global_rank->ptr(), + gpu_global_ranks->ptr(), + 1, + ncclInt, + nccl_comm_ctx->GetNcclComm(), + comm_ctx->stream())); + + std::vector global_ranks(num_ranks); + phi::memory_utils::Copy(phi::CPUPlace(), + global_ranks.data(), + phi::GPUPlace(), + gpu_global_ranks->ptr(), + gpu_global_ranks_size); + + // store global_ranks in current group_key + std::once_flag flag; + std::call_once(flag, [this]() { + phi::distributed::CommContextManager::GetInstance().SetStore(store_); + phi::distributed::CommTaskManager::GetInstance().SetTimeout(pg_timeout_); + }); + + std::string group_key = place_to_group_key_.at(place_key); + phi::distributed::CommContextManager::GetInstance().AddGroupRanks( + group_key, global_ranks); + } + auto* calc_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); @@ -771,8 +820,10 @@ std::shared_ptr ProcessGroupNCCL::Collective( if (!FLAGS_enable_async_trace) { fn(nccl_comm_ctx, nccl_stream); } else { + std::string group_key = place_to_group_key_.at(key); auto comm_task = std::make_shared(place, + group_key, rank_, size_, gid_, @@ -837,16 +888,19 @@ std::shared_ptr ProcessGroupNCCL::Point2Point( bool is_batch_p2p = s_group_call_counter > 0; std::string key = ""; + int p2p_nrank = 0; if (is_batch_p2p) { key = GetKeyFromPlace(place); p2p_rank = rank_; p2p_target_rank = peer; + p2p_nrank = GetSize(); } else { int low_rank = rank_ < peer ? rank_ : peer; int high_rank = rank_ < peer ? peer : rank_; key = std::to_string(low_rank) + "->" + std::to_string(high_rank); p2p_rank = rank_ < peer ? 0 : 1; p2p_target_rank = 1 - p2p_rank; + p2p_nrank = 2; } platform::CUDADeviceGuard cuda_guard(place); @@ -857,6 +911,10 @@ std::shared_ptr ProcessGroupNCCL::Point2Point( if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLEnvCache(place, key, store_key, comm_type, p2p_rank); } + if (p2p_comm_seq_.find(key) == p2p_comm_seq_.end()) { + p2p_comm_seq_[key] = 0; + } + p2p_comm_seq_[key]++; if (!use_calc_stream) { SyncCalcStream(place, key); @@ -869,18 +927,21 @@ std::shared_ptr ProcessGroupNCCL::Point2Point( auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); + std::string group_key = place_to_group_key_.at(key); auto comm_task = std::make_shared(place, - rank_, - size_, + group_key, + p2p_rank, + p2p_nrank, gid_, - comm_seq_, + p2p_comm_seq_[key], tensor_tmp.numel(), sync_op, use_calc_stream, nccl_comm, nccl_stream, - comm_type); + comm_type, + pg_timeout_); auto nccl_comm_ctx = this->GetCommContext(&store_key); diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index 96c907e622b17..f923f1ddbdbf8 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -79,6 +79,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { int size, int gid, int64_t timeout = 30 * 60 * 1000); + ~ProcessGroupNCCL(); std::string GetBackendName() const override { return "NCCL"; } @@ -220,6 +221,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { place_to_comm_ctx_; uint64_t comm_seq_{0}; + std::unordered_map p2p_comm_seq_; + std::unordered_map place_to_group_key_; // TODO(sunyilun): attrs below will be removed later std::mutex mutex_; diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 2a5b336f34e25..e728681f16251 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -208,5 +208,25 @@ bool CommContextManager::Has(const std::string& unique_comm_key) const { return id_to_comm_context_.find(unique_comm_key) != id_to_comm_context_.end(); } +void CommContextManager::SetGroupSize(const std::string& pg_key, int size) { + pg_key_size_[pg_key] = size; +} + +void CommContextManager::AddGroupRanks(const std::string& pg_key, + std::vector global_ranks) { + if (pg_key_ranks_.find(pg_key) == pg_key_ranks_.end()) { + pg_key_ranks_[pg_key] = global_ranks; + } +} + +std::vector CommContextManager::GetGroupRanks( + const std::string& pg_key) const { + PADDLE_ENFORCE_NE( + pg_key_ranks_.find(pg_key), + pg_key_ranks_.end(), + errors::NotFound("Can not find pg_key %d in GroupRanks.", pg_key)); + return pg_key_ranks_.at(pg_key); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 2229786db3855..132f9e4f52cd1 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "paddle/phi/common/place.h" #include "paddle/phi/core/distributed/comm_context.h" @@ -64,6 +65,12 @@ class CommContextManager { static void SetDeviceId(int dev_id); + void SetGroupSize(const std::string& pg_key, int size); + + void AddGroupRanks(const std::string& pg_key, std::vector global_ranks); + + std::vector GetGroupRanks(const std::string& pg_key) const; + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) static void CreateNCCLCommContext(const std::shared_ptr& store, const std::string& unique_comm_key, @@ -96,6 +103,11 @@ class CommContextManager { id_to_comm_context_; std::shared_ptr store_; static int device_id; + + // process group key to global ranks map + std::unordered_map> pg_key_ranks_; + // process group key to group size map + std::unordered_map pg_key_size_; }; } // namespace distributed diff --git a/paddle/phi/core/distributed/comm_task.h b/paddle/phi/core/distributed/comm_task.h index 3673c7a9e21aa..05560eb67dafc 100644 --- a/paddle/phi/core/distributed/comm_task.h +++ b/paddle/phi/core/distributed/comm_task.h @@ -37,6 +37,7 @@ class CommTask { public: CommTask(const std::string& backend = "", const phi::Place& place = phi::Place(), + const std::string& group_key = "", int rank = -1, int size = 0, int gid = 0, @@ -47,6 +48,7 @@ class CommTask { CommType comm_type = CommType::UNKNOWN) : backend_(backend), place_(place), + group_key_(group_key), rank_(rank), size_(size), gid_(gid), @@ -65,10 +67,11 @@ class CommTask { virtual ~CommTask() = default; std::string UniqueKey() { - return "op:" + CommTypeToString(comm_type_) + + return "group_key:" + group_key_ + ",op:" + CommTypeToString(comm_type_) + ",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_); } + std::string GroupKey() { return group_key_; } std::string GetBackend() { return backend_; } phi::Place GetPlace() { return place_; } int GetGlobalRank() { return global_rank_; } @@ -105,6 +108,12 @@ class CommTask { return; } + virtual void ClearRecord() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + virtual std::string GetCommErrors() { PADDLE_THROW( phi::errors::Unimplemented("%s is not implemented.", __func__)); @@ -125,6 +134,16 @@ class CommTask { phi::errors::Unimplemented("%s is not implemented.", __func__)); return false; } + virtual void SetUpdated(bool updated) { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + virtual bool IsUpdated() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } virtual void AbortComm() { PADDLE_THROW( phi::errors::Unimplemented("%s is not implemented.", __func__)); @@ -134,6 +153,7 @@ class CommTask { protected: std::string backend_; phi::Place place_; + std::string group_key_; int global_rank_; int rank_; int size_; @@ -145,7 +165,11 @@ class CommTask { CommType comm_type_; bool start_trace_updated_{false}; + // task status + bool started_ = false; bool completed_ = false; + // task status changed + bool updated_ = true; bool aborted_{false}; std::chrono::time_point start_time_; std::shared_ptr store_; diff --git a/paddle/phi/core/distributed/comm_task_manager.cc b/paddle/phi/core/distributed/comm_task_manager.cc index 37083119b59f5..ae7de42291358 100644 --- a/paddle/phi/core/distributed/comm_task_manager.cc +++ b/paddle/phi/core/distributed/comm_task_manager.cc @@ -22,6 +22,7 @@ #include "paddle/phi/core/distributed/comm_context_manager.h" +#include #include #include @@ -34,35 +35,51 @@ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/phi/core/distributed/comm_task_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" -#include "paddle/phi/core/distributed/trace_utils.h" #endif namespace phi { namespace distributed { std::thread CommTaskManager::comm_task_loop_thread_; +std::thread CommTaskManager::comm_task_clear_loop_thread_; const int64_t CommTaskManager::loop_thread_sleep_millis = 10000; std::atomic CommTaskManager::terminated_; std::mutex CommTaskManager::comm_task_list_mutex_; std::condition_variable CommTaskManager::comm_task_list_cv_; std::list> CommTaskManager::comm_task_list_; + +std::mutex CommTaskManager::comm_task_clear_list_mutex_; +std::condition_variable CommTaskManager::comm_task_clear_list_cv_; +std::list> CommTaskManager::comm_task_clear_list_; + std::unordered_map> CommTaskManager::init_comm_task_map_; std::unordered_map> CommTaskManager::start_comm_task_map_; +std::unordered_map> + CommTaskManager::group_last_comm_task_; +std::chrono::time_point + CommTaskManager::last_update_time_ = std::chrono::steady_clock::now(); CommTaskManager::CommTaskManager() { terminated_.store(false); comm_task_loop_thread_ = std::thread(&CommTaskManager::CommTaskLoop, this); - LOG(INFO) << "CommTaskManager init success"; + comm_task_clear_loop_thread_ = + std::thread(&CommTaskManager::CommTaskClearLoop, this); + LOG(INFO) << "CommTaskManager init success."; } CommTaskManager::~CommTaskManager() { terminated_.store(true); if (comm_task_loop_thread_.joinable()) { - comm_task_loop_thread_.join(); comm_task_list_cv_.notify_one(); + comm_task_loop_thread_.join(); + } + + if (comm_task_clear_loop_thread_.joinable()) { + comm_task_clear_list_cv_.notify_one(); + comm_task_clear_loop_thread_.join(); } LOG(INFO) << "CommTaskManager destruct success."; } @@ -74,33 +91,106 @@ void CommTaskManager::CommTaskEnqueue(std::shared_ptr comm_task) { } } +void CommTaskManager::CommTaskClearEnqueue( + std::shared_ptr comm_task) { + if (!terminated_.load()) { + std::lock_guard lock(comm_task_clear_list_mutex_); + comm_task_clear_list_.emplace_back(comm_task); + } +} + +void CommTaskManager::Stop() { + terminated_.store(true); + + LOG(INFO) << "CommTaskManager stopped begin."; + if (comm_task_loop_thread_.joinable()) { + comm_task_list_cv_.notify_one(); + comm_task_loop_thread_.join(); + } + + if (comm_task_clear_loop_thread_.joinable()) { + comm_task_clear_list_cv_.notify_one(); + comm_task_clear_loop_thread_.join(); + } + + LOG(INFO) << "CommTaskManager stopped."; +} + +inline void LogLongStr(const std::string prefix, const std::string& log) { + size_t max_log_size = 20000; + if (log.size() >= max_log_size) { + int log_count = log.size() / max_log_size + 1; + int index = 0; + int part = 0; + while (index + max_log_size < log.size()) { + LOG(INFO) << prefix << "part:" << part << "/" << log_count << "," + << log.substr(index, max_log_size) << std::endl; + index += max_log_size; + part++; + } + LOG(INFO) << prefix << "part:" << part << "/" << log_count << "," + << log.substr(index) << std::endl; + } else { + LOG(INFO) << prefix << "part:0/1," << log << std::endl; + } +} + void CommTaskManager::CommTaskLoop() { bool done = false; while (!terminated_.load() || !done) { std::unique_lock lock(comm_task_list_mutex_); + VLOG(3) << "IsTimeout: " << IsTimeout() + << ", comm_task_list_ size: " << comm_task_list_.size() + << ", init_comm_task_map_ size: " << init_comm_task_map_.size() + << ", start_comm_task_map_ size: " << start_comm_task_map_.size() + << ", logged_ " << logged_; + comm_task_list_cv_.wait_for( lock, std::chrono::milliseconds(loop_thread_sleep_millis), [&]() -> bool { return terminated_.load(); }); + + if (IsTimeout() && !logged_) { + // case 1: all group is empty, has no task + // report error immediately + if (group_last_comm_task_.empty()) { + LOG(WARNING) << "Find no task started in all group"; + } else { + // case 2: all group is not empty, but all last task is completed + // case 3: all group is not empty, some group task started but not + for (auto iter : group_last_comm_task_) { + LogLongStr("Find last group comm task:", iter.second->GetTraceMsg()); + } + } + logged_ = true; + } for (auto iter = comm_task_list_.begin(); iter != comm_task_list_.end();) { auto task = *iter; if (task->IsTimeout()) { if (!task->IsStarted()) { - LOG(ERROR) << "Find timeout init but not start task: " - << task->GetTraceMsg() << ",comm:" << task->nccl_comm() - << ",stream:" << task->nccl_stream(); + LOG(WARNING) << "Find timeout init but not start task:" + << task->GetTraceMsg(); std::string task_key = task->UniqueKey(); init_comm_task_map_[task_key] = task; } else if (!task->IsCompleted()) { - LOG(ERROR) << "Find timeout start but not finish task: " - << task->GetTraceMsg() << ",comm:" << task->nccl_comm() - << ",stream:" << task->nccl_stream(); + LOG(WARNING) << "Find timeout start but not finish task:" + << task->GetTraceMsg(); std::string task_key = task->UniqueKey(); start_comm_task_map_[task_key] = task; } iter = comm_task_list_.erase(iter); } else { - ++iter; + if (task->IsStarted()) { + if (task->IsCompleted()) { + CommTaskClearEnqueue(task); + iter = comm_task_list_.erase(iter); + } else { + ++iter; + } + UpdateLastCommTask(task); + } else { + ++iter; + } } } @@ -121,6 +211,8 @@ void CommTaskManager::CommTaskLoop() { iter != start_comm_task_map_.end();) { auto task = iter->second; if (task->IsCompleted()) { + CommTaskClearEnqueue(task); + UpdateLastCommTask(task); iter = start_comm_task_map_.erase(iter); LOG(INFO) << "Finish timeout task: " << task->GetTraceMsg(); } else { @@ -131,9 +223,58 @@ void CommTaskManager::CommTaskLoop() { if (comm_task_list_.empty() && init_comm_task_map_.empty() && start_comm_task_map_.empty()) { done = true; + } else { + done = false; } } } +void CommTaskManager::CommTaskClearLoop() { + std::future future; + while (!terminated_.load()) { + if (future.valid()) { + future.wait(); + } + std::unique_lock lock(comm_task_clear_list_mutex_); + comm_task_clear_list_cv_.wait_for( + lock, + std::chrono::milliseconds(loop_thread_sleep_millis), + [&]() -> bool { return terminated_.load(); }); + + VLOG(3) << "comm_task_clear_list_ size: " << comm_task_clear_list_.size(); + for (auto iter = comm_task_clear_list_.begin(); + iter != comm_task_clear_list_.end();) { + auto task = *iter; + VLOG(3) << "start clear task: " << task->GetTraceMsg(); + future = std::async(std::launch::async, [&]() { task->ClearRecord(); }); + if (future.wait_for(std::chrono::seconds(30)) == + std::future_status::timeout) { + VLOG(0) << "clear task timeout, detail: " << task->GetTraceMsg(); + break; + } + VLOG(3) << "end clear task: " << task->GetTraceMsg(); + iter = comm_task_clear_list_.erase(iter); + } + } +} + +void CommTaskManager::UpdateLastCommTask(std::shared_ptr task) { + if (!task->IsUpdated()) { + return; + } + group_last_comm_task_[task->GroupKey()] = task; + last_update_time_ = std::chrono::steady_clock::now(); + task->SetUpdated(false); +} + +void CommTaskManager::SetTimeout(int64_t timeout) { + timeout_ = std::chrono::milliseconds(timeout); +} + +bool CommTaskManager::IsTimeout() { + auto current_timepoint = std::chrono::steady_clock::now(); + return std::chrono::duration_cast( + current_timepoint - last_update_time_) >= timeout_; +} } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/comm_task_manager.h b/paddle/phi/core/distributed/comm_task_manager.h index 58be0026dd072..bb739d5c6afdb 100644 --- a/paddle/phi/core/distributed/comm_task_manager.h +++ b/paddle/phi/core/distributed/comm_task_manager.h @@ -46,11 +46,18 @@ class CommTaskManager { } void CommTaskEnqueue(std::shared_ptr comm_task); + void CommTaskClearEnqueue(std::shared_ptr comm_task); + void Stop(); + void UpdateLastCommTask(std::shared_ptr comm_task); + void SetTimeout(int64_t timeout); private: void CommTaskLoop(); + void CommTaskClearLoop(); + bool IsTimeout(); static std::thread comm_task_loop_thread_; + static std::thread comm_task_clear_loop_thread_; static const int64_t loop_thread_sleep_millis; static std::atomic terminated_; @@ -58,6 +65,11 @@ class CommTaskManager { static std::mutex comm_task_list_mutex_; static std::condition_variable comm_task_list_cv_; static std::list> comm_task_list_; + + static std::mutex comm_task_clear_list_mutex_; + static std::condition_variable comm_task_clear_list_cv_; + static std::list> comm_task_clear_list_; + // not start task static std::unordered_map> init_comm_task_map_; @@ -65,7 +77,12 @@ class CommTaskManager { static std::unordered_map> start_comm_task_map_; std::shared_ptr store_; - bool store_error_{false}; + // record last comm task in current group, eg: group_key->comm_task + static std::unordered_map> + group_last_comm_task_; + static std::chrono::time_point last_update_time_; + std::chrono::milliseconds timeout_; + bool logged_ = false; }; } // namespace distributed diff --git a/paddle/phi/core/distributed/nccl_comm_task.cc b/paddle/phi/core/distributed/nccl_comm_task.cc index f82f39c1954a3..6bc002627a023 100644 --- a/paddle/phi/core/distributed/nccl_comm_task.cc +++ b/paddle/phi/core/distributed/nccl_comm_task.cc @@ -15,17 +15,17 @@ #include "paddle/phi/core/distributed/nccl_comm_task.h" #include "gflags/gflags.h" -#include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_tools.h" -#include "paddle/phi/core/distributed/trace_utils.h" #include "paddle/phi/core/utils/data_type.h" namespace phi { namespace distributed { NCCLCommTask::NCCLCommTask(const phi::Place& place, + const std::string& group_key, int rank, int size, int gid, @@ -39,6 +39,7 @@ NCCLCommTask::NCCLCommTask(const phi::Place& place, int64_t timeout) : CommTask("NCCL", place, + group_key, rank, size, gid, @@ -89,7 +90,20 @@ void NCCLCommTask::EndRecord() { #endif } -bool NCCLCommTask::CudaEventQuery(gpuEvent_t event) { +void NCCLCommTask::ClearRecord() { + if (start_event_created_) { + backends::gpu::GPUDeviceGuard guard(place_.device); + CUDA_CHECK(cudaEventDestroy(nccl_start_event_)); + start_event_created_ = false; + } + if (end_event_created_) { + backends::gpu::GPUDeviceGuard guard(place_.device); + CUDA_CHECK(cudaEventDestroy(nccl_end_event_)); + end_event_created_ = false; + } +} + +bool NCCLCommTask::CudaEventQuery(cudaEvent_t event) { #ifdef PADDLE_WITH_CUDA cudaError_t ret = cudaEventQuery(event); if (ret == cudaSuccess) { @@ -175,9 +189,31 @@ std::string NCCLCommTask::GetCommErrors() { return comm_error_; } -bool NCCLCommTask::IsStarted() { return CudaEventQuery(nccl_start_event_); } +bool NCCLCommTask::IsStarted() { + if (started_) { + return true; + } + if (start_event_created_ && CudaEventQuery(nccl_start_event_)) { + started_ = true; + updated_ = true; + } + return started_; +} + +bool NCCLCommTask::IsCompleted() { + if (completed_) { + return true; + } + if (end_event_created_ && CudaEventQuery(nccl_end_event_)) { + completed_ = true; + updated_ = true; + } + return completed_; +} + +void NCCLCommTask::SetUpdated(bool updated) { updated_ = updated; } -bool NCCLCommTask::IsCompleted() { return CudaEventQuery(nccl_end_event_); } +bool NCCLCommTask::IsUpdated() { return updated_; } bool NCCLCommTask::IsTimeout() { auto current_timepoint = std::chrono::steady_clock::now(); @@ -201,18 +237,19 @@ std::string NCCLCommTask::GetTraceMsg() { auto current_timepoint = std::chrono::steady_clock::now(); auto time_elapsed = std::chrono::duration_cast( current_timepoint - start_time_); - return "op:" + CommTypeToString(comm_type_) + ",gid:" + std::to_string(gid_) + - ",seq:" + std::to_string(seq_) + - ",started:" + std::to_string(IsStarted()) + - ",completed:" + std::to_string(IsCompleted()) + + auto global_ranks = + phi::distributed::CommContextManager::GetInstance().GetGroupRanks( + group_key_); + return "group_key:" + group_key_ + + ",group_ranks:" + VectorToString(global_ranks) + ",global_rank:" + std::to_string(global_rank_) + ",local_rank:" + std::to_string(rank_) + - ",size:" + std::to_string(size_) + ",numel:" + std::to_string(numel_) + - ",sync_op:" + std::to_string(sync_op_) + - ",use_calc_stream:" + std::to_string(use_calc_stream_) + - ",timeout:" + std::to_string(timeout_.count()) + - ",is_timeout:" + std::to_string(IsTimeout()) + - ",time_elapsed:" + std::to_string(time_elapsed.count()); + ",comm_count:" + std::to_string(seq_) + + ",op:" + CommTypeToString(comm_type_) + + ",started:" + std::to_string(started_) + + ",completed:" + std::to_string(completed_) + + ",numel:" + std::to_string(numel_) + + ",nranks:" + std::to_string(size_); } } // namespace distributed diff --git a/paddle/phi/core/distributed/nccl_comm_task.h b/paddle/phi/core/distributed/nccl_comm_task.h index 9fe71670c2f88..f9a8f3c250922 100644 --- a/paddle/phi/core/distributed/nccl_comm_task.h +++ b/paddle/phi/core/distributed/nccl_comm_task.h @@ -34,6 +34,7 @@ static int64_t DefaultTimeout = 30 * 60 * 1000; class NCCLCommTask : public CommTask { public: NCCLCommTask(const phi::Place& place = phi::Place(), + const std::string& group_key = "", int rank = -1, int size = 0, int gid = 0, @@ -51,6 +52,8 @@ class NCCLCommTask : public CommTask { bool IsStarted() override; bool IsTimeout() override; bool IsCompleted() override; + void SetUpdated(bool updated) override; + bool IsUpdated() override; std::string GetTraceMsg() override; std::string GetCommErrors() override; @@ -58,6 +61,7 @@ class NCCLCommTask : public CommTask { void StartRecord(); void EndRecord(); + void ClearRecord() override; bool CudaEventQuery(gpuEvent_t event); diff --git a/paddle/phi/core/distributed/trace_utils.h b/paddle/phi/core/distributed/trace_utils.h deleted file mode 100644 index 7a34055a987bc..0000000000000 --- a/paddle/phi/core/distributed/trace_utils.h +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/distributed/store/store.h" -#include "paddle/utils/string/split.h" - -namespace phi { -namespace distributed { - -enum TraceEventType { - TraceEventStart, - TraceEventEnd, -}; - -using TraceMap = - std::map>>; - -inline std::string GetTraceStartKey(const std::string& backend, - int rank, - int gid) { - return backend + "_" + std::to_string(rank) + "_" + std::to_string(gid) + - "_trace_start"; -} - -inline std::string GetTraceEndKey(const std::string& backend, - int rank, - int gid) { - return backend + "_" + std::to_string(rank) + "_" + std::to_string(gid) + - "_trace_end"; -} - -inline std::string GetExceptionMsgFromExceptionPtr( - const std::exception_ptr& exception_ptr) { - if (exception_ptr == nullptr) { - return "No exception found"; - } - try { - std::rethrow_exception(exception_ptr); - } catch (const std::exception& e) { - return e.what(); - } catch (...) { - return "Unknown exception type"; - } -} - -inline bool UpdateTraceMsg(std::shared_ptr store, - const std::string& key, - uint64_t seq, - const std::string& comm_type) { - std::vector value(comm_type.size() + sizeof(seq) + 1); - memcpy(value.data(), &seq, sizeof(seq)); - memcpy(value.data() + sizeof(seq), comm_type.data(), comm_type.size()); - try { - store->set(key, value); - return true; - } catch (...) { - LOG(ERROR) << "Store is down while updating trace msg, with seq: " << seq - << ", key " << key; - return false; - } -} - -inline bool ParseTraceValue(std::shared_ptr store, - const std::string& key, - uint64_t* seq, - std::string* comm_type) { - try { - std::vector value = store->get(key); - memcpy(seq, value.data(), sizeof(*seq)); - std::string type_value( - reinterpret_cast(value.data() + sizeof(*seq))); - *comm_type = type_value; - return true; - } catch (...) { - LOG(ERROR) << "Store is down while parsing trace value, with key: " << key; - return false; - } -} - -inline std::string RanksToString(const std::vector& ranks) { - std::string result; - for (int rank : ranks) { - if (result.empty()) { - result += std::to_string(rank); - } else { - result += ", " + std::to_string(rank); - } - } - return result; -} - -inline std::string AnalyzeTraceMsg(const TraceMap& trace_map, int gid) { - uint64_t lag_seq = trace_map.begin()->first; - std::vector start_ranks; - std::vector end_ranks; - for (auto& p : trace_map.begin()->second) { - if (p.second.second == TraceEventStart) { - start_ranks.emplace_back(p.first); - } else { - end_ranks.emplace_back(p.first); - } - } - - std::string result = "\n\t The ranks that has desync problem are: "; - if (start_ranks.size()) { - result += "[" + RanksToString(start_ranks) + - "] joined but do not finish collective seq: " + - std::to_string(lag_seq) + " in group_id: " + std::to_string(gid); - } - if (end_ranks.size()) { - result += ", ranks [" + RanksToString(end_ranks) + - "] finished collective seq: " + std::to_string(lag_seq) + - ", but didnt join seq: " + std::to_string(lag_seq + 1) + - " in group_id: " + std::to_string(gid); - } - return result; -} - -inline std::string GenerateTraceMsg(std::shared_ptr store, - const std::string& backend, - int curr_rank, - int group_id, - int world_size) { - std::string result; - TraceMap trace_map; - - uint64_t curr_seq; - std::string curr_comm_type; - - for (int rank = 0; rank < world_size; ++rank) { - uint64_t seq_start = 0; - { - std::string trace_start_key = GetTraceStartKey(backend, rank, group_id); - if (!store->check(trace_start_key)) { - continue; - } - - std::string comm_type; - if (!ParseTraceValue(store, trace_start_key, &seq_start, &comm_type)) { - return result; - } - trace_map[seq_start].emplace(rank, - std::make_pair(comm_type, TraceEventStart)); - if (rank == curr_rank) { - curr_seq = seq_start; - curr_comm_type = std::move(comm_type); - } - } - { - std::string trace_end_key = GetTraceEndKey(backend, rank, group_id); - if (!store->check(trace_end_key)) { - continue; - } - - uint64_t seq = 0; - std::string comm_type; - if (!ParseTraceValue(store, trace_end_key, &seq, &comm_type)) { - return result; - } - if (seq == seq_start) { - trace_map[seq][rank].second = TraceEventEnd; - } - } - } - result += "\n\t Problem summary: rank: " + std::to_string(curr_rank) + - " timeout at collective: " + curr_comm_type + - ", group_id: " + std::to_string(group_id) + - ", seq: " + std::to_string(curr_seq); - result += AnalyzeTraceMsg(trace_map, group_id); - return result; -} - -} // namespace distributed -} // namespace phi diff --git a/paddle/phi/core/distributed/utils.h b/paddle/phi/core/distributed/utils.h index 40b28bb2a3e6f..79cd1861da9dd 100644 --- a/paddle/phi/core/distributed/utils.h +++ b/paddle/phi/core/distributed/utils.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -141,5 +142,46 @@ inline std::string CommTypeToString(CommType CommType) { return "Unknown"; } +// convert vector to string, concatenate continuous intervals with `~`, +// concatenate discontinuous intervals with `#` eg: [1,2,3,4,5,7,8,9] => +// 1~3#4#5#7~9 +inline std::string VectorToString(const std::vector& vec) { + if (vec.empty()) { + return ""; + } + if (vec.size() == 1) { + return std::to_string(vec[0]); + } + + std::stringstream ss; + size_t i = 0; + int start_rank = vec[i]; + for (; i < vec.size() - 1; ++i) { + if (vec[i] + 1 == vec[i + 1]) { + continue; + } + if (ss.rdbuf()->in_avail() != 0) { + ss << "#"; + } + ss << start_rank; + if (start_rank != vec[i]) { + ss << "~"; + ss << vec[i]; + } + start_rank = vec[i + 1]; + } + + if (ss.rdbuf()->in_avail() != 0) { + ss << "#"; + } + ss << start_rank; + if (start_rank != vec[i]) { + ss << "~"; + ss << vec[i]; + } + + return ss.str(); +} + } // namespace distributed } // namespace phi diff --git a/test/collective/test_collective_allgather_api.py b/test/collective/test_collective_allgather_api.py index f53165d3fbd96..2edb3540b16fe 100644 --- a/test/collective/test_collective_allgather_api.py +++ b/test/collective/test_collective_allgather_api.py @@ -149,6 +149,20 @@ def test_allgather_nccl_dygraph(self): dtype=dtype, ) + def test_allgather_nccl_dygraph_with_trace_hang(self): + dtypes_to_test = [ + "float32", + ] + for dtype in dtypes_to_test: + self.check_with_place( + "collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype=dtype, + need_envs={"FLAGS_enable_async_trace": "True"}, + ) + def test_allgather_gloo_dygraph(self): dtypes_to_test = [ "float16",