Skip to content

Commit

Permalink
Merge pull request #45143 from makortel/edmAsync_140x
Browse files Browse the repository at this point in the history
[14_0_X] Introduce edm::Async service, and use it in CUDA and Alpaka modules
  • Loading branch information
cmsbuild authored Jun 24, 2024
2 parents 61dacc5 + 189d1f8 commit e551de3
Show file tree
Hide file tree
Showing 24 changed files with 1,052 additions and 41 deletions.
2 changes: 1 addition & 1 deletion FWCore/Concurrency/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<use name="FWCore/Utilities" source_only="1"/>
<use name="FWCore/Utilities"/>
<use name="tbb"/>
<export>
<lib name="1"/>
Expand Down
34 changes: 34 additions & 0 deletions FWCore/Concurrency/interface/Async.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef FWCore_Concurrency_Async_h
#define FWCore_Concurrency_Async_h

#include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"
#include "FWCore/Concurrency/interface/WaitingThreadPool.h"

namespace edm {
// All member functions are thread safe
class Async {
public:
Async() = default;
virtual ~Async() noexcept;

// prevent copying and moving
Async(Async const&) = delete;
Async(Async&&) = delete;
Async& operator=(Async const&) = delete;
Async& operator=(Async&&) = delete;

template <typename F, typename G>
void runAsync(WaitingTaskWithArenaHolder holder, F&& func, G&& errorContextFunc) {
ensureAllowed();
pool_.runAsync(std::move(holder), std::forward<F>(func), std::forward<G>(errorContextFunc));
}

protected:
virtual void ensureAllowed() const = 0;

private:
WaitingThreadPool pool_;
};
} // namespace edm

#endif
106 changes: 106 additions & 0 deletions FWCore/Concurrency/interface/WaitingThreadPool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#ifndef FWCore_Concurrency_WaitingThreadPool_h
#define FWCore_Concurrency_WaitingThreadPool_h

#include "FWCore/Utilities/interface/ConvertException.h"
#include "FWCore/Utilities/interface/ReusableObjectHolder.h"
#include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"

#include <condition_variable>
#include <mutex>
#include <thread>

namespace edm {
namespace impl {
class WaitingThread {
public:
WaitingThread();
~WaitingThread() noexcept;

WaitingThread(WaitingThread const&) = delete;
WaitingThread& operator=(WaitingThread&&) = delete;
WaitingThread(WaitingThread&&) = delete;
WaitingThread& operator=(WaitingThread const&) = delete;

template <typename F, typename G>
void run(WaitingTaskWithArenaHolder holder,
F&& func,
G&& errorContextFunc,
std::shared_ptr<WaitingThread> thisPtr) {
std::unique_lock lk(mutex_);
func_ = [holder = std::move(holder),
func = std::forward<F>(func),
errorContext = std::forward<G>(errorContextFunc)]() mutable {
try {
convertException::wrap([&func]() { func(); });
} catch (cms::Exception& e) {
e.addContext(errorContext());
holder.doneWaiting(std::current_exception());
}
};
thisPtr_ = std::move(thisPtr);
cond_.notify_one();
}

private:
void stopThread() {
std::unique_lock lk(mutex_);
stopThread_ = true;
cond_.notify_one();
}

void threadLoop() noexcept;

std::thread thread_;
std::mutex mutex_;
std::condition_variable cond_;
CMS_THREAD_GUARD(mutex_) std::function<void()> func_;
// The purpose of thisPtr_ is to keep the WaitingThread object
// outside of the WaitingThreadPool until the func_ has returned.
CMS_THREAD_GUARD(mutex_) std::shared_ptr<WaitingThread> thisPtr_;
CMS_THREAD_GUARD(mutex_) bool stopThread_ = false;
};
} // namespace impl

// Provides a mechanism to run the function 'func' asynchronously,
// i.e. without the calling thread to wait for the func() to return.
// The func should do as little work (outside of the TBB threadpool)
// as possible. The func must terminate eventually. The intended use
// case are blocking synchronization calls with external entities,
// where the calling thread is suspended while waiting.
//
// The func() is run in a thread that belongs to a separate pool of
// threads than the calling thread. Remotely similar to
// std::async(), but instead of dealing with std::futures, takes an
// edm::WaitingTaskWithArenaHolder object, that is signaled upon the
// func() returning or throwing an exception.
//
// The caller is responsible for keeping the WaitingThreadPool
// object alive at least as long as all asynchronous calls finish.
class WaitingThreadPool {
public:
WaitingThreadPool() = default;
WaitingThreadPool(WaitingThreadPool const&) = delete;
WaitingThreadPool& operator=(WaitingThreadPool const&) = delete;
WaitingThreadPool(WaitingThreadPool&&) = delete;
WaitingThreadPool& operator=(WaitingThreadPool&&) = delete;

/**
* \param holder WaitingTaskWithArenaHolder object to signal the completion of 'func'
* \param func Function to run in a separate thread
* \param errorContextFunc Function returning a string-like object
* that is added to the context of
* cms::Exception in case 'func' throws an
* exception
*/
template <typename F, typename G>
void runAsync(WaitingTaskWithArenaHolder holder, F&& func, G&& errorContextFunc) {
auto thread = pool_.makeOrGet([]() { return std::make_unique<impl::WaitingThread>(); });
thread->run(std::move(holder), std::forward<F>(func), std::forward<G>(errorContextFunc), std::move(thread));
}

private:
edm::ReusableObjectHolder<impl::WaitingThread> pool_;
};
} // namespace edm

#endif
5 changes: 5 additions & 0 deletions FWCore/Concurrency/src/Async.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "FWCore/Concurrency/interface/Async.h"

namespace edm {
Async::~Async() noexcept = default;
}
58 changes: 58 additions & 0 deletions FWCore/Concurrency/src/WaitingThreadPool.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "FWCore/Concurrency/interface/WaitingThreadPool.h"

#include <cassert>
#include <string_view>

#include <pthread.h>

namespace edm::impl {
WaitingThread::WaitingThread() {
thread_ = std::thread(&WaitingThread::threadLoop, this);
static constexpr auto poolName = "edm async pool";
// pthread_setname_np() string length is limited to 16 characters,
// including the null termination
static_assert(std::string_view(poolName).size() < 16);

int err = pthread_setname_np(thread_.native_handle(), poolName);
// According to the glibc documentation, the only error
// pthread_setname_np() can return is about the argument C-string
// being too long. We already check above the C-string is shorter
// than the limit was at the time of writing. In order to capture
// if the limit shortens, or other error conditions get added,
// let's assert() anyway (exception feels overkill)
assert(err == 0);
}

WaitingThread::~WaitingThread() noexcept {
// When we are shutting down, we don't care about any possible
// system errors anymore
CMS_SA_ALLOW try {
stopThread();
thread_.join();
} catch (...) {
}
}

void WaitingThread::threadLoop() noexcept {
std::unique_lock lk(mutex_);

while (true) {
cond_.wait(lk, [this]() { return static_cast<bool>(func_) or stopThread_; });
if (stopThread_) {
// There should be no way to stop the thread when it as the
// func_ assigned, but let's make sure
assert(not thisPtr_);
break;
}
func_();
// Must return this WaitingThread to the ReusableObjectHolder in
// the WaitingThreadPool before resettting func_ (that holds the
// WaitingTaskWithArenaHolder, that enables the progress in the
// TBB thread pool) in order to meet the requirement of
// ReusableObjectHolder destructor that there are no outstanding
// objects.
thisPtr_.reset();
decltype(func_)().swap(func_);
}
}
} // namespace edm::impl
101 changes: 101 additions & 0 deletions FWCore/Concurrency/test/test_catch2_Async.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "catch.hpp"

#include <atomic>

#include "oneapi/tbb/global_control.h"

#include "FWCore/Concurrency/interface/chain_first.h"
#include "FWCore/Concurrency/interface/FinalWaitingTask.h"
#include "FWCore/Concurrency/interface/Async.h"

namespace {
constexpr char const* errorContext() { return "AsyncServiceTest"; }

class AsyncServiceTest : public edm::Async {
public:
enum class State { kAllowed, kDisallowed, kShutdown };

AsyncServiceTest() = default;

void setAllowed(bool allowed) noexcept { allowed_ = allowed; }

private:
void ensureAllowed() const final {
if (not allowed_) {
throw std::runtime_error("Calling run in this context is not allowed");
}
}

std::atomic<bool> allowed_ = true;
};
} // namespace

TEST_CASE("Test Async", "[edm::Async") {
// Using parallelism 2 here because otherwise the
// tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will
// start a new TBB thread that "inherits" the name from the
// WaitingThreadPool thread.
oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);

SECTION("Normal operation") {
AsyncServiceTest service;
std::atomic<int> count{0};

oneapi::tbb::task_group group;
edm::FinalWaitingTask waitTask{group};

{
using namespace edm::waiting_task::chain;
auto h1 = first([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
}) |
lastTask(edm::WaitingTaskHolder(group, &waitTask));

auto h2 = first([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
}) |
lastTask(edm::WaitingTaskHolder(group, &waitTask));
h2.doneWaiting(std::exception_ptr());
h1.doneWaiting(std::exception_ptr());
}
waitTask.waitNoThrow();
REQUIRE(count.load() == 2);
REQUIRE(waitTask.done());
REQUIRE(not waitTask.exceptionPtr());
}

SECTION("Disallowed") {
AsyncServiceTest service;
std::atomic<int> count{0};

oneapi::tbb::task_group group;
edm::FinalWaitingTask waitTask{group};

{
using namespace edm::waiting_task::chain;
auto h = first([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
service.setAllowed(false);
}) |
then([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
}) |
lastTask(edm::WaitingTaskHolder(group, &waitTask));
h.doneWaiting(std::exception_ptr());
}
waitTask.waitNoThrow();
REQUIRE(count.load() == 1);
REQUIRE(waitTask.done());
REQUIRE(waitTask.exceptionPtr());
REQUIRE_THROWS_WITH(std::rethrow_exception(waitTask.exceptionPtr()),
Catch::Contains("Calling run in this context is not allowed"));
}
}
Loading

0 comments on commit e551de3

Please sign in to comment.