Skip to content

Commit

Permalink
Replace TTG_HAS_COROUTINE with TTG_HAVE_COROUTINE and add to config.h
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
  • Loading branch information
devreal committed May 30, 2024
1 parent 1421c4f commit df32b1e
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 38 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ if (${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
endif(SKIP_COROUTINE_DETECTION)
endif(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")

if (SKIP_COROUTINE_DETECTION)
if (NOT SKIP_COROUTINE_DETECTION)
find_package(CXXStdCoroutine MODULE REQUIRED COMPONENTS Final Experimental)
set(TTG_HAVE_COROUTINE CXXStdCoroutine_FOUND CACHE BOOL "True if the compiler has coroutine support")
endif(SKIP_COROUTINE_DETECTION)


Expand Down
1 change: 0 additions & 1 deletion ttg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ if (TTG_ENABLE_TRACE)
endif (TTG_ENABLE_TRACE)
if (TARGET std::coroutine)
list(APPEND ttg-deps std::coroutine)
list(APPEND ttg-defs "TTG_HAS_COROUTINE=1")
list(APPEND ttg-util-headers
${CMAKE_CURRENT_SOURCE_DIR}/ttg/coroutine.h
)
Expand Down
3 changes: 3 additions & 0 deletions ttg/ttg/config.in.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
/** the C++ namespace containing the coroutine API */
#define TTG_CXX_COROUTINE_NAMESPACE @CXX_COROUTINE_NAMESPACE@

/** whether the compiler supports C++ coroutines */
#cmakedefine TTG_HAVE_COROUTINE

/** whether TTG has CUDA language support */
#cmakedefine TTG_HAVE_CUDA

Expand Down
4 changes: 4 additions & 0 deletions ttg/ttg/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <algorithm>
#include <array>

#ifdef TTG_HAVE_COROUTINE

namespace ttg {

// import std coroutine API into ttg namespace
Expand Down Expand Up @@ -227,4 +229,6 @@ namespace ttg {

} // namespace ttg

#endif // TTG_HAVE_COROUTINE

#endif // TTG_COROUTINE_H
6 changes: 5 additions & 1 deletion ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "ttg/impl_selector.h"
#include "ttg/ptr.h"

#ifdef TTG_HAVE_COROUTINE

namespace ttg::device {

namespace detail {
Expand Down Expand Up @@ -632,6 +634,8 @@ namespace ttg::device {
bool device_reducer::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
#endif // 0

} // namespace ttg::devie
} // namespace ttg::device

#endif // TTG_HAVE_COROUTINE

#endif // TTG_DEVICE_TASK_H
22 changes: 10 additions & 12 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
#include "ttg/util/meta/callable.h"
#include "ttg/util/void.h"
#include "ttg/world.h"
#ifdef TTG_HAS_COROUTINE
#include "ttg/coroutine.h"
#endif

#include <array>
#include <cassert>
Expand Down Expand Up @@ -303,10 +301,10 @@ namespace ttg_madness {
derivedT *derived; // Pointer to derived class instance
bool pull_terminals_invoked = false;
std::conditional_t<ttg::meta::is_void_v<keyT>, ttg::Void, keyT> key; // Task key
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
void *suspended_task_address = nullptr; // if not null the function is suspended
ttg::TaskCoroutineID coroutine_id = ttg::TaskCoroutineID::Invalid;
#endif
#endif // TTG_HAVE_COROUTINE

/// makes a tuple of references out of tuple of
template <typename Tuple, std::size_t... Is>
Expand Down Expand Up @@ -336,11 +334,11 @@ namespace ttg_madness {
ttT::threaddata.call_depth++;

void *suspended_task_address =
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
this->suspended_task_address; // non-null = need to resume the task
#else
#else // TTG_HAVE_COROUTINE
nullptr;
#endif
#endif // TTG_HAVE_COROUTINE
if (suspended_task_address == nullptr) { // task is a coroutine that has not started or an ordinary function
// ttg::print("starting task");
if constexpr (!ttg::meta::is_void_v<keyT> && !ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
Expand All @@ -362,7 +360,7 @@ namespace ttg_madness {
} else // unreachable
ttg::abort();
} else { // resume suspended coroutine
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
auto ret = static_cast<ttg::resumable_task>(ttg::coroutine_handle<ttg::resumable_task_state>::from_address(suspended_task_address));
assert(ret.ready());
ret.resume();
Expand All @@ -373,9 +371,9 @@ namespace ttg_madness {
// leave suspended_task_address as is
}
this->suspended_task_address = suspended_task_address;
#else
#else // TTG_HAVE_COROUTINE
ttg::abort(); // should not happen
#endif
#endif // TTG_HAVE_COROUTINE
}

ttT::threaddata.call_depth--;
Expand All @@ -384,7 +382,7 @@ namespace ttg_madness {
// ttg::print("finishing task",ttT::threaddata.call_depth);
// }

#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
if (suspended_task_address) {
// TODO implement handling of suspended coroutines properly

Expand Down Expand Up @@ -412,7 +410,7 @@ namespace ttg_madness {
ttg::abort();
}
}
#endif // TTG_HAS_COROUTINE
#endif // TTG_HAVE_COROUTINE
}

virtual ~TTArgs() {} // Will be deleted via TaskInterface*
Expand Down
8 changes: 4 additions & 4 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class CallableWrapTTArgs
std::conditional_t<std::is_function_v<noref_funcT>, std::add_pointer_t<noref_funcT>, noref_funcT> func;

using op_return_type =
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
std::conditional_t<std::is_same_v<returnT, ttg::resumable_task>,
ttg::coroutine_handle<ttg::resumable_task_state>,
#ifdef TTG_HAVE_DEVICE
Expand All @@ -160,9 +160,9 @@ class CallableWrapTTArgs
void
#endif // TTG_HAVE_DEVICE
>;
#else // TTG_HAS_COROUTINE
#else // TTG_HAVE_COROUTINE
void;
#endif // TTG_HAS_COROUTINE
#endif // TTG_HAVE_COROUTINE

public:
static constexpr bool have_cuda_op = (space == ttg::ExecutionSpace::CUDA);
Expand All @@ -176,7 +176,7 @@ class CallableWrapTTArgs
static_assert(std::is_same_v<std::remove_reference_t<decltype(ret)>, returnT>,
"CallableWrapTTArgs<funcT,returnT,...>: returnT does not match the actual return type of funcT");
if constexpr (!std::is_void_v<returnT>) { // protect from compiling for void returnT
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
if constexpr (std::is_same_v<returnT, ttg::resumable_task>) {
ttg::coroutine_handle<ttg::resumable_task_state> coro_handle;
// if task completed destroy it
Expand Down
4 changes: 2 additions & 2 deletions ttg/ttg/parsec/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ namespace ttg_parsec {
TT* tt = nullptr;
key_type key;
std::array<stream_info_t, num_streams> streams;
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
void* suspended_task_address = nullptr; // if not null the function is suspended
ttg::TaskCoroutineID coroutine_id = ttg::TaskCoroutineID::Invalid;
#endif
Expand Down Expand Up @@ -268,7 +268,7 @@ namespace ttg_parsec {
static constexpr size_t num_streams = TT::numins;
TT* tt = nullptr;
std::array<stream_info_t, num_streams> streams;
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
void* suspended_task_address = nullptr; // if not null the function is suspended
ttg::TaskCoroutineID coroutine_id = ttg::TaskCoroutineID::Invalid;
#endif
Expand Down
32 changes: 17 additions & 15 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
* This may cause deadlocks, so use with caution. */
#define TTG_PARSEC_DEFER_WRITER false

#include "ttg/config.h"

#include "ttg/impl_selector.h"

/* include ttg header to make symbols available in case this header is included directly */
Expand Down Expand Up @@ -1643,11 +1645,11 @@ namespace ttg_parsec {

task_t *task = (task_t*)parsec_task;
void* suspended_task_address =
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
task->suspended_task_address; // non-null = need to resume the task
#else
#else // TTG_HAVE_COROUTINE
nullptr;
#endif
#endif // TTG_HAVE_COROUTINE
//std::cout << "static_op: suspended_task_address " << suspended_task_address << std::endl;
if (suspended_task_address == nullptr) { // task is a coroutine that has not started or an ordinary function

Expand Down Expand Up @@ -1679,9 +1681,9 @@ namespace ttg_parsec {
}
else { // resume the suspended coroutine

#ifdef TTG_HAVE_COROUTINE
assert(task->coroutine_id != ttg::TaskCoroutineID::Invalid);

#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_DEVICE
if (task->coroutine_id == ttg::TaskCoroutineID::DeviceTask) {
ttg::device::Task coro = ttg::device::detail::device_task_handle_type::from_address(suspended_task_address);
Expand Down Expand Up @@ -1725,9 +1727,9 @@ namespace ttg_parsec {
}
else
ttg::abort(); // unrecognized task id
#else // TTG_HAS_COROUTINE
ttg::abort(); // should not happen
#endif // TTG_HAS_COROUTINE
#else // TTG_HAVE_COROUTINE
ttg::abort(); // should not happen
#endif // TTG_HAVE_COROUTINE
}
task->suspended_task_address = suspended_task_address;

Expand All @@ -1750,11 +1752,11 @@ ttg::abort(); // should not happen
task_t *task = static_cast<task_t*>(parsec_task);

void* suspended_task_address =
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
task->suspended_task_address; // non-null = need to resume the task
#else
#else // TTG_HAVE_COROUTINE
nullptr;
#endif
#endif // TTG_HAVE_COROUTINE
if (suspended_task_address == nullptr) { // task is a coroutine that has not started or an ordinary function
ttT *baseobj = (ttT *)task->object_ptr;
derivedT *obj = (derivedT *)task->object_ptr;
Expand All @@ -1769,7 +1771,7 @@ ttg::abort(); // should not happen
detail::parsec_ttg_caller = NULL;
}
else {
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
auto ret = static_cast<ttg::resumable_task>(ttg::coroutine_handle<ttg::resumable_task_state>::from_address(suspended_task_address));
assert(ret.ready());
ret.resume();
Expand All @@ -1780,9 +1782,9 @@ ttg::abort(); // should not happen
else { // not yet completed
// leave suspended_task_address as is
}
#else
#else // TTG_HAVE_COROUTINE
ttg::abort(); // should not happen
#endif
#endif // TTG_HAVE_COROUTINE
}
task->suspended_task_address = suspended_task_address;

Expand Down Expand Up @@ -3695,7 +3697,7 @@ ttg::abort(); // should not happen

task_t *task = (task_t*)parsec_task;

#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
/* if we still have a coroutine handle we invoke it one more time to get the sends/broadcasts */
if (task->suspended_task_address) {
assert(task->coroutine_id != ttg::TaskCoroutineID::Invalid);
Expand Down Expand Up @@ -3726,7 +3728,7 @@ ttg::abort(); // should not happen
/* the coroutine should have completed and we cannot access the promise anymore */
task->suspended_task_address = nullptr;
}
#endif // TTG_HAS_COROUTINE
#endif // TTG_HAVE_COROUTINE

/* release our data copies */
for (int i = 0; i < task->data_count; i++) {
Expand Down
5 changes: 3 additions & 2 deletions ttg/ttg/tt.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#ifndef TTG_TT_H
#define TTG_TT_H

#include "ttg/config.h"
#include "ttg/fwd.h"

#include "ttg/base/tt.h"
#include "ttg/edge.h"

#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
#include "ttg/coroutine.h"
#endif

Expand Down Expand Up @@ -176,7 +177,7 @@ namespace ttg {
} // namespace ttg

#ifndef TTG_PROCESS_TT_OP_RETURN
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \
{ \
using return_type = decltype(invoke); \
Expand Down

0 comments on commit df32b1e

Please sign in to comment.