From 8b5687467a7d2ff7081029a8fd6ef246fd4dfc47 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Sat, 12 Sep 2020 18:58:31 -0700 Subject: [PATCH] Fix race condition in NaiveEngine::PushAsync (#19108) (#19122) * Wait for async_fun to complete in NaiveEngine::PushAsync This fixes a race condition in which NaiveEngine::PushAsync was checking if the the async_fun had completed by the end of NaiveEngine::PushAsync. If async_fun hadn't completed yet, NaiveEngine::PushAsync would set an internal error string and deallocate the callback, causing segfault in async_fun once it would attempt calling the callback. * Update naive_engine.cc --- src/engine/naive_engine.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index e76003a8dca9..d98f77d03bc8 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -22,9 +22,11 @@ * \file naive_engine.cc * \brief Implementation of NaiveEngine */ -#include #include +#include +#include #include +#include #include "./engine_impl.h" #include "../profiler/profiler.h" #include "./openmp.h" @@ -156,9 +158,10 @@ class NaiveEngine final : public Engine { int priority = 0, const char* opr_name = nullptr, bool wait = false) override { - bool req_completed = false; + std::promise promise; + std::future future = promise.get_future(); CallbackOnComplete callback = CreateCallback( - NaiveEngine::OnComplete, &req_completed); + NaiveEngine::OnComplete, &promise); profiler::Profiler *profiler = profiler::Profiler::Get(); auto opr_deleter = [this](NaiveOpr* p) { this->DeleteOperator(p); @@ -199,12 +202,11 @@ class NaiveEngine final : public Engine { } else { exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr, false}, callback); } + future.wait(); // increment mutable var version for (auto var : mutable_vars) { ++var->version_; } - CHECK(req_completed) - << "NaiveEngine only support synchronize Push so far"; if (profiling) { opr->opr_profile->stop(); } @@ -236,8 +238,7 @@ class NaiveEngine final : public Engine { // callback to oncomplete static void OnComplete(Engine *engine, void *param, const dmlc::Error* error) { - bool *req_completed = static_cast(param); - *req_completed = true; + static_cast*>(param)->set_value(); } /*! \brief whether it is during shutdown phase*/ std::atomic shutdown_phase_{false};