Skip to content

Commit

Permalink
Merge pull request #10143 from typhoonzero/fix_multiGPU_dist_train
Browse files Browse the repository at this point in the history
Fix multi gpu dist train
  • Loading branch information
reyoung authored Apr 24, 2018
2 parents e5f2cb8 + f034152 commit 2c8fe4e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
45 changes: 42 additions & 3 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
}
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const {
if (send_op == nullptr) {
return false;
}

auto checker = [&](const std::vector<std::string> opvars,
const std::vector<std::string> sendvars) -> bool {
bool is_dist_train_op = false;
for (auto &var : opvars) {
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
is_dist_train_op = true;
break;
}
}
return is_dist_train_op;
};

if (op.Type() == "split") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
}
return false;
}

std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
auto graph = new SSAGraph();
Expand All @@ -89,19 +116,30 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());

// Find "send" op first for split is in front of send.
OpDesc *send_op = nullptr;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
send_op = op;
break;
}
}

bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
// append send op if program is distributed trainer main program.
// always use the first device
CreateSendOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
if (!skip_scale_loss_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
} else {
CreateComputationalOps(&result, *op);
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. But there are no
Expand Down Expand Up @@ -199,8 +237,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
}

void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
const OpDesc &op) const {
for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) {
const OpDesc &op,
size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

void CreateSendOp(SSAGraph *result, const OpDesc &op) const;

void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const;
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;

void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const;

void CreateScaleLossGradOp(SSAGraph *result) const;

Expand Down

0 comments on commit 2c8fe4e

Please sign in to comment.