Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[14_0_X] Introduce edm::Async service, and use it in CUDA and Alpaka modules #45143

Merged
merged 8 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion FWCore/Concurrency/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<use name="FWCore/Utilities" source_only="1"/>
<use name="FWCore/Utilities"/>
<use name="tbb"/>
<export>
<lib name="1"/>
Expand Down
34 changes: 34 additions & 0 deletions FWCore/Concurrency/interface/Async.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef FWCore_Concurrency_Async_h
#define FWCore_Concurrency_Async_h

#include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"
#include "FWCore/Concurrency/interface/WaitingThreadPool.h"

namespace edm {
// All member functions are thread safe
class Async {
public:
Async() = default;
virtual ~Async() noexcept;

// prevent copying and moving
Async(Async const&) = delete;
Async(Async&&) = delete;
Async& operator=(Async const&) = delete;
Async& operator=(Async&&) = delete;

template <typename F, typename G>
void runAsync(WaitingTaskWithArenaHolder holder, F&& func, G&& errorContextFunc) {
ensureAllowed();
pool_.runAsync(std::move(holder), std::forward<F>(func), std::forward<G>(errorContextFunc));
}

protected:
virtual void ensureAllowed() const = 0;

private:
WaitingThreadPool pool_;
};
} // namespace edm

#endif
106 changes: 106 additions & 0 deletions FWCore/Concurrency/interface/WaitingThreadPool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#ifndef FWCore_Concurrency_WaitingThreadPool_h
#define FWCore_Concurrency_WaitingThreadPool_h

#include "FWCore/Utilities/interface/ConvertException.h"
#include "FWCore/Utilities/interface/ReusableObjectHolder.h"
#include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"

#include <condition_variable>
#include <mutex>
#include <thread>

namespace edm {
namespace impl {
class WaitingThread {
public:
WaitingThread();
~WaitingThread() noexcept;

WaitingThread(WaitingThread const&) = delete;
WaitingThread& operator=(WaitingThread&&) = delete;
WaitingThread(WaitingThread&&) = delete;
WaitingThread& operator=(WaitingThread const&) = delete;

template <typename F, typename G>
void run(WaitingTaskWithArenaHolder holder,
F&& func,
G&& errorContextFunc,
std::shared_ptr<WaitingThread> thisPtr) {
std::unique_lock lk(mutex_);
func_ = [holder = std::move(holder),
func = std::forward<F>(func),
errorContext = std::forward<G>(errorContextFunc)]() mutable {
try {
convertException::wrap([&func]() { func(); });
} catch (cms::Exception& e) {
e.addContext(errorContext());
holder.doneWaiting(std::current_exception());
}
};
thisPtr_ = std::move(thisPtr);
cond_.notify_one();
}

private:
void stopThread() {
std::unique_lock lk(mutex_);
stopThread_ = true;
cond_.notify_one();
}

void threadLoop() noexcept;

std::thread thread_;
std::mutex mutex_;
std::condition_variable cond_;
CMS_THREAD_GUARD(mutex_) std::function<void()> func_;
// The purpose of thisPtr_ is to keep the WaitingThread object
// outside of the WaitingThreadPool until the func_ has returned.
CMS_THREAD_GUARD(mutex_) std::shared_ptr<WaitingThread> thisPtr_;
CMS_THREAD_GUARD(mutex_) bool stopThread_ = false;
};
} // namespace impl

// Provides a mechanism to run the function 'func' asynchronously,
// i.e. without the calling thread to wait for the func() to return.
// The func should do as little work (outside of the TBB threadpool)
// as possible. The func must terminate eventually. The intended use
// case are blocking synchronization calls with external entities,
// where the calling thread is suspended while waiting.
//
// The func() is run in a thread that belongs to a separate pool of
// threads than the calling thread. Remotely similar to
// std::async(), but instead of dealing with std::futures, takes an
// edm::WaitingTaskWithArenaHolder object, that is signaled upon the
// func() returning or throwing an exception.
//
// The caller is responsible for keeping the WaitingThreadPool
// object alive at least as long as all asynchronous calls finish.
class WaitingThreadPool {
public:
WaitingThreadPool() = default;
WaitingThreadPool(WaitingThreadPool const&) = delete;
WaitingThreadPool& operator=(WaitingThreadPool const&) = delete;
WaitingThreadPool(WaitingThreadPool&&) = delete;
WaitingThreadPool& operator=(WaitingThreadPool&&) = delete;

/**
* \param holder WaitingTaskWithArenaHolder object to signal the completion of 'func'
* \param func Function to run in a separate thread
* \param errorContextFunc Function returning a string-like object
* that is added to the context of
* cms::Exception in case 'func' throws an
* exception
*/
template <typename F, typename G>
void runAsync(WaitingTaskWithArenaHolder holder, F&& func, G&& errorContextFunc) {
auto thread = pool_.makeOrGet([]() { return std::make_unique<impl::WaitingThread>(); });
thread->run(std::move(holder), std::forward<F>(func), std::forward<G>(errorContextFunc), std::move(thread));
}

private:
edm::ReusableObjectHolder<impl::WaitingThread> pool_;
};
} // namespace edm

#endif
5 changes: 5 additions & 0 deletions FWCore/Concurrency/src/Async.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "FWCore/Concurrency/interface/Async.h"

namespace edm {
Async::~Async() noexcept = default;
}
58 changes: 58 additions & 0 deletions FWCore/Concurrency/src/WaitingThreadPool.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "FWCore/Concurrency/interface/WaitingThreadPool.h"

#include <cassert>
#include <string_view>

#include <pthread.h>

namespace edm::impl {
WaitingThread::WaitingThread() {
thread_ = std::thread(&WaitingThread::threadLoop, this);
static constexpr auto poolName = "edm async pool";
// pthread_setname_np() string length is limited to 16 characters,
// including the null termination
static_assert(std::string_view(poolName).size() < 16);

int err = pthread_setname_np(thread_.native_handle(), poolName);
// According to the glibc documentation, the only error
// pthread_setname_np() can return is about the argument C-string
// being too long. We already check above the C-string is shorter
// than the limit was at the time of writing. In order to capture
// if the limit shortens, or other error conditions get added,
// let's assert() anyway (exception feels overkill)
assert(err == 0);
}

WaitingThread::~WaitingThread() noexcept {
// When we are shutting down, we don't care about any possible
// system errors anymore
CMS_SA_ALLOW try {
stopThread();
thread_.join();
} catch (...) {
}
}

void WaitingThread::threadLoop() noexcept {
std::unique_lock lk(mutex_);

while (true) {
cond_.wait(lk, [this]() { return static_cast<bool>(func_) or stopThread_; });
if (stopThread_) {
// There should be no way to stop the thread when it as the
// func_ assigned, but let's make sure
assert(not thisPtr_);
break;
}
func_();
// Must return this WaitingThread to the ReusableObjectHolder in
// the WaitingThreadPool before resettting func_ (that holds the
// WaitingTaskWithArenaHolder, that enables the progress in the
// TBB thread pool) in order to meet the requirement of
// ReusableObjectHolder destructor that there are no outstanding
// objects.
thisPtr_.reset();
decltype(func_)().swap(func_);
}
}
} // namespace edm::impl
101 changes: 101 additions & 0 deletions FWCore/Concurrency/test/test_catch2_Async.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "catch.hpp"

#include <atomic>

#include "oneapi/tbb/global_control.h"

#include "FWCore/Concurrency/interface/chain_first.h"
#include "FWCore/Concurrency/interface/FinalWaitingTask.h"
#include "FWCore/Concurrency/interface/Async.h"

namespace {
constexpr char const* errorContext() { return "AsyncServiceTest"; }

class AsyncServiceTest : public edm::Async {
public:
enum class State { kAllowed, kDisallowed, kShutdown };

AsyncServiceTest() = default;

void setAllowed(bool allowed) noexcept { allowed_ = allowed; }

private:
void ensureAllowed() const final {
if (not allowed_) {
throw std::runtime_error("Calling run in this context is not allowed");
}
}

std::atomic<bool> allowed_ = true;
};
} // namespace

TEST_CASE("Test Async", "[edm::Async") {
// Using parallelism 2 here because otherwise the
// tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will
// start a new TBB thread that "inherits" the name from the
// WaitingThreadPool thread.
oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);

SECTION("Normal operation") {
AsyncServiceTest service;
std::atomic<int> count{0};

oneapi::tbb::task_group group;
edm::FinalWaitingTask waitTask{group};

{
using namespace edm::waiting_task::chain;
auto h1 = first([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
}) |
lastTask(edm::WaitingTaskHolder(group, &waitTask));

auto h2 = first([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
}) |
lastTask(edm::WaitingTaskHolder(group, &waitTask));
h2.doneWaiting(std::exception_ptr());
h1.doneWaiting(std::exception_ptr());
}
waitTask.waitNoThrow();
REQUIRE(count.load() == 2);
REQUIRE(waitTask.done());
REQUIRE(not waitTask.exceptionPtr());
}

SECTION("Disallowed") {
AsyncServiceTest service;
std::atomic<int> count{0};

oneapi::tbb::task_group group;
edm::FinalWaitingTask waitTask{group};

{
using namespace edm::waiting_task::chain;
auto h = first([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
service.setAllowed(false);
}) |
then([&service, &count](edm::WaitingTaskHolder h) {
edm::WaitingTaskWithArenaHolder h2(std::move(h));
service.runAsync(
h2, [&count]() { ++count; }, errorContext);
}) |
lastTask(edm::WaitingTaskHolder(group, &waitTask));
h.doneWaiting(std::exception_ptr());
}
waitTask.waitNoThrow();
REQUIRE(count.load() == 1);
REQUIRE(waitTask.done());
REQUIRE(waitTask.exceptionPtr());
REQUIRE_THROWS_WITH(std::rethrow_exception(waitTask.exceptionPtr()),
Catch::Contains("Calling run in this context is not allowed"));
}
}
Loading