Skip to content

Commit

Permalink
check the generate_op is null or not and add DEPS of broadcast_op_han…
Browse files Browse the repository at this point in the history
…dle and gather_op_handle
  • Loading branch information
chengduoZH committed Apr 18, 2018
1 parent d24ef93 commit 4760ac4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
5 changes: 2 additions & 3 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)

cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)

cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)

cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)

cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ void BroadcastOpHandle::RunImpl() {
"Places must be all on CPU or all on CUDA.");

VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
in_tensor.type());
VariableVisitor::GetMutableTensor(out_var)
.Resize(in_tensor.dims())
.mutable_data(out_p, in_tensor.type());

auto dev_ctx = dev_ctxes_[out_p];
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
Expand All @@ -74,8 +75,10 @@ void BroadcastOpHandle::RunImpl() {
}

void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
if (in_var.generated_op_) {
for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
}
}
}

Expand Down

0 comments on commit 4760ac4

Please sign in to comment.