Skip to content

Commit

Permalink
dash
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Apr 16, 2018
1 parent 2455fd3 commit 8114cb5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
15 changes: 8 additions & 7 deletions paddle/fluid/framework/details/reduce_and_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct GatherSelectedRows {
void operator()(
const std::vector<SelectedRows> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const platform::Place &out_place,
const std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash> &dev_ctxes,
SelectedRows *dst_selecte_rows) const {
Expand All @@ -63,20 +64,20 @@ struct GatherSelectedRows {

for (auto &in_sr : src_selecte_rows_) {
in_tensors.emplace_back(in_sr.value());
out_rows.insert(out_rows.end(), in_sr.rows.begin(), in_sr.rows.end());
out_rows.insert(out_rows.end(), in_sr.rows().begin(), in_sr.rows().end());
}

auto &pre_in = src_selecte_rows_[0];

dst_tensor_.set_height(pre_in.height());
dst_selecte_rows.set_rows(out_rows);
dst_selecte_rows->set_height(pre_in.height());
dst_selecte_rows->set_rows(out_rows);
size_t rows = out_rows.size();
DDim out_dim = pre_in.GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows);
dst_selecte_rows.mutable_value()->Resize(out_dim);
dst_selecte_rows.mutable_value()->mutable_data(out_place,
pre_in.value().type());
Tensor *out_tensor = dst_selecte_rows.mutable_value();
dst_selecte_rows->mutable_value()->Resize(out_dim);
dst_selecte_rows->mutable_value()->mutable_data(out_place,
pre_in.value().type());
Tensor *out_tensor = dst_selecte_rows->mutable_value();

// copy
int s = 0, e = 0;
Expand Down
22 changes: 11 additions & 11 deletions paddle/fluid/framework/details/reduce_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void ReduceOpHandle::RunImpl() {
"The number of output should be one.");

// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handles[0]->place_;
auto in_place = in_var_handles[0]->place_;
if (in_var_handles[0]->generated_op_) {
for (auto *in : in_var_handles) {
auto &in_p = in->place_;
Expand Down Expand Up @@ -103,7 +103,8 @@ void ReduceOpHandle::RunImpl() {
in_selected_rows.emplace_back(in_sr);
}
auto trg = out_var->GetMutable<framework::SelectedRows>();
gather(in_selected_rows, in_places, dev_ctxes_, trg);
gather(in_selected_rows, in_places, out_var_handles[0]->place_, dev_ctxes_,
trg);
} else {
// reduce tensor
auto pre_in = pre_in_var->Get<framework::LoDTensor>();
Expand Down Expand Up @@ -139,22 +140,21 @@ void ReduceOpHandle::RunImpl() {
auto &p = in_places[i];
auto &lod_tensor = lod_tensors[i];

void *buffer = const_cast<void *>(lod_tensor.data<void>());
int gpu_id = static_cast<>

if (dtype == -1) {
dtype = platform::ToNCCLDataType(lod_tensor.type());
}

T *recvbuffer = nullptr;
if (root == gpu_id) {
recvbuffer = trg->mutable_data(out_var_handles[0]->place_);
}
void *buffer = const_cast<void *>(lod_tensor.data<void>());

int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;

void *recvbuffer = nullptr;
if (root == dev_id) {
recvbuffer = trg->mutable_data(out_var_handles[0]->place_);
}

// error: get the sizeof of var.type()
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclReduce(
buffer, recvbuffer, static_cast<size_t>(lod_tensor.numel()),
Expand Down

0 comments on commit 8114cb5

Please sign in to comment.