diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 5e51199e3304..f7b29405377c 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -22,9 +22,11 @@ * \file naive_engine.cc * \brief Implementation of NaiveEngine */ -#include #include +#include +#include #include +#include #include "./engine_impl.h" #include "../profiler/profiler.h" #include "./openmp.h" @@ -156,9 +158,10 @@ class NaiveEngine final : public Engine { int priority = 0, const char* opr_name = nullptr, bool wait = false) override { - bool req_completed = false; + std::promise promise; + std::future future = promise.get_future(); CallbackOnComplete callback = CreateCallback( - NaiveEngine::OnComplete, &req_completed); + NaiveEngine::OnComplete, &promise); profiler::Profiler *profiler = profiler::Profiler::Get(); auto opr_deleter = [this](NaiveOpr* p) { this->DeleteOperator(p); @@ -200,12 +203,11 @@ class NaiveEngine final : public Engine { } else { exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr, false}, callback); } + future.wait(); // increment mutable var version for (auto var : mutable_vars) { ++var->version_; } - CHECK(req_completed) - << "NaiveEngine only support synchronize Push so far"; if (profiling) { opr->opr_profile->stop(); } @@ -237,8 +239,7 @@ class NaiveEngine final : public Engine { // callback to oncomplete static void OnComplete(Engine *engine, void *param, const dmlc::Error* error) { - bool *req_completed = static_cast(param); - *req_completed = true; + static_cast*>(param)->set_value(); } /*! \brief whether it is during shutdown phase*/ std::atomic shutdown_phase_{false};