Skip to content

Commit

Permalink
Merge pull request TESSEorg#283 from devreal/new_parsec_device_code
Browse files Browse the repository at this point in the history
Add device hints and return to PaRSEC main repo
  • Loading branch information
devreal authored Jun 4, 2024
2 parents d083245 + a9c33d4 commit 754c7d7
Show file tree
Hide file tree
Showing 22 changed files with 359 additions and 130 deletions.
23 changes: 21 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ option(TTG_ENABLE_LEVEL_ZERO "Whether to TTG will look for Intel oneAPI Level Ze
option(TTG_EXAMPLES "Whether to build examples" OFF)
option(TTG_ENABLE_ASAN "Whether to enable address sanitizer" OFF)

option(TTG_ENABLE_COROUTINES "Whether to enable C++ coroutines, needed for accelerator device support" ON)
option(TTG_FETCH_BOOST "Whether to fetch+build Boost, if missing" OFF)
option(TTG_IGNORE_BUNDLED_EXTERNALS "Whether to skip installation and use of bundled external dependencies (Boost.CallableTraits)" OFF)
option(TTG_ENABLE_TRACE "Whether to enable ttg::trace() output" OFF)
Expand Down Expand Up @@ -94,8 +95,26 @@ endif (BUILD_TESTING)
###########################
# Boost
include("${PROJECT_SOURCE_DIR}/cmake/modules/FindOrFetchBoost.cmake")
# C++ coroutines
find_package(CXXStdCoroutine MODULE REQUIRED COMPONENTS Final Experimental)

if (TTG_ENABLE_COROUTINES)
set(SKIP_COROUTINE_DETECTION FALSE)
# C++ coroutines, check for broken GCC releases and skip if one is found
if (${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 11.4.0)
set(SKIP_COROUTINE_DETECTION TRUE)
elseif(${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.1.0 AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 12.3.0)
set(SKIP_COROUTINE_DETECTION TRUE)
endif()
if (SKIP_COROUTINE_DETECTION)
message(WARNING "GCC with broken Coroutine support detected, disabling Coroutine support. At least GCC 11.4, 12.3, or 13.1 required.")
endif(SKIP_COROUTINE_DETECTION)
endif(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")

if (NOT SKIP_COROUTINE_DETECTION)
find_package(CXXStdCoroutine MODULE REQUIRED COMPONENTS Final Experimental)
set(TTG_HAVE_COROUTINE CXXStdCoroutine_FOUND CACHE BOOL "True if the compiler has coroutine support")
endif(NOT SKIP_COROUTINE_DETECTION)
endif(TTG_ENABLE_COROUTINES)


##########################
Expand Down
2 changes: 1 addition & 1 deletion cmake/modules/ExternalDependenciesVersions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
set(TTG_TRACKED_VG_CMAKE_KIT_TAG 7ea2d4d3f8854b9e417f297fd74d6fc49aa13fd5) # used to provide "real" FindOrFetchBoost
set(TTG_TRACKED_CATCH2_VERSION 3.5.0)
set(TTG_TRACKED_MADNESS_TAG 2eb3bcf0138127ee2dbc651f1aabd3e9b0def4e3)
set(TTG_TRACKED_PARSEC_TAG 0b3140f58ad9dc78a3d64da9fd73ecc7f443ece7)
set(TTG_TRACKED_PARSEC_TAG 58f8f3089ecad2e8ee50e80a9586e05ce8873b1c)
set(TTG_TRACKED_BTAS_TAG 4e8f5233aa7881dccdfcc37ce07128833926d3c2)
set(TTG_TRACKED_TILEDARRAY_TAG 493c109379a1b64ddd5ef59f7e33b95633b68d73)

Expand Down
2 changes: 1 addition & 1 deletion cmake/modules/FindOrFetchPARSEC.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ if (NOT TARGET PaRSEC::parsec)

FetchContent_Declare(
PARSEC
GIT_REPOSITORY https://github.com/devreal/parsec-1.git
GIT_REPOSITORY https://github.com/ICLDisco/parsec.git
GIT_TAG ${TTG_TRACKED_PARSEC_TAG}
)
FetchContent_MakeAvailable(PARSEC)
Expand Down
10 changes: 9 additions & 1 deletion examples/potrf/pmw.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ class PaRSECMatrixWrapper {
(pm->uplo == PARSEC_MATRIX_UPPER && col >= row);
}

int P() const {
return pm->grid.rows;
}

int Q() const {
return pm->grid.cols;
}

PaRSECMatrixT* parsec() {
return pm;
}
Expand Down Expand Up @@ -132,7 +140,7 @@ class PaRSECMatrixWrapper {
};

template<typename ValueT>
using MatrixT = PaRSECMatrixWrapper<sym_two_dim_block_cyclic_t, ValueT>;
using MatrixT = PaRSECMatrixWrapper<parsec_matrix_sym_block_cyclic_t, ValueT>;

static auto make_load_tt(MatrixT<double> &A, ttg::Edge<Key2, MatrixTile<double>> &toop, bool defer_write)
{
Expand Down
26 changes: 25 additions & 1 deletion examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,22 @@ namespace potrf {
auto keymap1 = [&](const Key1& key) { return A.rank_of(key[0], key[0]); };

auto keymap2a = [&](const Key2& key) { return A.rank_of(key[0], key[1]); };
auto keymap2b = [&](const Key2& key) { return A.rank_of(key[0], key[0]); };
auto keymap2b = [&](const Key2& key) { return A.rank_of(key[1], key[1]); };

auto keymap3 = [&](const Key3& key) { return A.rank_of(key[0], key[1]); };

/**
* Device map hints: we try to keep tiles on one row on the same device to minimize
* data movement between devices. This provides hints for load-balancing up front
* and avoids movement of the TRSM result to GEMM tasks.
*/
auto devmap1 = [&](const Key1& key) { return (key[0] / A.P()) % ttg::device::num_devices(); };

auto devmap2a = [&](const Key2& key) { return (key[0] / A.P()) % ttg::device::num_devices(); };
auto devmap2b = [&](const Key2& key) { return (key[1] / A.P()) % ttg::device::num_devices(); };

auto devmap3 = [&](const Key3& key) { return (key[0] / A.P()) % ttg::device::num_devices(); };

ttg::Edge<Key1, MatrixTile<T>> syrk_potrf("syrk_potrf"), disp_potrf("disp_potrf");

ttg::Edge<Key2, MatrixTile<T>> potrf_trsm("potrf_trsm"), trsm_syrk("trsm_syrk"), gemm_trsm("gemm_trsm"),
Expand All @@ -692,18 +704,30 @@ namespace potrf {
auto tt_potrf = make_potrf(A, disp_potrf, syrk_potrf, potrf_trsm, output);
tt_potrf->set_keymap(keymap1);
tt_potrf->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_potrf->set_devicemap(devmap1);
#endif // 0

auto tt_trsm = make_trsm(A, disp_trsm, potrf_trsm, gemm_trsm, trsm_syrk, trsm_gemm_row, trsm_gemm_col, output);
tt_trsm->set_keymap(keymap2a);
tt_trsm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_trsm->set_devicemap(devmap2a);
#endif // 0

auto tt_syrk = make_syrk(A, disp_syrk, trsm_syrk, syrk_syrk, syrk_potrf, syrk_syrk);
tt_syrk->set_keymap(keymap2b);
tt_syrk->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_syrk->set_devicemap(devmap2b);
#endif // 0

auto tt_gemm = make_gemm(A, disp_gemm, trsm_gemm_row, trsm_gemm_col, gemm_gemm, gemm_trsm, gemm_gemm);
tt_gemm->set_keymap(keymap3);
tt_gemm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_gemm->set_devicemap(devmap3);
#endif // 0

/* Priorities taken from DPLASMA */
auto nt = A.cols();
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ set(ut_libs Catch2::Catch2)

# coroutine tests
# we definitely have TARGET std::coroutine
list(APPEND ut_src fibonacci-coro.cc)
list(APPEND ut_src device_coro.cc)
if (TTG_HAVE_CUDA)
list(APPEND ut_src cuda_kernel.cu)
endif(TTG_HAVE_CUDA)
list(APPEND ut_libs std::coroutine)
if (CXXStdCoroutine_FOUND)
list(APPEND ut_src fibonacci-coro.cc)
list(APPEND ut_src device_coro.cc)
if (TTG_HAVE_CUDA)
list(APPEND ut_src cuda_kernel.cu)
endif(TTG_HAVE_CUDA)
list(APPEND ut_libs std::coroutine)
endif(CXXStdCoroutine_FOUND)

add_ttg_executable(core-unittests-ttg "${ut_src}" LINK_LIBRARIES "${ut_libs}" COMPILE_DEFINITIONS "CATCH_CONFIG_NO_POSIX_SIGNALS=1" )

Expand Down
2 changes: 1 addition & 1 deletion ttg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ if (TTG_ENABLE_TRACE)
endif (TTG_ENABLE_TRACE)
if (TARGET std::coroutine)
list(APPEND ttg-deps std::coroutine)
list(APPEND ttg-defs "TTG_HAS_COROUTINE=1")
list(APPEND ttg-util-headers
${CMAKE_CURRENT_SOURCE_DIR}/ttg/coroutine.h
)
Expand Down Expand Up @@ -208,6 +207,7 @@ endif(TARGET Boost::serialization)
if (TARGET MADworld)
set(ttg-mad-headers
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/buffer.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/device.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/fwd.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/import.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/madness/ttg.h
Expand Down
3 changes: 3 additions & 0 deletions ttg/ttg/config.in.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
/** the C++ namespace containing the coroutine API */
#define TTG_CXX_COROUTINE_NAMESPACE @CXX_COROUTINE_NAMESPACE@

/** whether the compiler supports C++ coroutines */
#cmakedefine TTG_HAVE_COROUTINE

/** whether TTG has CUDA language support */
#cmakedefine TTG_HAVE_CUDA

Expand Down
5 changes: 5 additions & 0 deletions ttg/ttg/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
#define TTG_COROUTINE_H

#include "ttg/config.h"

#ifdef TTG_HAVE_COROUTINE
#include TTG_CXX_COROUTINE_HEADER

#include <algorithm>
#include <array>


namespace ttg {

// import std coroutine API into ttg namespace
Expand Down Expand Up @@ -227,4 +230,6 @@ namespace ttg {

} // namespace ttg

#endif // TTG_HAVE_COROUTINE

#endif // TTG_COROUTINE_H
8 changes: 8 additions & 0 deletions ttg/ttg/device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "ttg/config.h"
#include "ttg/execution.h"
#include "ttg/impl_selector.h"
#include "ttg/fwd.h"



Expand Down Expand Up @@ -180,3 +182,9 @@ namespace ttg::device {
}
} // namespace ttg
#endif // defined(TTG_HAVE_HIP)

namespace ttg::device {
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}
}
6 changes: 5 additions & 1 deletion ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "ttg/impl_selector.h"
#include "ttg/ptr.h"

#ifdef TTG_HAVE_COROUTINE

namespace ttg::device {

namespace detail {
Expand Down Expand Up @@ -632,6 +634,8 @@ namespace ttg::device {
bool device_reducer::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
#endif // 0

} // namespace ttg::devie
} // namespace ttg::device

#endif // TTG_HAVE_COROUTINE

#endif // TTG_DEVICE_TASK_H
9 changes: 9 additions & 0 deletions ttg/ttg/madness/device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef TTG_MADNESS_DEVICE_H
#define TTG_MADNESS_DEVICE_H

namespace ttg_madness {
/* no device support in MADNESS */
inline int num_devices() { return 0; }
}

#endif // TTG_MADNESS_DEVICE_H
2 changes: 2 additions & 0 deletions ttg/ttg/madness/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ namespace ttg_madness {
template<typename... Buffer>
inline void mark_device_out(std::tuple<Buffer&...> &b);

inline int num_devices();

} // namespace ttg_madness

#endif // TTG_MADNESS_FWD_H
23 changes: 11 additions & 12 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ttg/base/keymap.h"
#include "ttg/base/tt.h"
#include "ttg/func.h"
#include "ttg/madness/device.h"
#include "ttg/runtimes.h"
#include "ttg/tt.h"
#include "ttg/util/bug.h"
Expand All @@ -23,9 +24,7 @@
#include "ttg/util/meta/callable.h"
#include "ttg/util/void.h"
#include "ttg/world.h"
#ifdef TTG_HAS_COROUTINE
#include "ttg/coroutine.h"
#endif

#include <array>
#include <cassert>
Expand Down Expand Up @@ -302,10 +301,10 @@ namespace ttg_madness {
derivedT *derived; // Pointer to derived class instance
bool pull_terminals_invoked = false;
std::conditional_t<ttg::meta::is_void_v<keyT>, ttg::Void, keyT> key; // Task key
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
void *suspended_task_address = nullptr; // if not null the function is suspended
ttg::TaskCoroutineID coroutine_id = ttg::TaskCoroutineID::Invalid;
#endif
#endif // TTG_HAVE_COROUTINE

/// makes a tuple of references out of tuple of
template <typename Tuple, std::size_t... Is>
Expand Down Expand Up @@ -335,11 +334,11 @@ namespace ttg_madness {
ttT::threaddata.call_depth++;

void *suspended_task_address =
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
this->suspended_task_address; // non-null = need to resume the task
#else
#else // TTG_HAVE_COROUTINE
nullptr;
#endif
#endif // TTG_HAVE_COROUTINE
if (suspended_task_address == nullptr) { // task is a coroutine that has not started or an ordinary function
// ttg::print("starting task");
if constexpr (!ttg::meta::is_void_v<keyT> && !ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
Expand All @@ -361,7 +360,7 @@ namespace ttg_madness {
} else // unreachable
ttg::abort();
} else { // resume suspended coroutine
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
auto ret = static_cast<ttg::resumable_task>(ttg::coroutine_handle<ttg::resumable_task_state>::from_address(suspended_task_address));
assert(ret.ready());
ret.resume();
Expand All @@ -372,9 +371,9 @@ namespace ttg_madness {
// leave suspended_task_address as is
}
this->suspended_task_address = suspended_task_address;
#else
#else // TTG_HAVE_COROUTINE
ttg::abort(); // should not happen
#endif
#endif // TTG_HAVE_COROUTINE
}

ttT::threaddata.call_depth--;
Expand All @@ -383,7 +382,7 @@ namespace ttg_madness {
// ttg::print("finishing task",ttT::threaddata.call_depth);
// }

#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
if (suspended_task_address) {
// TODO implement handling of suspended coroutines properly

Expand Down Expand Up @@ -411,7 +410,7 @@ namespace ttg_madness {
ttg::abort();
}
}
#endif // TTG_HAS_COROUTINE
#endif // TTG_HAVE_COROUTINE
}

virtual ~TTArgs() {} // Will be deleted via TaskInterface*
Expand Down
8 changes: 4 additions & 4 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class CallableWrapTTArgs
std::conditional_t<std::is_function_v<noref_funcT>, std::add_pointer_t<noref_funcT>, noref_funcT> func;

using op_return_type =
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
std::conditional_t<std::is_same_v<returnT, ttg::resumable_task>,
ttg::coroutine_handle<ttg::resumable_task_state>,
#ifdef TTG_HAVE_DEVICE
Expand All @@ -160,9 +160,9 @@ class CallableWrapTTArgs
void
#endif // TTG_HAVE_DEVICE
>;
#else // TTG_HAS_COROUTINE
#else // TTG_HAVE_COROUTINE
void;
#endif // TTG_HAS_COROUTINE
#endif // TTG_HAVE_COROUTINE

public:
static constexpr bool have_cuda_op = (space == ttg::ExecutionSpace::CUDA);
Expand All @@ -176,7 +176,7 @@ class CallableWrapTTArgs
static_assert(std::is_same_v<std::remove_reference_t<decltype(ret)>, returnT>,
"CallableWrapTTArgs<funcT,returnT,...>: returnT does not match the actual return type of funcT");
if constexpr (!std::is_void_v<returnT>) { // protect from compiling for void returnT
#ifdef TTG_HAS_COROUTINE
#ifdef TTG_HAVE_COROUTINE
if constexpr (std::is_same_v<returnT, ttg::resumable_task>) {
ttg::coroutine_handle<ttg::resumable_task_state> coro_handle;
// if task completed destroy it
Expand Down
8 changes: 8 additions & 0 deletions ttg/ttg/parsec/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ struct Buffer : public detail::ttg_parsec_data_wrapper_t
// << " parsec_data " << m_data.get() << std::endl;
}

void prefer_device(ttg::device::Device dev) {
/* only set device if the host has the latest copy as otherwise we might end up with a stale copy */
if (dev.is_device() && this->parsec_data()->owner_device == 0) {
parsec_advise_data_on_device(this->parsec_data(), detail::ttg_device_to_parsec_device(dev),
PARSEC_DEV_DATA_ADVICE_PREFERRED_DEVICE);
}
}

/* serialization support */

#ifdef TTG_SERIALIZATION_SUPPORTS_BOOST
Expand Down
Loading

0 comments on commit 754c7d7

Please sign in to comment.