Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

listen_and_serv_op support async update #10042

Merged
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
79a1a7c
init async gprc server
jacquesqiao Apr 18, 2018
1a43828
implement main logic
jacquesqiao Apr 18, 2018
e84f353
optimize
jacquesqiao Apr 19, 2018
a39e607
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 20, 2018
d002aa7
update
jacquesqiao Apr 20, 2018
1e30c41
add split string
jacquesqiao Apr 20, 2018
3608301
Merge branch 'refine-listen-and-serve-op' of ssh://github.com/jacques…
jacquesqiao Apr 20, 2018
0a881a1
init RunAsyncUpdate
jacquesqiao Apr 20, 2018
f997c9b
Merge branch 'refine-listen-and-serve-op' of ssh://github.com/jacques…
jacquesqiao Apr 20, 2018
e2ace03
rename RunAsyncUpdate to RunAsyncLoop
jacquesqiao Apr 20, 2018
63fbdcf
update send_recv_op_test
jacquesqiao Apr 20, 2018
260bf5a
add sync_mode
jacquesqiao Apr 20, 2018
dc3d2dc
rename grad_map to grad_to_id
jacquesqiao Apr 20, 2018
0763ae9
remove unused file
jacquesqiao Apr 20, 2018
1d75674
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 21, 2018
c6937ab
tmp
jacquesqiao Apr 21, 2018
4b86b49
Merge branch 'fix-build-activation_op' of ssh://github.com/jacquesqia…
jacquesqiao Apr 21, 2018
5d32008
Merge branch 'split-optimize-op-into-signle-blocks' of ssh://github.c…
jacquesqiao Apr 22, 2018
1b5de9d
Merge branch 'split-optimize-op-into-signle-blocks' of ssh://github.c…
jacquesqiao Apr 22, 2018
39892fe
Merge branch 'split-optimize-op-into-signle-blocks' of ssh://github.c…
jacquesqiao Apr 22, 2018
63055a3
complete grad_to_id
jacquesqiao Apr 23, 2018
42a15a4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 23, 2018
34f2818
distribute transpiler support async config
jacquesqiao Apr 23, 2018
a0ced3d
async update can run
jacquesqiao Apr 23, 2018
a29e352
optimize code
jacquesqiao Apr 23, 2018
0881d80
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 23, 2018
3503c47
listen and serv default sync mode
jacquesqiao Apr 23, 2018
8081e15
fix send_recv_op_test
jacquesqiao Apr 24, 2018
63bf82d
Merge branch 'develop' into add-async-listen-and-serv-op
jacquesqiao Apr 25, 2018
63bd38b
code optimize
jacquesqiao Apr 26, 2018
0264ec3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 26, 2018
46342a2
delete useless code
jacquesqiao Apr 26, 2018
3295f31
optimize naming
jacquesqiao Apr 26, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
class RequestBase {
public:
explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe is_sync or just sync can tell the meaning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sync_mode means it works in a mode, but is_sync means itself is async. So I think sync_mode is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

const platform::DeviceContext* dev_ctx)
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
: service_(service),
cq_(cq),
sync_mode_(sync_mode),
status_(PROCESS),
dev_ctx_(dev_ctx) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {}
Expand All @@ -49,18 +53,25 @@ class RequestBase {
::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_;
const bool sync_mode_;
CallStatus status_;
const platform::DeviceContext* dev_ctx_;
};

class RequestSend final : public RequestBase {
public:
explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new VariableResponse(scope, dev_ctx_));
: RequestBase(service, cq, sync_mode, dev_ctx),
queue_(queue),
responder_(&ctx_) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

request_.reset(new VariableResponse(
scope,
dev_ctx_,
!sync_mode_ // create_scope
));

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought a while here, and think the current code is easier for user understand the intent.

} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
Expand All @@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
class RequestGet final : public RequestBase {
public:
explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::BlockingQueue<MessageWithName>* queue)
: RequestBase(service, cq, dev_ctx),
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
queue_(queue) {
Expand Down Expand Up @@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
class RequestPrefetch final : public RequestBase {
public:
explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::Executor* executor,
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx),
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
Expand Down Expand Up @@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_;
};

void AsyncGRPCServer::WaitClientGet(int count) {
Expand Down Expand Up @@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return;
}
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
&var_recv_queue_, dev_ctx_);
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
scope_, &var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status();
}

Expand All @@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
dev_ctx_, &var_get_queue_);
VLOG(4) << "Create RequestGet status:" << get->Status();
}

Expand All @@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
return;
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_ctx_);
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
dev_ctx_, executor_, program_, prefetch_ctx_);

VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
Expand All @@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";

PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
}

RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference:
Expand All @@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,

switch (base->Status()) {
case PROCESS: {
VLOG(4) << cq_name << " status:" << base->Status();
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
TryToRegisterNewOne();
base->Process();
break;
}
case FINISH: {
VLOG(4) << cq_name << " status:" << base->Status();
VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base;
break;
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class RequestBase;

class AsyncGRPCServer final {
public:
explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {}

void RunSyncUpdate();

Expand Down Expand Up @@ -95,6 +96,7 @@ class AsyncGRPCServer final {
std::unique_ptr<::grpc::Server> server_;

std::string address_;
const bool sync_mode_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/grpc_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}

void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/detail/variable_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class VariableResponse {
}

virtual ~VariableResponse() {
if (create_scope_) scope_->DeleteScope(local_scope_);
if (create_scope_) {
scope_->DeleteScope(local_scope_);
}
}

// return:
Expand All @@ -61,7 +63,7 @@ class VariableResponse {
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);

const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope& GetLocalScope() const { return *local_scope_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider GetMutableLocalScope that returns a pointer and avoid removing the const?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
Expand Down
116 changes: 114 additions & 2 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG(4) << "RunServer thread end";
}

static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}

static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
// TODO(qiao) maybe we can remove this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this means more "async" mode, trainer even doesn't know whether the sent gradient is updated to the server side weights before it gets the latest weights. Or do you mean by letting updates to different weights become parallel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation will update gradients in sequence if we keep this wait. This may influence the effect, I will do some test on it.

Copy link
Member Author

@jacquesqiao jacquesqiao Apr 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing with @typhoonzero , we think that each gradient should be put to an independent block queue to ensure that they are updated without conflict.

Copy link
Contributor

@Yancey1989 Yancey1989 Apr 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we think that each gradient should be put to an independent block queue

Do you mean each gradient of one parameter, such as grad_w1(trainer0), grad_w1(trainer1), grad_w2(trainer0), we put grad_w1(trainer0) and grad_w1(trainer1) into a queue, and grad_w2(trainer0) into another one?

According to the design doc, maybe we need multiple BlockingQueues so that each parameter can own one of them to implement a lock of updating parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we need multiple block queue, each will store gradients for on parameters, but we do not need to add a lock, because the queue will block until the optimize block is finished.

future.wait();
}

static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
Expand Down Expand Up @@ -169,15 +201,85 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
} // while(true)
}

void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad;

auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_to_id_str can be generated by listen_and_serv_op when initializing, read the ProgramDesc blocks and create a mapping, so we can save this attribute.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think grad_to_id_str should be created in Python by transpiler because the transpile logic know how to split the operator and block, listen_and_serv_op just use the result is fine, or it has to understand the detailed logic of transpiler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Firstly we want to make listen_and_serv_op general, that means it should not know that attributes and inputs are "grads" or parameters, it should simply receive the data and run a block.

In that case, for Async Execution, listen_and_serv is responsible to determine which block need to run when the data arrives. Just open for discussion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing with @typhoonzero, I get the point, I totally agree with the idea that listen_and_serv_op should be a general operator! We will find a better way to implement async update in the future PRs.

for (auto &grad_and_id : grad_to_id_str) {
std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");

std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
}
auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_to_prepared_block

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
}

VLOG(3) << "RunAsyncLoop into while";
bool exit_flag = false;
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
&(v.second->GetLocalScope()));
// TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we don't need to clear the rows, because of each gradient var is in a new scope.

Copy link
Member Author

@jacquesqiao jacquesqiao Apr 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! removed.

var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
}

if (exit_flag) {
rpc_service_->ShutDown();
break;
}
} // while(true)
}

void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();

bool sync_mode = Attr<bool>("sync_mode");

PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));

rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));

auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
Expand All @@ -202,7 +304,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sleep(5);
// Write to a file of server selected port for python use.
SavePort(rpc_service_);
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else {
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
}
}

class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -221,6 +327,12 @@ from send_op and send back variables to recv_op.
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_to_id",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_to_block_id?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/listen_and_serv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;

void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;

void Stop() override;

void RunImpl(const framework::Scope& scope,
Expand Down
Loading