Skip to content

Commit

Permalink
Fix trace hang (#57536)
Browse files Browse the repository at this point in the history
* fix trace hang

* fix compile error

* fix code style

* tinyfix

* tiny update

* fix code style

---------

Co-authored-by: ForFishes <1422485404@qq.com>
  • Loading branch information
wentaoyu and ForFishes authored Sep 21, 2023
1 parent 7f3a363 commit 798b1d4
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 177 deletions.
8 changes: 3 additions & 5 deletions paddle/fluid/distributed/collective/process_group_nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
<< ", use_calc_stream: " << use_calc_stream << ", "
<< GetGroupMessage();

int64_t numel = in_tensor.numel();
NCCL_CHECK(
phi::dynload::ncclAllReduce(in_tensor.data(),
out_tensor->data(),
// in_tensor.numel(),
numel,
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
Expand Down Expand Up @@ -895,7 +893,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
fn(nccl_comm, nccl_stream);
} else {
auto comm_task =
std::make_unique<phi::distributed::NCCLCommTask>(place,
std::make_shared<phi::distributed::NCCLCommTask>(place,
rank_,
size_,
gid_,
Expand Down Expand Up @@ -982,7 +980,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Point2Point(
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();

auto comm_task =
std::make_unique<phi::distributed::NCCLCommTask>(place,
std::make_shared<phi::distributed::NCCLCommTask>(place,
rank_,
size_,
gid_,
Expand Down
43 changes: 21 additions & 22 deletions paddle/phi/core/distributed/comm_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"

#if defined(PADDLE_WITH_RCCL)
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include "paddle/phi/backends/dynload/nccl.h"
#endif

namespace phi {
namespace distributed {

Expand All @@ -36,6 +42,8 @@ class CommTask {
int gid = 0,
uint64_t seq = 0,
int64_t numel = 0,
ncclComm_t nccl_comm = nullptr,
gpuStream_t nccl_stream = nullptr,
CommType comm_type = CommType::UNKNOWN)
: backend_(backend),
place_(place),
Expand All @@ -44,6 +52,8 @@ class CommTask {
gid_(gid),
seq_(seq),
numel_(numel),
nccl_comm_(nccl_comm),
nccl_stream_(nccl_stream),
comm_type_(comm_type) {
const char* global_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
Expand All @@ -54,6 +64,10 @@ class CommTask {
}
virtual ~CommTask() = default;

std::string UniqueKey() {
return "op:" + CommTypeToString(comm_type_) +
",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_);
}
std::string GetBackend() { return backend_; }
phi::Place GetPlace() { return place_; }
int GetGlobalRank() { return global_rank_; }
Expand All @@ -71,6 +85,9 @@ class CommTask {
std::shared_ptr<Store> GetStore() { return store_; }
void SetStore(std::shared_ptr<Store> store) { store_ = store; }

ncclComm_t nccl_comm() { return nccl_comm_; }
gpuStream_t nccl_stream() { return nccl_stream_; }

virtual std::string GetTraceMsg() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -87,20 +104,10 @@ class CommTask {
return;
}

virtual void SetException(std::exception_ptr exception) {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}
virtual void CheckAndSetException() {
virtual std::string GetCommErrors() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}
virtual std::exception_ptr CheckCommErrors() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return nullptr;
return "";
}
virtual bool IsStarted() {
PADDLE_THROW(
Expand All @@ -117,16 +124,6 @@ class CommTask {
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual bool IsSuccess() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual std::exception_ptr GetException() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return nullptr;
}
virtual void AbortComm() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -142,6 +139,8 @@ class CommTask {
int gid_;
uint64_t seq_{0};
int64_t numel_;
ncclComm_t nccl_comm_;
gpuStream_t nccl_stream_;
CommType comm_type_;
bool start_trace_updated_{false};

Expand Down
111 changes: 49 additions & 62 deletions paddle/phi/core/distributed/comm_task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif

DECLARE_int32(async_trace_count);

namespace phi {
namespace distributed {

Expand All @@ -48,14 +46,16 @@ const int64_t CommTaskManager::loop_thread_sleep_millis = 10000;
std::atomic<bool> CommTaskManager::terminated_;
std::mutex CommTaskManager::comm_task_list_mutex_;
std::condition_variable CommTaskManager::comm_task_list_cv_;
std::list<std::unique_ptr<CommTask>> CommTaskManager::comm_task_list_;
int CommTaskManager::check_timeout_count = 0;
std::list<std::shared_ptr<CommTask>> CommTaskManager::comm_task_list_;
std::unordered_map<std::string, std::shared_ptr<CommTask>>
CommTaskManager::init_comm_task_map_;
std::unordered_map<std::string, std::shared_ptr<CommTask>>
CommTaskManager::start_comm_task_map_;

CommTaskManager::CommTaskManager() {
terminated_.store(false);
comm_task_loop_thread_ = std::thread(&CommTaskManager::CommTaskLoop, this);
LOG(INFO) << "CommTaskManager init success. FLAGS_async_trace_count: "
<< FLAGS_async_trace_count;
LOG(INFO) << "CommTaskManager init success.";
}
CommTaskManager::~CommTaskManager() {
terminated_.store(true);
Expand All @@ -67,7 +67,7 @@ CommTaskManager::~CommTaskManager() {
LOG(INFO) << "CommTaskManager destruct success.";
}

void CommTaskManager::CommTaskEnqueue(std::unique_ptr<CommTask> comm_task) {
void CommTaskManager::CommTaskEnqueue(std::shared_ptr<CommTask> comm_task) {
if (!terminated_.load()) {
std::lock_guard<std::mutex> lock(comm_task_list_mutex_);
comm_task_list_.emplace_back(std::move(comm_task));
Expand All @@ -81,69 +81,56 @@ void CommTaskManager::CommTaskLoop() {
comm_task_list_cv_.wait_for(
lock,
std::chrono::milliseconds(loop_thread_sleep_millis),
[&]() -> bool {
return terminated_.load() &&
check_timeout_count <= FLAGS_async_trace_count;
});
for (auto task = comm_task_list_.begin(); task != comm_task_list_.end();) {
(*task)->CheckAndSetException();
if ((*task)->IsTimeout()) {
std::string exception_msg = (*task)->GetTraceMsg();
exception_msg += GenerateTraceMsg((*task)->GetStore(),
(*task)->GetBackend(),
(*task)->GetRank(),
(*task)->GetGid(),
(*task)->GetSize());
LOG(ERROR) << exception_msg;
std::exception_ptr exception_ptr =
std::make_exception_ptr(std::runtime_error(exception_msg));
(*task)->SetException(exception_ptr);
(*task)->AbortComm();

++check_timeout_count;
[&]() -> bool { return terminated_.load(); });
for (auto iter = comm_task_list_.begin(); iter != comm_task_list_.end();) {
auto task = *iter;
if (task->IsTimeout()) {
if (!task->IsStarted()) {
LOG(ERROR) << "Find timeout init but not start task: "
<< task->GetTraceMsg() << ",comm:" << task->nccl_comm()
<< ",stream:" << task->nccl_stream();
std::string task_key = task->UniqueKey();
init_comm_task_map_[task_key] = task;
} else if (!task->IsCompleted()) {
LOG(ERROR) << "Find timeout start but not finish task: "
<< task->GetTraceMsg() << ",comm:" << task->nccl_comm()
<< ",stream:" << task->nccl_stream();
std::string task_key = task->UniqueKey();
start_comm_task_map_[task_key] = task;
}
iter = comm_task_list_.erase(iter);
} else {
++iter;
}
}

if (!(*task)->GetTraceUpdated() && (*task)->IsStarted() &&
!terminated_.load() && !store_error_) {
std::string trace_key = GetTraceStartKey(
(*task)->GetBackend(), (*task)->GetRank(), (*task)->GetGid());
store_error_ =
!UpdateTraceMsg((*task)->GetStore(),
trace_key,
(*task)->GetSeq(),
CommTypeToString((*task)->GetCommType()));
(*task)->SetTraceUpdated();
for (auto iter = init_comm_task_map_.begin();
iter != init_comm_task_map_.end();) {
auto task = iter->second;
if (task->IsStarted()) {
std::string task_key = task->UniqueKey();
start_comm_task_map_[task_key] = task;
iter = init_comm_task_map_.erase(iter);
LOG(INFO) << "Start timeout task: " << task->GetTraceMsg();
} else {
++iter;
}
}

if ((*task)->IsCompleted()) {
if (!(*task)->GetTraceUpdated() && !terminated_.load() &&
!store_error_) {
std::string trace_key = GetTraceStartKey(
(*task)->GetBackend(), (*task)->GetRank(), (*task)->GetGid());
store_error_ =
!UpdateTraceMsg((*task)->GetStore(),
trace_key,
(*task)->GetSeq(),
CommTypeToString((*task)->GetCommType()));
(*task)->SetTraceUpdated();
}
if (!terminated_.load() && !store_error_) {
std::string trace_key = GetTraceEndKey(
(*task)->GetBackend(), (*task)->GetRank(), (*task)->GetGid());
store_error_ =
!UpdateTraceMsg((*task)->GetStore(),
trace_key,
(*task)->GetSeq(),
CommTypeToString((*task)->GetCommType()));
}
task = comm_task_list_.erase(task);
for (auto iter = start_comm_task_map_.begin();
iter != start_comm_task_map_.end();) {
auto task = iter->second;
if (task->IsCompleted()) {
iter = start_comm_task_map_.erase(iter);
LOG(INFO) << "Finish timeout task: " << task->GetTraceMsg();
} else {
++task;
++iter;
}
}
if (comm_task_list_.empty()) {

if (comm_task_list_.empty() && init_comm_task_map_.empty() &&
start_comm_task_map_.empty()) {
done = true;
check_timeout_count = 0;
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions paddle/phi/core/distributed/comm_task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CommTaskManager {
return instance;
}

void CommTaskEnqueue(std::unique_ptr<CommTask> comm_task);
void CommTaskEnqueue(std::shared_ptr<CommTask> comm_task);

private:
void CommTaskLoop();
Expand All @@ -58,13 +58,15 @@ class CommTaskManager {

static std::mutex comm_task_list_mutex_;
static std::condition_variable comm_task_list_cv_;
static std::list<std::unique_ptr<CommTask>> comm_task_list_;
static std::list<std::shared_ptr<CommTask>> comm_task_list_;
// not start task
static std::unordered_map<std::string, std::shared_ptr<CommTask>>
init_comm_task_map_;
// start but not finish task
static std::unordered_map<std::string, std::shared_ptr<CommTask>>
start_comm_task_map_;
std::shared_ptr<Store> store_;
bool store_error_{false};

int comm_seq_;
// timeout count, only check first timeout task
static int check_timeout_count;
};

} // namespace distributed
Expand Down
Loading

0 comments on commit 798b1d4

Please sign in to comment.