Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Update ThreadedVar #67

Merged
merged 1 commit into from
Sep 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/engine/engine_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

namespace mxnet {
namespace engine {

/*! \brief base class of engine variables, used for type checking */
struct Var {
#if ENGINE_DEBUG
Expand Down
9 changes: 2 additions & 7 deletions src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
auto&& counter = gpu_cnt_.at(ctx.dev_id);
if (counter == -1) {
for (auto&& i : gpu_streams_.at(ctx.dev_id)) {
#if MXNET_USE_CUDNN == 1
i = mshadow::NewStream<gpu>(true, true);
#else
i = mshadow::NewStream<gpu>(true, false);
#endif // MXNET_USE_CUDNN
i = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0);
}
counter = 0;
}
Expand Down Expand Up @@ -88,7 +84,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
{
std::lock_guard<std::mutex> lock{m_};
if (gpu_io_streams_.at(ctx.dev_id) == nullptr) {
gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(true, false);
gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, false);
}
}
return {gpu_io_streams_.at(ctx.dev_id)};
Expand Down Expand Up @@ -126,7 +122,6 @@ StreamManager<kNumGpus, kStreams>::~StreamManager() {
}

} // namespace engine

} // namespace mxnet

#endif // MXNET_ENGINE_STREAM_MANAGER_H_
101 changes: 57 additions & 44 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) {
head_->trigger = opr_block;
head_->write = true;
if (ready_to_read_) {
/*!
* Raise `num_pending_reads_` temporarily to avoid premature triggering.
*/
// Raise `num_pending_reads_` temporarily to avoid premature triggering.
++num_pending_reads_;
pending_write_ = head_;
if (--num_pending_reads_ == 0) {
Expand All @@ -65,51 +63,74 @@ void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) {

template <typename Dispatcher>
void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) {
std::lock_guard<std::mutex> lock{m_};
if (--num_pending_reads_ == 0) {
if (pending_write_ != nullptr && --pending_write_->trigger->wait == 0) {
dispatcher(pending_write_->trigger);
bool trigger = false;
{
// this is lock scope
std::lock_guard<std::mutex> lock{m_};
if (--num_pending_reads_ == 0) {
if (pending_write_ != nullptr && --pending_write_->trigger->wait == 0) {
trigger = true;
}
}
}
if (trigger) {
dispatcher(pending_write_->trigger);
}
}

template <typename Dispatcher>
bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
std::lock_guard<std::mutex> lock{m_};
assert(ready_to_read_ == false);
auto cur_head = pending_write_->next;
VersionedVarBlock::Delete(pending_write_);
pending_write_ = nullptr;
VersionedVarBlock *old_pending_write, *end_of_dispatch_chain;
int num_reads;
{
// this is lock scope
std::lock_guard<std::mutex> lock{m_};
assert(ready_to_read_ == false);
// detach pending write
old_pending_write = pending_write_;
pending_write_ = nullptr;
// search for chains to trigger
VersionedVarBlock *p = old_pending_write->next;
assert(num_pending_reads_ == 0);
num_reads = 0;
while (p->next != nullptr && p->write == false) {
p = p->next;
++num_reads;
}
num_pending_reads_ = num_reads;
end_of_dispatch_chain = p;
if (p->next == nullptr) {
ready_to_read_ = true;
} else {
assert(p->write == true);
pending_write_ = p;
}
}
// this is outside of lock scope
// the linked list is detached from variable
VersionedVarBlock *cur_head = old_pending_write->next;
VersionedVarBlock::Delete(old_pending_write);
if (to_delete_) {
assert(cur_head->next == nullptr);
VersionedVarBlock::Delete(cur_head);
return true;
} else {
while (true) {
if (cur_head->write == true) {
++num_pending_reads_;
pending_write_ = cur_head;
if (--num_pending_reads_ == 0) {
if (--cur_head->trigger->wait == 0) {
dispatcher(cur_head->trigger);
}
}
break;
} else if (cur_head->next == nullptr) {
ready_to_read_ = true;
break;
} else {
++num_pending_reads_;
if (--cur_head->trigger->wait == 0) {
dispatcher(cur_head->trigger);
}
auto prev = cur_head;
cur_head = cur_head->next;
VersionedVarBlock::Delete(prev);
}
}
// dispatch all the events
while (cur_head != end_of_dispatch_chain) {
if (--cur_head->trigger->wait == 0) {
dispatcher(cur_head->trigger);
}
return false;
auto prev = cur_head;
cur_head = cur_head->next;
VersionedVarBlock::Delete(prev);
}
// trigger pending write, if any
if (pending_write_ != nullptr && num_reads == 0) {
if (--pending_write_->trigger->wait == 0) {
dispatcher(pending_write_->trigger);
}
}
return false;
}

void ThreadedVar::SetToDelete() {
Expand All @@ -122,14 +143,6 @@ bool ThreadedVar::ready_to_read() {
return ready_to_read_;
}

ThreadedVar* ThreadedVar::CastFromBase(Var* v) {
return v->Cast<ThreadedVar>();
}

ThreadedOpr* ThreadedOpr::CastFromBase(Opr* o) {
return o->Cast<ThreadedOpr>();
}

ThreadedEngine::ThreadedEngine()
: pending_{0},
thread_pool_{[this]() { ThreadWorker(&task_queue_); }},
Expand Down
Loading