-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
listen_and_serv_op support async update #10042
Changes from 29 commits
79a1a7c
1a43828
e84f353
a39e607
d002aa7
1e30c41
3608301
0a881a1
f997c9b
e2ace03
63fbdcf
260bf5a
dc3d2dc
0763ae9
1d75674
c6937ab
4b86b49
5d32008
1b5de9d
39892fe
63055a3
42a15a4
34f2818
a0ced3d
a29e352
0881d80
3503c47
8081e15
63bf82d
63bd38b
0264ec3
46342a2
3295f31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH }; | |
class RequestBase { | ||
public: | ||
explicit RequestBase(GrpcService::AsyncService* service, | ||
::grpc::ServerCompletionQueue* cq, | ||
::grpc::ServerCompletionQueue* cq, bool sync_mode, | ||
const platform::DeviceContext* dev_ctx) | ||
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { | ||
: service_(service), | ||
cq_(cq), | ||
sync_mode_(sync_mode), | ||
status_(PROCESS), | ||
dev_ctx_(dev_ctx) { | ||
PADDLE_ENFORCE(cq_); | ||
} | ||
virtual ~RequestBase() {} | ||
|
@@ -49,18 +53,25 @@ class RequestBase { | |
::grpc::ServerContext ctx_; | ||
GrpcService::AsyncService* service_; | ||
::grpc::ServerCompletionQueue* cq_; | ||
const bool sync_mode_; | ||
CallStatus status_; | ||
const platform::DeviceContext* dev_ctx_; | ||
}; | ||
|
||
class RequestSend final : public RequestBase { | ||
public: | ||
explicit RequestSend(GrpcService::AsyncService* service, | ||
::grpc::ServerCompletionQueue* cq, | ||
::grpc::ServerCompletionQueue* cq, bool sync_mode, | ||
framework::Scope* scope, ReceivedQueue* queue, | ||
const platform::DeviceContext* dev_ctx) | ||
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { | ||
request_.reset(new VariableResponse(scope, dev_ctx_)); | ||
: RequestBase(service, cq, sync_mode, dev_ctx), | ||
queue_(queue), | ||
responder_(&ctx_) { | ||
if (sync_mode_) { | ||
request_.reset(new VariableResponse(scope, dev_ctx_, false)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. request_.reset(new VariableResponse( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought a while here, and think the current code is easier for user understand the intent. |
||
} else { | ||
request_.reset(new VariableResponse(scope, dev_ctx_, true)); | ||
} | ||
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); | ||
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, | ||
cq_, cq_, this); | ||
|
@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase { | |
class RequestGet final : public RequestBase { | ||
public: | ||
explicit RequestGet(GrpcService::AsyncService* service, | ||
::grpc::ServerCompletionQueue* cq, | ||
::grpc::ServerCompletionQueue* cq, bool sync_mode, | ||
framework::Scope* scope, | ||
const platform::DeviceContext* dev_ctx, | ||
framework::BlockingQueue<MessageWithName>* queue) | ||
: RequestBase(service, cq, dev_ctx), | ||
: RequestBase(service, cq, sync_mode, dev_ctx), | ||
responder_(&ctx_), | ||
scope_(scope), | ||
queue_(queue) { | ||
|
@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase { | |
class RequestPrefetch final : public RequestBase { | ||
public: | ||
explicit RequestPrefetch(GrpcService::AsyncService* service, | ||
::grpc::ServerCompletionQueue* cq, | ||
::grpc::ServerCompletionQueue* cq, bool sync_mode, | ||
framework::Scope* scope, | ||
const platform::DeviceContext* dev_ctx, | ||
framework::Executor* executor, | ||
framework::ProgramDesc* program, | ||
framework::ExecutorPrepareContext* prefetch_ctx) | ||
: RequestBase(service, cq, dev_ctx), | ||
: RequestBase(service, cq, sync_mode, dev_ctx), | ||
responder_(&ctx_), | ||
scope_(scope), | ||
executor_(executor), | ||
program_(program), | ||
prefetch_ctx_(prefetch_ctx) { | ||
request_.reset(new VariableResponse(scope, dev_ctx_)); | ||
if (sync_mode_) { | ||
request_.reset(new VariableResponse(scope, dev_ctx_, false)); | ||
} else { | ||
request_.reset(new VariableResponse(scope, dev_ctx_, true)); | ||
} | ||
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); | ||
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, | ||
cq_, cq_, this); | ||
|
@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase { | |
framework::Executor* executor_; | ||
framework::ProgramDesc* program_; | ||
framework::ExecutorPrepareContext* prefetch_ctx_; | ||
int blkid_; | ||
}; | ||
|
||
void AsyncGRPCServer::WaitClientGet(int count) { | ||
|
@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { | |
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; | ||
return; | ||
} | ||
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, | ||
&var_recv_queue_, dev_ctx_); | ||
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, | ||
scope_, &var_recv_queue_, dev_ctx_); | ||
VLOG(4) << "Create RequestSend status:" << send->Status(); | ||
} | ||
|
||
|
@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { | |
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; | ||
return; | ||
} | ||
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, | ||
&var_get_queue_); | ||
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, | ||
dev_ctx_, &var_get_queue_); | ||
VLOG(4) << "Create RequestGet status:" << get->Status(); | ||
} | ||
|
||
|
@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { | |
return; | ||
} | ||
RequestPrefetch* prefetch = | ||
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, | ||
executor_, program_, prefetch_ctx_); | ||
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, | ||
dev_ctx_, executor_, program_, prefetch_ctx_); | ||
|
||
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); | ||
} | ||
|
@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, | |
VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; | ||
|
||
PADDLE_ENFORCE(tag); | ||
// FIXME(typhoonzero): de-couple the barriers with recv_op | ||
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); | ||
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); | ||
if (sync_mode_) { | ||
// FIXME(typhoonzero): de-couple the barriers with recv_op | ||
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); | ||
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); | ||
} | ||
|
||
RequestBase* base = reinterpret_cast<RequestBase*>(tag); | ||
// reference: | ||
|
@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, | |
|
||
switch (base->Status()) { | ||
case PROCESS: { | ||
VLOG(4) << cq_name << " status:" << base->Status(); | ||
VLOG(4) << cq_name << " PROCESS status:" << base->Status(); | ||
TryToRegisterNewOne(); | ||
base->Process(); | ||
break; | ||
} | ||
case FINISH: { | ||
VLOG(4) << cq_name << " status:" << base->Status(); | ||
VLOG(4) << cq_name << " FINISH status:" << base->Status(); | ||
delete base; | ||
break; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,9 @@ class VariableResponse { | |
} | ||
|
||
virtual ~VariableResponse() { | ||
if (create_scope_) scope_->DeleteScope(local_scope_); | ||
if (create_scope_) { | ||
scope_->DeleteScope(local_scope_); | ||
} | ||
} | ||
|
||
// return: | ||
|
@@ -61,7 +63,7 @@ class VariableResponse { | |
// other: number of error field. | ||
int Parse(const ::grpc::ByteBuffer& byte_buffer); | ||
|
||
const framework::Scope& GetLocalScope() const { return *local_scope_; } | ||
framework::Scope& GetLocalScope() const { return *local_scope_; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider GetMutableLocalScope that returns a pointer and avoid removing the const? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
inline std::string Varname() { return meta_.varname(); } | ||
inline std::string OutVarname() { return meta_.out_varname(); } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { | |
VLOG(4) << "RunServer thread end"; | ||
} | ||
|
||
static void split(const std::string &str, char sep, | ||
std::vector<std::string> *pieces) { | ||
pieces->clear(); | ||
if (str.empty()) { | ||
return; | ||
} | ||
size_t pos = 0; | ||
size_t next = str.find(sep, pos); | ||
while (next != std::string::npos) { | ||
pieces->push_back(str.substr(pos, next - pos)); | ||
pos = next + 1; | ||
next = str.find(sep, pos); | ||
} | ||
if (!str.substr(pos).empty()) { | ||
pieces->push_back(str.substr(pos)); | ||
} | ||
} | ||
|
||
static void AsyncExecuteBlock(framework::Executor *executor, | ||
framework::ExecutorPrepareContext *prepared, | ||
framework::Scope *scope) { | ||
std::future<void> future = framework::Async([&executor, &prepared, &scope]() { | ||
try { | ||
executor->RunPreparedContext(prepared, scope, false, false); | ||
} catch (std::exception &e) { | ||
LOG(ERROR) << "run sub program error " << e.what(); | ||
} | ||
}); | ||
// TODO(qiao) maybe we can remove this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing this means more "async" mode, trainer even doesn't know whether the sent gradient is updated to the server side weights before it gets the latest weights. Or do you mean by letting updates to different weights become parallel? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation will update gradients in sequence if we keep this wait. This may influence the effect, I will do some test on it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After discussing with @typhoonzero , we think that each gradient should be put to an independent block queue to ensure that they are updated without conflict. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do you mean each gradient of one parameter, such as grad_w1(trainer0), grad_w1(trainer1), grad_w2(trainer0), we put grad_w1(trainer0) and grad_w1(trainer1) into a queue, and grad_w2(trainer0) into another one? According to the design doc, maybe we need multiple BlockingQueues so that each parameter can own one of them to implement a lock of updating parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we need multiple block queue, each will store gradients for on parameters, but we do not need to add a lock, because the queue will block until the optimize block is finished. |
||
future.wait(); | ||
} | ||
|
||
static void ParallelExecuteBlocks( | ||
const std::vector<size_t> ¶llel_blkids, framework::Executor *executor, | ||
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>> | ||
|
@@ -169,15 +201,85 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, | |
} // while(true) | ||
} | ||
|
||
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, | ||
framework::ProgramDesc *program, | ||
framework::Scope *recv_scope, | ||
framework::BlockDesc *prefetch_block) const { | ||
VLOG(3) << "RunAsyncLoop in"; | ||
// grad name to block id | ||
std::unordered_map<std::string, int32_t> grad_to_id; | ||
std::unordered_map<int32_t, std::string> id_to_grad; | ||
|
||
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think grad_to_id_str should be created in Python by transpiler because the transpile logic know how to split the operator and block, listen_and_serv_op just use the result is fine, or it has to understand the detailed logic of transpiler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Firstly we want to make In that case, for Async Execution, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After discussing with @typhoonzero, I get the point, I totally agree with the idea that listen_and_serv_op should be a general operator! We will find a better way to implement async update in the future PRs. |
||
for (auto &grad_and_id : grad_to_id_str) { | ||
std::vector<std::string> pieces; | ||
split(grad_and_id, ':', &pieces); | ||
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; | ||
PADDLE_ENFORCE_EQ(pieces.size(), 2); | ||
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0); | ||
int block_id = std::stoi(pieces[1]); | ||
grad_to_id[pieces[0]] = block_id; | ||
id_to_grad[block_id] = pieces[0]; | ||
} | ||
size_t num_blocks = program->Size(); | ||
PADDLE_ENFORCE_GE(num_blocks, 2, | ||
"server program should have at least 2 blocks"); | ||
|
||
std::vector<int> block_list; | ||
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { | ||
block_list.push_back(blkid); | ||
} | ||
auto optimize_prepared = executor->Prepare(*program, block_list); | ||
std::unordered_map<std::string, | ||
std::shared_ptr<framework::ExecutorPrepareContext>> | ||
grad_to_prepared; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. grad_to_prepared_block There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
for (size_t i = 0; i < block_list.size(); ++i) { | ||
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i]; | ||
} | ||
|
||
VLOG(3) << "RunAsyncLoop into while"; | ||
bool exit_flag = false; | ||
while (!exit_flag) { | ||
const detail::ReceivedMessage v = rpc_service_->Get(); | ||
auto recv_var_name = v.first; | ||
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { | ||
LOG(INFO) << "received terminate message and exit"; | ||
exit_flag = true; | ||
break; | ||
} else { | ||
VLOG(3) << "received grad: " << recv_var_name; | ||
auto var = v.second->GetVar(); | ||
if (var == nullptr) { | ||
LOG(ERROR) << "Can not find server side var: " << recv_var_name; | ||
PADDLE_THROW("Can not find server side var"); | ||
} | ||
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(), | ||
&(v.second->GetLocalScope())); | ||
// TODO(qiao): explain why | ||
if (var->IsType<framework::SelectedRows>()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we don't need to clear the rows, because of each gradient var is in a new scope. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great suggestion! removed. |
||
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); | ||
} | ||
} | ||
|
||
if (exit_flag) { | ||
rpc_service_->ShutDown(); | ||
break; | ||
} | ||
} // while(true) | ||
} | ||
|
||
void ListenAndServOp::RunImpl(const framework::Scope &scope, | ||
const platform::Place &dev_place) const { | ||
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); | ||
auto &dev_ctx = *pool.Get(dev_place); | ||
framework::Scope &recv_scope = scope.NewScope(); | ||
|
||
bool sync_mode = Attr<bool>("sync_mode"); | ||
|
||
PADDLE_ENFORCE(!rpc_service_); | ||
std::string endpoint = Attr<std::string>("endpoint"); | ||
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); | ||
|
||
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode)); | ||
|
||
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); | ||
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock); | ||
|
@@ -202,7 +304,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, | |
sleep(5); | ||
// Write to a file of server selected port for python use. | ||
SavePort(rpc_service_); | ||
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); | ||
if (sync_mode) { | ||
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); | ||
} else { | ||
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block); | ||
} | ||
} | ||
|
||
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { | ||
|
@@ -221,6 +327,12 @@ from send_op and send back variables to recv_op. | |
"IP address to listen on.") | ||
.SetDefault("127.0.0.1:6164") | ||
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); | ||
AddAttr<std::vector<std::string>>( | ||
"grad_to_id", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. grad_to_block_id? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] " | ||
"a map from grad name to it's optimize block id") | ||
.SetDefault({}); | ||
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); | ||
AddAttr<framework::BlockDesc *>(kOptimizeBlock, | ||
"BlockID to run on server side."); | ||
AddAttr<framework::BlockDesc *>(kPrefetchBlock, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
is_sync
or justsync
can tell the meaning?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
sync_mode
means it works in a mode, butis_sync
means itself is async. So I thinksync_mode
is better.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.