From bcfb82d33e431d621317f97d3c0703d9b002a8ee Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 15 Jan 2018 20:55:48 +0800 Subject: [PATCH 1/5] dist train support split selectedrows --- .../paddle/v2/fluid/distribute_transpiler.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index d17f9815cca5e..00fe3e68c9008 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -59,6 +59,51 @@ def split_dense_variable(var_list, return blocks +def split_selected_rows(var, + pserver_count, + min_block_size=1024, + max_block_size=1048576): + assert ((len(var.shape)) <= 1) + + split_count = pserver_count + indices = var.desc.selected_rows().dims() + var_width = reduce(lambda x, y: x * y, var.shape[1:]) + row_count = len(indices) + rows_per_block = 1 + if var_width < min_block_size: + rows_per_block = 1 + split_count = row_count + else: + rows_per_block = row_count / pserver_count + if not rows_per_block % pserver_count: + rows_per_block += 1 + split_count = row_count / rows_per_block + if not row_count % rows_per_block: + split_count += 1 + blocks = [] + for block_id in xrange(split_count): + curr_block_rows = min(rows_per_block, + row_count - (block_id * rows_per_block)) + block = VarBlock(var.name, block_id, curr_block_rows) + blocks.append(block) + return blocks + + +def split_variable(var_list, + pserver_count, + min_block_size=1024, + max_block_size=1048576): + for var in var_list: + if var.type == core.VarDesc.VarType.LOD_TENSOR: + split_dense_variable(var_list, pserver_count, min_block_size, + max_block_size) + elif var.type == core.VarDesc.VarType.SELECTED_ROWS: + split_selected_rows(var_list, pserver_count, min_block_size, + max_block_size) + else: + raise TypeError("variable must be lodtensor or selected rows") + + class DistributeTranspiler: def transpile(self, optimize_ops, From 02ea349101662e5ad5199dac47b48f1835eda361 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 17 Jan 2018 18:02:45 +0800 Subject: [PATCH 2/5] enhance dist train performance --- paddle/operators/detail/grpc_client.cc | 5 +- paddle/operators/detail/grpc_client.h | 2 +- paddle/operators/recv_op.cc | 66 ++++++++----------- paddle/operators/send_op.cc | 6 +- .../paddle/v2/fluid/distribute_transpiler.py | 15 ++++- .../notest_recognize_digits_conv_dist.py | 17 ++--- 6 files changed, 55 insertions(+), 56 deletions(-) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index 5a4db2d7e686c..521760228b5d7 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, sendrecv::VariableMessage req; req.set_varname(var_name); - auto* var = scope.FindVar(var_name); - SerializeToMessage(var_name, var, ctx, &req); - // varhandle VarHandle var_h; var_h.ep = ep; @@ -87,7 +84,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::wait() { +bool RPCClient::Wait() { bool ok = true; while (true) { diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h index d27b5ced9ece6..a62e70a2533ae 100644 --- a/paddle/operators/detail/grpc_client.h +++ b/paddle/operators/detail/grpc_client.h @@ -130,7 +130,7 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, int64_t time_out = 600 * 1000); - bool wait(); + bool Wait(); private: bool Proceed(); diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 55b33343af438..dea7db391cf56 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/operators/detail/grpc_server.h" #include "paddle/operators/detail/sendrecvop_utils.h" #include "paddle/operators/detail/simple_block_queue.h" +#include "paddle/string/printf.h" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" @@ -77,35 +78,37 @@ class RecvOp : public framework::OperatorBase { if (grads_counter_.find(varname) == grads_counter_.end()) { grads_counter_[varname] = 0; } - char ret[256]; - snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(), - grads_counter_[varname]++); - return std::string(ret); + return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); } void Run(const framework::Scope &scope, const platform::Place &dev_place) const override { - // FIXME(typhoonzero): no new scopes for every run. + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); framework::Scope &recv_scope = scope.NewScope(); rpc_service_->SetScope(&recv_scope); auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); - auto trainer_count = Attr("Trainers"); + auto fan_in = Attr("Fanin"); size_t param_count = param_list.size(); + std::string program_str = Attr("OptimizeProgram"); + framework::proto::ProgramDesc program_desc; + program_desc.ParseFromString(program_str); + framework::ProgramDesc program(program_desc); + framework::Executor executor(dev_place); + rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; while (!exit_flag) { - // TODO(gognwb): simply this loop. - // Get from multiple trainers, we don't care about order in which - // the gradient arrives, just add suffix 0~n then average the gradient. - for (size_t i = 0; i < param_count * trainer_count; ++i) { - // blocking get one var from client. + // Get from multiple trainers, we don't care about the order in which + // the gradients arrives, just add suffix 0~n and merge the gradient. + for (size_t i = 0; i < param_count * fan_in; ++i) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { - VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit"; + LOG(INFO) << "received terminate message and exit"; exit_flag = true; break; } @@ -114,44 +117,27 @@ class RecvOp : public framework::OperatorBase { if (it != grad_list.end()) { param_var_name = param_list[it - grad_list.begin()]; } else { - LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name - << "\""; + LOG(ERROR) << "grad have no paired param:" << grad_var_name; } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; - - auto *merged_grad = recv_scope.FindVar(grad_var_name); - if (merged_grad == nullptr) { - auto *ptr = recv_scope.Var(grad_var_name); - CreateTensorFromMessageType(ptr, v.second.type()); - VLOG(3) << "Create Variable " << grad_var_name - << " on recv scope, which pointer is " << ptr << " type is " - << v.second.type(); + // Assume grad_var_name must appear in global scope. + std::string grad_var_name_trainer; + if (fan_in > 1) { + grad_var_name_trainer = this->GetGradVarNameForTrainer(grad_var_name); } - - if (trainer_count > 1) { - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); + auto *var = recv_scope.FindVar(grad_var_name_trainer); + if (var == nullptr) { + LOG(ERROR) << "can not find server side var: " + << grad_var_name_trainer; + PADDLE_THROW("can not find server side var"); } - - auto *var = recv_scope.Var(grad_var_name); - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); detail::DeserializeFromMessage(v.second, dev_ctx, var); } - if (exit_flag) { break; } - rpc_service_->Reset(); - - std::string program_str = Attr("OptimizeProgram"); - framework::proto::ProgramDesc program_desc; - program_desc.ParseFromString(program_str); - framework::ProgramDesc program(program_desc); - framework::Executor executor(dev_place); - // Run sub graph to get optimized tensor try { executor.Run(program, &recv_scope, 0, /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); @@ -195,7 +181,7 @@ This operator will recv tensor from send_op "GradList", "type list of string", "grad->param name mapping to find which param to optimize.") .SetDefault({}); - AddAttr("Trainers", "type int", + AddAttr("Fanin", "type int", "Number of trainers in the current cluster job") .SetDefault(1); } diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 4d145250bdc73..d65153c1fdb5b 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -41,14 +41,16 @@ class SendOp : public framework::OperatorBase { // FIXME(gongwb): DeviceContext? auto ctx = platform::CPUDeviceContext(); for (size_t i = 0; i < ins.size(); i++) { + VLOG(3) << "sending " << ins[i]; client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); } + client_.Wait(); for (size_t i = 0; i < outs.size(); i++) { + VLOG(3) << "getting " << outs[i]; client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - - client_.wait(); + client_.Wait(); } private: diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 00fe3e68c9008..9876296a37ae1 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -452,6 +452,19 @@ def get_pserver_program(self, endpoint, optimize_ops): pserver_program = Program() for v in self.param_grad_ep_mapping[endpoint]["params"]: self._clone_var(pserver_program.global_block(), v) + for v in self.param_grad_ep_mapping[endpoint]["grads"]: + # create vars for each trainer in global scope, so + # we don't need to create them when grad arrives. + pserver_program.global_block().create_var( + name=v.name, persistable=True, dtype=v.dtype, shape=v.shape) + for trainer_id in xrange(self.trainers): + print("create variable for program: %s.trainer_%d" % + (v.name, trainer_id)) + pserver_program.global_block().create_var( + name="%s.trainer_%d" % (v.name, trainer_id), + persistable=True, + dtype=v.dtype, + shape=v.shape) # step6 optimize_sub_program = Program() for idx, opt_op in enumerate(optimize_ops): @@ -481,7 +494,7 @@ def get_pserver_program(self, endpoint, optimize_ops): p.name for p in self.param_grad_ep_mapping[endpoint]["grads"] ], - "Trainers": self.trainers + "Fanin": self.trainers }) pserver_program.sync_with_cpp() return pserver_program diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py index 20b4a8b34cd08..e563e0ddc5d79 100644 --- a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py +++ b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py @@ -39,26 +39,27 @@ place = fluid.CPUPlace() exe = fluid.Executor(place) -t = fluid.DistributeTranspiler() -# all parameter server endpoints list for spliting parameters -pserver_endpoints = os.getenv("PSERVERS") -# server endpoint for current node -current_endpoint = os.getenv("SERVER_ENDPOINT") -# run as trainer or parameter server +pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints +trainers = int(os.getenv("TRAINERS")) # total trainer count +current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint training_role = os.getenv("TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver -t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) +t = fluid.DistributeTranspiler() +t.transpile( + optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": if not current_endpoint: print("need env SERVER_ENDPOINT") exit(1) pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) - exe.run(fluid.default_startup_program()) + pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) + exe.run(pserver_startup) exe.run(pserver_prog) elif training_role == "TRAINER": trainer_prog = t.get_trainer_program() feeder = fluid.DataFeeder(feed_list=[images, label], place=place) + # TODO(typhoonzero): change trainer startup program to fetch parameters from pserver exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): From ae19d2ea1ecd28db7f5704da4cb07c59e038e195 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 18 Jan 2018 18:27:32 +0800 Subject: [PATCH 3/5] fix comm issues --- paddle/operators/detail/grpc_server.cc | 47 +++++++++++++++----------- paddle/operators/detail/grpc_server.h | 15 ++++---- paddle/operators/recv_op.cc | 15 +++++--- 3 files changed, 48 insertions(+), 29 deletions(-) diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index c0b94746a0b7f..42d3cc57584d9 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -36,7 +36,10 @@ class RequestBase { CallStatus Status() { return status_; } void SetStatus(CallStatus status) { status_ = status; } - virtual std::string GetReqName() { assert(false); } + virtual std::string GetReqName() { + assert(false); + return ""; + } protected: grpc::ServerContext ctx_; @@ -80,11 +83,13 @@ class RequestGet final : public RequestBase { public: explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, grpc::ServerCompletionQueue* cq, framework::Scope* scope, - const platform::DeviceContext* dev_ctx) + const platform::DeviceContext* dev_ctx, + SimpleBlockQueue* queue) : RequestBase(service, cq), responder_(&ctx_), scope_(scope), - dev_ctx_(dev_ctx) { + dev_ctx_(dev_ctx), + queue_(queue) { service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); } @@ -100,6 +105,7 @@ class RequestGet final : public RequestBase { // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); status_ = FINISH; + queue_->Push('c'); } protected: @@ -108,8 +114,15 @@ class RequestGet final : public RequestBase { ServerAsyncResponseWriter responder_; framework::Scope* scope_; const platform::DeviceContext* dev_ctx_; + SimpleBlockQueue* queue_; }; +void AsyncGRPCServer::WaitClientGet(int count) { + for (int i = 0; i < count; ++i) { + var_get_queue_.Pop(); + } +} + void AsyncGRPCServer::RunSyncUpdate() { grpc::ServerBuilder builder; builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); @@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { if (is_shut_down_) { return; } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_); + RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, + &var_get_queue_); VLOG(4) << "create Requestget status:" << get->Status(); } @@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, } PADDLE_ENFORCE(tag); - if (wait && !done_) { - Wait(); - } + if (cq_name == "cq_get") WaitCond(2); + if (cq_name == "cq_send") WaitCond(0); RequestBase* base = (RequestBase*)tag; // reference: @@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, } } -void AsyncGRPCServer::Wait() { - std::unique_lock lock(this->mutex_); - condition_.wait(lock, [=] { return this->done_ == true; }); -} - -void AsyncGRPCServer::Reset() { - std::lock_guard lock(this->mutex_); - done_ = false; +void AsyncGRPCServer::WaitCond(int cond) { + std::unique_lock lock(this->barrier_mutex_); + barrier_condition_.wait(lock, + [=] { return this->barrier_cond_step_ == cond; }); } -void AsyncGRPCServer::Done() { +void AsyncGRPCServer::SetCond(int cond) { { - std::lock_guard lock(this->mutex_); - done_ = true; + std::lock_guard lock(this->barrier_mutex_); + barrier_cond_step_ = cond; } - condition_.notify_all(); + barrier_condition_.notify_all(); } } // namespace detail diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 2c078b7777165..5c7be5f5bd256 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -41,9 +41,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void RunSyncUpdate(); - void Reset(); - + // functions to sync server barrier status. + void WaitStart(); + void WaitDone(); + void Start(); void Done(); + void WaitClientGet(int count); void SetScope(framework::Scope *scope) { scope_ = scope; } @@ -56,7 +59,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void ShutDown(); protected: - void Wait(); void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, std::string cq_name, std::function TryToRegisterNewOne); @@ -78,11 +80,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { const platform::DeviceContext *dev_ctx_; // received variable from RPC, operators fetch variable from this queue. SimpleBlockQueue var_recv_queue_; + SimpleBlockQueue var_get_queue_; // condition of the sub program - std::mutex mutex_; - volatile mutable bool done_; - std::condition_variable condition_; + std::mutex barrier_mutex_; + mutable int barrier_cond_step_; + std::condition_variable barrier_condition_; std::unique_ptr t_send_; std::unique_ptr t_get_; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index b77d150dccfbe..2ecd56671f1c4 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -34,6 +34,10 @@ limitations under the License. */ namespace paddle { namespace operators { +constexpr int kCondStart = 0; +constexpr int kCondRunning = 1; +constexpr int kCondDone = 2; + void RunServer(std::shared_ptr service) { service->RunSyncUpdate(); VLOG(4) << "RunServer thread end"; @@ -101,12 +105,14 @@ class RecvOp : public framework::OperatorBase { framework::ProgramDesc program(program_desc); framework::Executor executor(dev_place); - rpc_service_->Reset(); + // rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. + rpc_service_->SetCond(kCondStart); + VLOG(3) << "================ start get from service ==========="; for (size_t i = 0; i < param_count * fan_in; ++i) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; @@ -139,15 +145,16 @@ class RecvOp : public framework::OperatorBase { if (exit_flag) { break; } - rpc_service_->Reset(); + // rpc_service_->Reset(); try { executor.Run(program, &recv_scope, 0, /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } - - rpc_service_->Done(); + VLOG(3) << "================ run sub program end ==========="; + rpc_service_->SetCond(kCondDone); + rpc_service_->WaitClientGet(param_count * fan_in); grads_counter_.clear(); } // while(true) } From 5f4d9130f01833dfef44dac2eadb7089accbe0ba Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 18 Jan 2018 19:27:20 +0800 Subject: [PATCH 4/5] merge codes --- paddle/operators/detail/grpc_server.cc | 5 +++-- paddle/operators/detail/grpc_server.h | 6 ++---- paddle/operators/recv_op.cc | 15 +++++---------- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index 42d3cc57584d9..3ddcd839bdd23 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() { } // This URL explains why shutdown is complicate: -// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c void AsyncGRPCServer::ShutDown() { server_->Shutdown(); ShutdownQueue(); @@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { VLOG(4) << "create Requestget status:" << get->Status(); } +// FIXME(typhoonzero): remove wait argument and change cq_name to enum. void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, std::string cq_name, std::function TryToRegisterNewOne) { @@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, } PADDLE_ENFORCE(tag); - if (cq_name == "cq_get") WaitCond(2); + // FIXME(typhoonzero): de-couple the barriers with recv_op + if (cq_name == "cq_get") WaitCond(1); if (cq_name == "cq_send") WaitCond(0); RequestBase* base = (RequestBase*)tag; diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 5c7be5f5bd256..1ca9086c744c5 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void RunSyncUpdate(); // functions to sync server barrier status. - void WaitStart(); - void WaitDone(); - void Start(); - void Done(); + void WaitCond(int cond); + void SetCond(int cond); void WaitClientGet(int count); void SetScope(framework::Scope *scope) { scope_ = scope; } diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 2ecd56671f1c4..8d1479bdd6311 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase { framework::ProgramDesc program(program_desc); framework::Executor executor(dev_place); - // rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; + int64_t barrier_size = param_count * fan_in; while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(kCondStart); - VLOG(3) << "================ start get from service ==========="; - for (size_t i = 0; i < param_count * fan_in; ++i) { + rpc_service_->SetCond(0); + for (size_t i = 0; i < barrier_size; ++i) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { @@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase { } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; - // Assume grad_var_name must appear in global scope. - std::string grad_var_name_trainer; if (fan_in > 1) { grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); } @@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase { if (exit_flag) { break; } - // rpc_service_->Reset(); try { executor.Run(program, &recv_scope, 0, /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } - VLOG(3) << "================ run sub program end ==========="; - rpc_service_->SetCond(kCondDone); - rpc_service_->WaitClientGet(param_count * fan_in); + rpc_service_->SetCond(1); + rpc_service_->WaitClientGet(barrier_size); grads_counter_.clear(); } // while(true) } From 30529e314e7e9bdce78aa0adf9667da3fe9977cb Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 18 Jan 2018 20:02:26 +0800 Subject: [PATCH 5/5] delete debug transpiler code --- .../paddle/v2/fluid/distribute_transpiler.py | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 3cba015fc5250..13d2bb8325200 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -72,51 +72,6 @@ def split_dense_variable(var_list, return blocks -def split_selected_rows(var, - pserver_count, - min_block_size=1024, - max_block_size=1048576): - assert ((len(var.shape)) <= 1) - - split_count = pserver_count - indices = var.desc.selected_rows().dims() - var_width = reduce(lambda x, y: x * y, var.shape[1:]) - row_count = len(indices) - rows_per_block = 1 - if var_width < min_block_size: - rows_per_block = 1 - split_count = row_count - else: - rows_per_block = row_count / pserver_count - if not rows_per_block % pserver_count: - rows_per_block += 1 - split_count = row_count / rows_per_block - if not row_count % rows_per_block: - split_count += 1 - blocks = [] - for block_id in xrange(split_count): - curr_block_rows = min(rows_per_block, - row_count - (block_id * rows_per_block)) - block = VarBlock(var.name, block_id, curr_block_rows) - blocks.append(block) - return blocks - - -def split_variable(var_list, - pserver_count, - min_block_size=1024, - max_block_size=1048576): - for var in var_list: - if var.type == core.VarDesc.VarType.LOD_TENSOR: - split_dense_variable(var_list, pserver_count, min_block_size, - max_block_size) - elif var.type == core.VarDesc.VarType.SELECTED_ROWS: - split_selected_rows(var_list, pserver_count, min_block_size, - max_block_size) - else: - raise TypeError("variable must be lodtensor or selected rows") - - class DistributeTranspiler: def transpile(self, optimize_ops,