Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Done] API for dist train #6297

Merged
merged 25 commits into from
Dec 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions paddle/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ OpDesc *BlockDesc::PrependOp() {
return ops_.front().get();
}

void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
need_update_ = true;
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
auto names = (*it)->InputArgumentNames();
for (auto n : names) {
// TODO(typhoonzero): delete vars if no other op use it.
VLOG(3) << "deleting var " << n;
}
}
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}

std::vector<OpDesc *> BlockDesc::AllOps() const {
std::vector<OpDesc *> res;
for (const auto &op : ops_) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class BlockDesc {

OpDesc *PrependOp();

void RemoveOp(size_t s, size_t e);

std::vector<OpDesc *> AllOps() const;

size_t OpSize() const { return ops_.size(); }
Expand Down
50 changes: 26 additions & 24 deletions paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}

void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope) {
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
Expand All @@ -74,33 +74,35 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto& device = device_contexts_[0];

Scope* local_scope = scope;
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
} // if (create_local_scope)
} // if (create_vars)

for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
Expand Down
3 changes: 2 additions & 1 deletion paddle/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true);
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);

private:
std::vector<const platform::DeviceContext*> device_contexts_;
Expand Down
50 changes: 41 additions & 9 deletions paddle/operators/detail/recv_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,57 @@ namespace detail {

Status SendRecvServerImpl::SendVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
framework::LoDTensor t;
// TODO(typhoonzero): desirealize in_tensor and run pserver network.
VoidMessage *out_var) {
// TODO(typhoonzero): support different variable types.
std::istringstream iss(in_var->serialized());
framework::LoDTensor t;
framework::DeserializeFromStream(iss, &t);
lodtensor_queue_.Push(std::move(t));
// Block util the sub graph is done.
t = lodtensor_return_queue_.Pop();
TensorWithName tensor_with_name =
std::make_pair(in_var->varname(), std::move(t));

var_recv_queue_.Push(std::move(tensor_with_name));
return Status::OK;
}

Status SendRecvServerImpl::GetVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
std::string get_var_name = in_var->varname();
auto *var = scope_->FindVar(get_var_name);
auto tensor = var->Get<framework::LoDTensor>();
std::ostringstream oss;
// FIXME(typhoonzero): get context from op.
framework::SerializeToStream(oss, t, platform::CPUDeviceContext());
framework::SerializeToStream(oss, tensor, platform::CPUDeviceContext());

std::string *varname = out_var->mutable_varname();
*varname = in_var->varname();
*varname = get_var_name;
std::string *serialized = out_var->mutable_serialized();
*serialized = oss.str();
return Status::OK;
}

Status SendRecvServerImpl::Wait(ServerContext *context,
const VoidMessage *in_var,
VoidMessage *out_var) {
{
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}
return Status::OK;
}

void SendRecvServerImpl::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
}

void SendRecvServerImpl::Done() {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
}
condition_.notify_all();
}

} // namespace detail
} // namespace operators
} // namespace paddle
31 changes: 27 additions & 4 deletions paddle/operators/detail/send_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ namespace operators {
namespace detail {

bool RPCClient::SendVariable(const framework::Scope& scope,
const std::string& inname,
const std::string& outname) {
const std::string& inname) {
ClientContext context;
VariableMessage msg, out_msg;
VariableMessage msg;
VoidMessage out_msg;
// FIXME(typhoonzero): pass device context to here.
auto ctx = platform::CPUDeviceContext();
auto* var = scope.FindVar(inname);
Expand All @@ -37,9 +37,26 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
msg.set_serialized(oss.str());
Status status = stub_->SendVariable(&context, msg, &out_msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
std::istringstream iss(out_msg.serialized());
return true;
}

bool RPCClient::GetVariable(const framework::Scope& scope,
const std::string& outname) {
ClientContext context;
VariableMessage call_msg, ret_msg;
call_msg.set_varname(outname);
auto ctx = platform::CPUDeviceContext();
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}

std::istringstream iss(ret_msg.serialized());

framework::LoDTensor ret_tensor;
framework::DeserializeFromStream(iss, &ret_tensor);
auto* outvar = scope.FindVar(outname);
Expand All @@ -49,6 +66,12 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
return true;
}

void RPCClient::Wait() {
ClientContext context;
VoidMessage call_msg, ret_msg;
stub_->Wait(&context, call_msg, &ret_msg);
Copy link
Contributor

@helinwang helinwang Dec 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here seems if the remote wait completed, but failed due to network issue. The next time this call retries, the remote will wait forever until next notify_all is called.

}

} // namespace detail
} // namespace operators
} // namespace paddle
7 changes: 6 additions & 1 deletion paddle/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ package sendrecv;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
rpc SendVariable(VariableMessage) returns (VariableMessage) {}
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// wait for one execution of the program
rpc Wait(VoidMessage) returns (VoidMessage) {}
}

// VariableMessage is serialized paddle variable message.
Expand Down
37 changes: 21 additions & 16 deletions paddle/operators/detail/send_recv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/simple_block_queue.h"

// #include <grpc++/channel.h>
// #include <grpc++/client_context.h>
// #include <grpc++/create_channel.h>
// #include <grpc++/security/credentials.h>
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"

Expand All @@ -48,24 +44,32 @@ namespace paddle {
namespace operators {
namespace detail {

typedef std::pair<std::string, framework::LoDTensor> TensorWithName;

class SendRecvServerImpl final : public SendRecvService::Service {
public:
explicit SendRecvServerImpl() {}

Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;

const framework::LoDTensor Get() { return this->lodtensor_queue_.Pop(); }
VoidMessage *out_var) override;
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;
Status Wait(ServerContext *context, const VoidMessage *in_var,
VoidMessage *out_var) override;
void Reset();
void Done();
void SetScope(framework::Scope *scope) { scope_ = scope; };

void Push(const framework::LoDTensor &tensor) {
this->lodtensor_return_queue_.Push(tensor);
}
const TensorWithName Get() { return this->var_recv_queue_.Pop(); }

private:
SimpleBlockQueue<framework::LoDTensor> lodtensor_queue_;
SimpleBlockQueue<framework::LoDTensor> lodtensor_return_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_return_queue_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<TensorWithName> var_recv_queue_;
framework::Scope *scope_;
// condition of the sub program
std::mutex mutex_;
bool done_;
std::condition_variable condition_;
};

// RPCClient is a class to send tensors to pserver sub-network
Expand All @@ -75,8 +79,9 @@ class RPCClient {
RPCClient(std::shared_ptr<Channel> channel)
: stub_(SendRecvService::NewStub(channel)) {}

bool SendVariable(const framework::Scope &scope, const std::string &inname,
const std::string &outname);
bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope, const std::string &outname);
void Wait();

private:
std::unique_ptr<SendRecvService::Stub> stub_;
Expand Down
Loading