Skip to content

Commit

Permalink
Merge pull request #6297 from typhoonzero/simple_dist_train_api
Browse files Browse the repository at this point in the history
[Done] API for dist train
  • Loading branch information
typhoonzero authored Dec 22, 2017
2 parents a1cfc32 + 5913e73 commit 8d6db25
Show file tree
Hide file tree
Showing 19 changed files with 618 additions and 119 deletions.
15 changes: 15 additions & 0 deletions paddle/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ OpDesc *BlockDesc::PrependOp() {
return ops_.front().get();
}

void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
need_update_ = true;
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
auto names = (*it)->InputArgumentNames();
for (auto n : names) {
// TODO(typhoonzero): delete vars if no other op use it.
VLOG(3) << "deleting var " << n;
}
}
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}

std::vector<OpDesc *> BlockDesc::AllOps() const {
std::vector<OpDesc *> res;
for (const auto &op : ops_) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class BlockDesc {

OpDesc *PrependOp();

void RemoveOp(size_t s, size_t e);

std::vector<OpDesc *> AllOps() const;

size_t OpSize() const { return ops_.size(); }
Expand Down
50 changes: 26 additions & 24 deletions paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}

void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope) {
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
Expand All @@ -74,33 +74,35 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto& device = device_contexts_[0];

Scope* local_scope = scope;
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
} // if (create_local_scope)
} // if (create_vars)

for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
Expand Down
3 changes: 2 additions & 1 deletion paddle/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true);
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);

private:
std::vector<const platform::DeviceContext*> device_contexts_;
Expand Down
50 changes: 41 additions & 9 deletions paddle/operators/detail/recv_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,57 @@ namespace detail {

Status SendRecvServerImpl::SendVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
framework::LoDTensor t;
// TODO(typhoonzero): desirealize in_tensor and run pserver network.
VoidMessage *out_var) {
// TODO(typhoonzero): support different variable types.
std::istringstream iss(in_var->serialized());
framework::LoDTensor t;
framework::DeserializeFromStream(iss, &t);
lodtensor_queue_.Push(std::move(t));
// Block util the sub graph is done.
t = lodtensor_return_queue_.Pop();
TensorWithName tensor_with_name =
std::make_pair(in_var->varname(), std::move(t));

var_recv_queue_.Push(std::move(tensor_with_name));
return Status::OK;
}

Status SendRecvServerImpl::GetVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
std::string get_var_name = in_var->varname();
auto *var = scope_->FindVar(get_var_name);
auto tensor = var->Get<framework::LoDTensor>();
std::ostringstream oss;
// FIXME(typhoonzero): get context from op.
framework::SerializeToStream(oss, t, platform::CPUDeviceContext());
framework::SerializeToStream(oss, tensor, platform::CPUDeviceContext());

std::string *varname = out_var->mutable_varname();
*varname = in_var->varname();
*varname = get_var_name;
std::string *serialized = out_var->mutable_serialized();
*serialized = oss.str();
return Status::OK;
}

Status SendRecvServerImpl::Wait(ServerContext *context,
const VoidMessage *in_var,
VoidMessage *out_var) {
{
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}
return Status::OK;
}

void SendRecvServerImpl::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
}

void SendRecvServerImpl::Done() {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
}
condition_.notify_all();
}

} // namespace detail
} // namespace operators
} // namespace paddle
31 changes: 27 additions & 4 deletions paddle/operators/detail/send_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ namespace operators {
namespace detail {

bool RPCClient::SendVariable(const framework::Scope& scope,
const std::string& inname,
const std::string& outname) {
const std::string& inname) {
ClientContext context;
VariableMessage msg, out_msg;
VariableMessage msg;
VoidMessage out_msg;
// FIXME(typhoonzero): pass device context to here.
auto ctx = platform::CPUDeviceContext();
auto* var = scope.FindVar(inname);
Expand All @@ -37,9 +37,26 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
msg.set_serialized(oss.str());
Status status = stub_->SendVariable(&context, msg, &out_msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
std::istringstream iss(out_msg.serialized());
return true;
}

bool RPCClient::GetVariable(const framework::Scope& scope,
const std::string& outname) {
ClientContext context;
VariableMessage call_msg, ret_msg;
call_msg.set_varname(outname);
auto ctx = platform::CPUDeviceContext();
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}

std::istringstream iss(ret_msg.serialized());

framework::LoDTensor ret_tensor;
framework::DeserializeFromStream(iss, &ret_tensor);
auto* outvar = scope.FindVar(outname);
Expand All @@ -49,6 +66,12 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
return true;
}

void RPCClient::Wait() {
ClientContext context;
VoidMessage call_msg, ret_msg;
stub_->Wait(&context, call_msg, &ret_msg);
}

} // namespace detail
} // namespace operators
} // namespace paddle
7 changes: 6 additions & 1 deletion paddle/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ package sendrecv;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
rpc SendVariable(VariableMessage) returns (VariableMessage) {}
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// wait for one execution of the program
rpc Wait(VoidMessage) returns (VoidMessage) {}
}

// VariableMessage is serialized paddle variable message.
Expand Down
37 changes: 21 additions & 16 deletions paddle/operators/detail/send_recv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/simple_block_queue.h"

// #include <grpc++/channel.h>
// #include <grpc++/client_context.h>
// #include <grpc++/create_channel.h>
// #include <grpc++/security/credentials.h>
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"

Expand All @@ -48,24 +44,32 @@ namespace paddle {
namespace operators {
namespace detail {

typedef std::pair<std::string, framework::LoDTensor> TensorWithName;

class SendRecvServerImpl final : public SendRecvService::Service {
public:
explicit SendRecvServerImpl() {}

Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;

const framework::LoDTensor Get() { return this->lodtensor_queue_.Pop(); }
VoidMessage *out_var) override;
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;
Status Wait(ServerContext *context, const VoidMessage *in_var,
VoidMessage *out_var) override;
void Reset();
void Done();
void SetScope(framework::Scope *scope) { scope_ = scope; };

void Push(const framework::LoDTensor &tensor) {
this->lodtensor_return_queue_.Push(tensor);
}
const TensorWithName Get() { return this->var_recv_queue_.Pop(); }

private:
SimpleBlockQueue<framework::LoDTensor> lodtensor_queue_;
SimpleBlockQueue<framework::LoDTensor> lodtensor_return_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_return_queue_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<TensorWithName> var_recv_queue_;
framework::Scope *scope_;
// condition of the sub program
std::mutex mutex_;
bool done_;
std::condition_variable condition_;
};

// RPCClient is a class to send tensors to pserver sub-network
Expand All @@ -75,8 +79,9 @@ class RPCClient {
RPCClient(std::shared_ptr<Channel> channel)
: stub_(SendRecvService::NewStub(channel)) {}

bool SendVariable(const framework::Scope &scope, const std::string &inname,
const std::string &outname);
bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope, const std::string &outname);
void Wait();

private:
std::unique_ptr<SendRecvService::Stub> stub_;
Expand Down
Loading

0 comments on commit 8d6db25

Please sign in to comment.