Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance distributed train performance #7608

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(var_name);

auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);

// varhandle
VarHandle var_h;
var_h.ep = ep;
Expand Down
50 changes: 30 additions & 20 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class RequestBase {

CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }
virtual std::string GetReqName() {
assert(false);
return "";
}

protected:
grpc::ServerContext ctx_;
Expand Down Expand Up @@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
dev_ctx_(dev_ctx) {
dev_ctx_(dev_ctx),
queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
}

Expand All @@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
queue_->Push('c');
}

protected:
Expand All @@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
};

void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
}
}

void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
Expand Down Expand Up @@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
}

// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void AsyncGRPCServer::ShutDown() {
server_->Shutdown();
ShutdownQueue();
Expand All @@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) {
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
}

// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
Expand All @@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}

PADDLE_ENFORCE(tag);
if (wait && !done_) {
Wait();
}
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (cq_name == "cq_get") WaitCond(1);
if (cq_name == "cq_send") WaitCond(0);

RequestBase* base = (RequestBase*)tag;
// reference:
Expand Down Expand Up @@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
}

void AsyncGRPCServer::Wait() {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}

void AsyncGRPCServer::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
barrier_condition_.wait(lock,
[=] { return this->barrier_cond_step_ == cond; });
}

void AsyncGRPCServer::Done() {
void AsyncGRPCServer::SetCond(int cond) {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
condition_.notify_all();
barrier_condition_.notify_all();
}

} // namespace detail
Expand Down
15 changes: 8 additions & 7 deletions paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {

void RunSyncUpdate();

void Reset();

void Done();
// functions to sync server barrier status.
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);

void SetScope(framework::Scope *scope) { scope_ = scope; }

Expand All @@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown();

protected:
void Wait();
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne);
Expand All @@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;

// condition of the sub program
std::mutex mutex_;
volatile mutable bool done_;
std::condition_variable condition_;
std::mutex barrier_mutex_;
mutable int barrier_cond_step_;
std::condition_variable barrier_condition_;

std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_;
Expand Down
74 changes: 30 additions & 44 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"

namespace paddle {
namespace operators {

constexpr int kCondStart = 0;
constexpr int kCondRunning = 1;
constexpr int kCondDone = 2;

void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
Expand Down Expand Up @@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
char ret[256];
snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
grads_counter_[varname]++);
return std::string(ret);
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}

void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();

// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();

rpc_service_->Reset();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);

// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
VLOG(4) << "param_count:" << param_count
<< " trainer_count:" << trainer_count;
int64_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) {
// blocking get one var from client.
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 这一块代码太长了, 导致main loop比较复杂,容易出错。建议改成单独的一个函数。

const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit";
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
Expand All @@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name
<< "\"";
LOG(ERROR) << "grad have no paired param:" << grad_var_name;
}
VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;

auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
auto *ptr = recv_scope.Var(grad_var_name);
CreateTensorFromMessageType(ptr, v.second.type());
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr << " type is "
<< v.second.type();
}

if (trainer_count > 1) {
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}

auto *var = recv_scope.Var(grad_var_name);
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) {
LOG(ERROR) << "can not find server side var: " << grad_var_name;
PADDLE_THROW("can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}

if (exit_flag) {
break;
}

rpc_service_->Reset();

std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}

rpc_service_->Done();
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(barrier_size);
grads_counter_.clear();
} // while(true)
}
Expand Down Expand Up @@ -199,7 +185,7 @@ This operator will recv tensor from send_op
"GradList", "type list of string",
"grad->param name mapping to find which param to optimize.")
.SetDefault({});
AddAttr<int>("Trainers", "type int",
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
.SetDefault(1);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
PADDLE_ENFORCE(client_.Wait());

for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

Expand Down
15 changes: 14 additions & 1 deletion python/paddle/v2/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,19 @@ def get_pserver_program(self, endpoint):
pserver_program = Program()
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
# step6
optimize_sub_program = Program()
for idx, opt_op in enumerate(self.optimize_ops):
Expand Down Expand Up @@ -449,7 +462,7 @@ def get_pserver_program(self, endpoint):
p.name
for p in self.param_grad_ep_mapping[endpoint]["grads"]
],
"Trainers": self.trainers
"Fanin": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,27 @@
place = fluid.CPUPlace()
exe = fluid.Executor(place)

t = fluid.DistributeTranspiler()
# all parameter server endpoints list for spliting parameters
pserver_endpoints = os.getenv("PSERVERS")
# server endpoint for current node
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints
trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)

if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe.run(fluid.default_startup_program())

for pass_id in range(PASS_NUM):
Expand Down