Skip to content

Commit

Permalink
Merge pull request #7608 from typhoonzero/distributed_split_selectedrows
Browse files Browse the repository at this point in the history
Enhance distributed train performance
  • Loading branch information
typhoonzero authored Jan 19, 2018
2 parents b7b5de7 + 0aff136 commit 58be41f
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 83 deletions.
3 changes: 0 additions & 3 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(var_name);

auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);

// varhandle
VarHandle var_h;
var_h.ep = ep;
Expand Down
50 changes: 30 additions & 20 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class RequestBase {

CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }
virtual std::string GetReqName() {
assert(false);
return "";
}

protected:
grpc::ServerContext ctx_;
Expand Down Expand Up @@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
dev_ctx_(dev_ctx) {
dev_ctx_(dev_ctx),
queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
}

Expand All @@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
queue_->Push('c');
}

protected:
Expand All @@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
};

void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
}
}

void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
Expand Down Expand Up @@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
}

// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void AsyncGRPCServer::ShutDown() {
server_->Shutdown();
ShutdownQueue();
Expand All @@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) {
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
}

// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
Expand All @@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}

PADDLE_ENFORCE(tag);
if (wait && !done_) {
Wait();
}
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (cq_name == "cq_get") WaitCond(1);
if (cq_name == "cq_send") WaitCond(0);

RequestBase* base = (RequestBase*)tag;
// reference:
Expand Down Expand Up @@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
}

void AsyncGRPCServer::Wait() {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}

void AsyncGRPCServer::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
barrier_condition_.wait(lock,
[=] { return this->barrier_cond_step_ == cond; });
}

void AsyncGRPCServer::Done() {
void AsyncGRPCServer::SetCond(int cond) {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
condition_.notify_all();
barrier_condition_.notify_all();
}

} // namespace detail
Expand Down
15 changes: 8 additions & 7 deletions paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {

void RunSyncUpdate();

void Reset();

void Done();
// functions to sync server barrier status.
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);

void SetScope(framework::Scope *scope) { scope_ = scope; }

Expand All @@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown();

protected:
void Wait();
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne);
Expand All @@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;

// condition of the sub program
std::mutex mutex_;
volatile mutable bool done_;
std::condition_variable condition_;
std::mutex barrier_mutex_;
mutable int barrier_cond_step_;
std::condition_variable barrier_condition_;

std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_;
Expand Down
74 changes: 30 additions & 44 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"

namespace paddle {
namespace operators {

constexpr int kCondStart = 0;
constexpr int kCondRunning = 1;
constexpr int kCondDone = 2;

void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
Expand Down Expand Up @@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
char ret[256];
snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
grads_counter_[varname]++);
return std::string(ret);
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}

void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();

// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();

rpc_service_->Reset();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);

// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
VLOG(4) << "param_count:" << param_count
<< " trainer_count:" << trainer_count;
int64_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) {
// blocking get one var from client.
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit";
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
Expand All @@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name
<< "\"";
LOG(ERROR) << "grad have no paired param:" << grad_var_name;
}
VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;

auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
auto *ptr = recv_scope.Var(grad_var_name);
CreateTensorFromMessageType(ptr, v.second.type());
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr << " type is "
<< v.second.type();
}

if (trainer_count > 1) {
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}

auto *var = recv_scope.Var(grad_var_name);
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) {
LOG(ERROR) << "can not find server side var: " << grad_var_name;
PADDLE_THROW("can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}

if (exit_flag) {
break;
}

rpc_service_->Reset();

std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}

rpc_service_->Done();
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(barrier_size);
grads_counter_.clear();
} // while(true)
}
Expand Down Expand Up @@ -199,7 +185,7 @@ This operator will recv tensor from send_op
"GradList", "type list of string",
"grad->param name mapping to find which param to optimize.")
.SetDefault({});
AddAttr<int>("Trainers", "type int",
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
.SetDefault(1);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
PADDLE_ENFORCE(client_.Wait());

for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

Expand Down
15 changes: 14 additions & 1 deletion python/paddle/v2/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,19 @@ def get_pserver_program(self, endpoint):
pserver_program = Program()
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
# step6
optimize_sub_program = Program()
for idx, opt_op in enumerate(self.optimize_ops):
Expand Down Expand Up @@ -449,7 +462,7 @@ def get_pserver_program(self, endpoint):
p.name
for p in self.param_grad_ep_mapping[endpoint]["grads"]
],
"Trainers": self.trainers
"Fanin": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,27 @@
place = fluid.CPUPlace()
exe = fluid.Executor(place)

t = fluid.DistributeTranspiler()
# all parameter server endpoints list for spliting parameters
pserver_endpoints = os.getenv("PSERVERS")
# server endpoint for current node
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints
trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)

if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe.run(fluid.default_startup_program())

for pass_id in range(PASS_NUM):
Expand Down

0 comments on commit 58be41f

Please sign in to comment.