Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix trace hang #57536

Merged
merged 6 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions paddle/fluid/distributed/collective/process_group_nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
<< ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream;

int64_t numel = in_tensor.numel();
NCCL_CHECK(
phi::dynload::ncclAllReduce(in_tensor.data(),
out_tensor->data(),
// in_tensor.numel(),
numel,
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
Expand Down Expand Up @@ -893,7 +891,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
fn(nccl_comm, nccl_stream);
} else {
auto comm_task =
std::make_unique<phi::distributed::NCCLCommTask>(place,
std::make_shared<phi::distributed::NCCLCommTask>(place,
rank_,
size_,
gid_,
Expand Down Expand Up @@ -980,7 +978,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Point2Point(
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();

auto comm_task =
std::make_unique<phi::distributed::NCCLCommTask>(place,
std::make_shared<phi::distributed::NCCLCommTask>(place,
rank_,
size_,
gid_,
Expand Down
43 changes: 21 additions & 22 deletions paddle/phi/core/distributed/comm_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"

#if defined(PADDLE_WITH_RCCL)
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include "paddle/phi/backends/dynload/nccl.h"
#endif

namespace phi {
namespace distributed {

Expand All @@ -36,6 +42,8 @@ class CommTask {
int gid = 0,
uint64_t seq = 0,
int64_t numel = 0,
ncclComm_t nccl_comm = nullptr,
gpuStream_t nccl_stream = nullptr,
CommType comm_type = CommType::UNKNOWN)
: backend_(backend),
place_(place),
Expand All @@ -44,6 +52,8 @@ class CommTask {
gid_(gid),
seq_(seq),
numel_(numel),
nccl_comm_(nccl_comm),
nccl_stream_(nccl_stream),
comm_type_(comm_type) {
const char* global_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
Expand All @@ -54,6 +64,10 @@ class CommTask {
}
virtual ~CommTask() = default;

std::string UniqueKey() {
return "op:" + CommTypeToString(comm_type_) +
",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_);
}
std::string GetBackend() { return backend_; }
phi::Place GetPlace() { return place_; }
int GetGlobalRank() { return global_rank_; }
Expand All @@ -71,6 +85,9 @@ class CommTask {
std::shared_ptr<Store> GetStore() { return store_; }
void SetStore(std::shared_ptr<Store> store) { store_ = store; }

ncclComm_t nccl_comm() { return nccl_comm_; }
gpuStream_t nccl_stream() { return nccl_stream_; }

virtual std::string GetTraceMsg() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -87,20 +104,10 @@ class CommTask {
return;
}

virtual void SetException(std::exception_ptr exception) {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}
virtual void CheckAndSetException() {
virtual std::string GetCommErrors() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}
virtual std::exception_ptr CheckCommErrors() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return nullptr;
return "";
}
virtual bool IsStarted() {
PADDLE_THROW(
Expand All @@ -117,16 +124,6 @@ class CommTask {
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual bool IsSuccess() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual std::exception_ptr GetException() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return nullptr;
}
virtual void AbortComm() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -142,6 +139,8 @@ class CommTask {
int gid_;
uint64_t seq_{0};
int64_t numel_;
ncclComm_t nccl_comm_;
gpuStream_t nccl_stream_;
CommType comm_type_;
bool start_trace_updated_{false};

Expand Down
111 changes: 49 additions & 62 deletions paddle/phi/core/distributed/comm_task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif

DECLARE_int32(async_trace_count);

namespace phi {
namespace distributed {

Expand All @@ -48,14 +46,16 @@ const int64_t CommTaskManager::loop_thread_sleep_millis = 10000;
std::atomic<bool> CommTaskManager::terminated_;
std::mutex CommTaskManager::comm_task_list_mutex_;
std::condition_variable CommTaskManager::comm_task_list_cv_;
std::list<std::unique_ptr<CommTask>> CommTaskManager::comm_task_list_;
int CommTaskManager::check_timeout_count = 0;
std::list<std::shared_ptr<CommTask>> CommTaskManager::comm_task_list_;
std::unordered_map<std::string, std::shared_ptr<CommTask>>
CommTaskManager::init_comm_task_map_;
std::unordered_map<std::string, std::shared_ptr<CommTask>>
CommTaskManager::start_comm_task_map_;

CommTaskManager::CommTaskManager() {
terminated_.store(false);
comm_task_loop_thread_ = std::thread(&CommTaskManager::CommTaskLoop, this);
LOG(INFO) << "CommTaskManager init success. FLAGS_async_trace_count: "
<< FLAGS_async_trace_count;
LOG(INFO) << "CommTaskManager init success.";
}
CommTaskManager::~CommTaskManager() {
terminated_.store(true);
Expand All @@ -67,7 +67,7 @@ CommTaskManager::~CommTaskManager() {
LOG(INFO) << "CommTaskManager destruct success.";
}

void CommTaskManager::CommTaskEnqueue(std::unique_ptr<CommTask> comm_task) {
void CommTaskManager::CommTaskEnqueue(std::shared_ptr<CommTask> comm_task) {
if (!terminated_.load()) {
std::lock_guard<std::mutex> lock(comm_task_list_mutex_);
comm_task_list_.emplace_back(std::move(comm_task));
Expand All @@ -81,69 +81,56 @@ void CommTaskManager::CommTaskLoop() {
comm_task_list_cv_.wait_for(
lock,
std::chrono::milliseconds(loop_thread_sleep_millis),
[&]() -> bool {
return terminated_.load() &&
check_timeout_count <= FLAGS_async_trace_count;
});
for (auto task = comm_task_list_.begin(); task != comm_task_list_.end();) {
(*task)->CheckAndSetException();
if ((*task)->IsTimeout()) {
std::string exception_msg = (*task)->GetTraceMsg();
exception_msg += GenerateTraceMsg((*task)->GetStore(),
(*task)->GetBackend(),
(*task)->GetRank(),
(*task)->GetGid(),
(*task)->GetSize());
LOG(ERROR) << exception_msg;
std::exception_ptr exception_ptr =
std::make_exception_ptr(std::runtime_error(exception_msg));
(*task)->SetException(exception_ptr);
(*task)->AbortComm();

++check_timeout_count;
[&]() -> bool { return terminated_.load(); });
for (auto iter = comm_task_list_.begin(); iter != comm_task_list_.end();) {
auto task = *iter;
if (task->IsTimeout()) {
if (!task->IsStarted()) {
LOG(ERROR) << "Find timeout init but not start task: "
<< task->GetTraceMsg() << ",comm:" << task->nccl_comm()
<< ",stream:" << task->nccl_stream();
std::string task_key = task->UniqueKey();
init_comm_task_map_[task_key] = task;
} else if (!task->IsCompleted()) {
LOG(ERROR) << "Find timeout start but not finish task: "
<< task->GetTraceMsg() << ",comm:" << task->nccl_comm()
<< ",stream:" << task->nccl_stream();
std::string task_key = task->UniqueKey();
start_comm_task_map_[task_key] = task;
}
iter = comm_task_list_.erase(iter);
} else {
++iter;
}
}

if (!(*task)->GetTraceUpdated() && (*task)->IsStarted() &&
!terminated_.load() && !store_error_) {
std::string trace_key = GetTraceStartKey(
(*task)->GetBackend(), (*task)->GetRank(), (*task)->GetGid());
store_error_ =
!UpdateTraceMsg((*task)->GetStore(),
trace_key,
(*task)->GetSeq(),
CommTypeToString((*task)->GetCommType()));
(*task)->SetTraceUpdated();
for (auto iter = init_comm_task_map_.begin();
iter != init_comm_task_map_.end();) {
auto task = iter->second;
if (task->IsStarted()) {
std::string task_key = task->UniqueKey();
start_comm_task_map_[task_key] = task;
iter = init_comm_task_map_.erase(iter);
LOG(INFO) << "Start timeout task: " << task->GetTraceMsg();
} else {
++iter;
}
}

if ((*task)->IsCompleted()) {
if (!(*task)->GetTraceUpdated() && !terminated_.load() &&
!store_error_) {
std::string trace_key = GetTraceStartKey(
(*task)->GetBackend(), (*task)->GetRank(), (*task)->GetGid());
store_error_ =
!UpdateTraceMsg((*task)->GetStore(),
trace_key,
(*task)->GetSeq(),
CommTypeToString((*task)->GetCommType()));
(*task)->SetTraceUpdated();
}
if (!terminated_.load() && !store_error_) {
std::string trace_key = GetTraceEndKey(
(*task)->GetBackend(), (*task)->GetRank(), (*task)->GetGid());
store_error_ =
!UpdateTraceMsg((*task)->GetStore(),
trace_key,
(*task)->GetSeq(),
CommTypeToString((*task)->GetCommType()));
}
task = comm_task_list_.erase(task);
for (auto iter = start_comm_task_map_.begin();
iter != start_comm_task_map_.end();) {
auto task = iter->second;
if (task->IsCompleted()) {
iter = start_comm_task_map_.erase(iter);
LOG(INFO) << "Finish timeout task: " << task->GetTraceMsg();
} else {
++task;
++iter;
}
}
if (comm_task_list_.empty()) {

if (comm_task_list_.empty() && init_comm_task_map_.empty() &&
start_comm_task_map_.empty()) {
done = true;
check_timeout_count = 0;
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions paddle/phi/core/distributed/comm_task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CommTaskManager {
return instance;
}

void CommTaskEnqueue(std::unique_ptr<CommTask> comm_task);
void CommTaskEnqueue(std::shared_ptr<CommTask> comm_task);

private:
void CommTaskLoop();
Expand All @@ -58,13 +58,15 @@ class CommTaskManager {

static std::mutex comm_task_list_mutex_;
static std::condition_variable comm_task_list_cv_;
static std::list<std::unique_ptr<CommTask>> comm_task_list_;
static std::list<std::shared_ptr<CommTask>> comm_task_list_;
// not start task
static std::unordered_map<std::string, std::shared_ptr<CommTask>>
init_comm_task_map_;
// start but not finish task
static std::unordered_map<std::string, std::shared_ptr<CommTask>>
start_comm_task_map_;
std::shared_ptr<Store> store_;
bool store_error_{false};

int comm_seq_;
// timeout count, only check first timeout task
static int check_timeout_count;
};

} // namespace distributed
Expand Down
Loading