Skip to content

Commit

Permalink
Merge branch 'ttg-device-support-master-coro-with-stream-tasks' into …
Browse files Browse the repository at this point in the history
…potrf-cuda-wip
  • Loading branch information
devreal authored Oct 11, 2023
2 parents 9d21225 + 366e60e commit 94f3a81
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 48 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ if (TTG_ENABLE_HIP)
add_compile_definitions(${TTG_HIP_PLATFORM})
endif(TTG_ENABLE_HIP)

set(_ttg_have_device FALSE)
if (TTG_HAVE_CUDA OR TTG_HAVE_HIP)
set(_ttg_have_device TRUE)
endif()
set(TTG_HAVE_DEVICE ${_ttg_have_device} CACHE BOOL "True if TTG has support for any device programming model")

##########################
#### prerequisite runtimes
##########################
Expand Down
2 changes: 1 addition & 1 deletion cmake/modules/ExternalDependenciesVersions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ set(TTG_TRACKED_CATCH2_VERSION 2.13.1)
set(TTG_TRACKED_CEREAL_VERSION 1.3.0)
set(TTG_TRACKED_MADNESS_TAG e4ee892717e7f91f824a9652d53d4da2acf6920e)
set(TTG_TRACKED_PARSEC_TAG e59bed9d4934775792a58980cf5500bd41381bc4)
set(TTG_TRACKED_BTAS_TAG d73153ad9bc41a177e441ef04eceff7fab0c766d)
set(TTG_TRACKED_BTAS_TAG a02be0d29fb4a788ecef43de711dcd6d6f1cb6b8)
set(TTG_TRACKED_TILEDARRAY_TAG dfa3f76ebd58c05c64acf6e57cbdc85e90c880a8)
16 changes: 11 additions & 5 deletions cmake/modules/FindTBB.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,17 @@ findpkg_finish(TBB_MALLOC_PROXY)
#parse all the version numbers from tbb
if(NOT TBB_VERSION)

#only read the start of the file
file(READ
"${TBB_INCLUDE_DIR}/tbb/tbb_stddef.h"
TBB_VERSION_CONTENTS
LIMIT 2048)
if (EXISTS "${TBB_INCLUDE_DIR}/oneapi/tbb/version.h")
file(STRINGS
"${TBB_INCLUDE_DIR}/oneapi/tbb/version.h"
TBB_VERSION_CONTENTS
REGEX "VERSION")
else()
file(STRINGS
"${TBB_INCLUDE_DIR}/tbb/tbb_stddef.h"
TBB_VERSION_CONTENTS
REGEX "VERSION")
endif()

string(REGEX REPLACE
".*#define TBB_VERSION_MAJOR ([0-9]+).*" "\\1"
Expand Down
41 changes: 26 additions & 15 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <vector>

#if __has_include(<btas/features.h>)
#pragma message("C Preprocessor got here!")
#include <btas/features.h>
#ifdef BTAS_IS_USABLE
#include <btas/btas.h>
Expand Down Expand Up @@ -41,7 +40,8 @@ using namespace ttg;
#include "ttg/util/bug.h"

#if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE)
using blk_t = btas::Tensor<double, btas::DEFAULT::range, btas::mohndle<btas::varray<double>, btas::Handle::shared_ptr>>;
using scalar_t = double;
using blk_t = btas::Tensor<scalar_t, btas::DEFAULT::range, btas::mohndle<btas::varray<scalar_t>, btas::Handle::shared_ptr>>;

#if defined(TTG_USE_PARSEC)
namespace ttg {
Expand Down Expand Up @@ -695,32 +695,43 @@ class Control : public TT<void, std::tuple<Out<Key<2>>>, Control> {
}
};

std::tuple<float, float> norms(float t) { return std::make_tuple(t * t, std::abs(t)); }
std::tuple<double, double> norms(double t) { return std::make_tuple(t * t, std::abs(t)); }

template <typename T>
std::tuple<T, T> norms(std::complex<T> t) {
auto abs_t = std::abs(t);
return std::make_tuple(abs_t * abs_t, abs_t);
}

#ifdef BTAS_IS_USABLE
template <typename T_, class Range_, class Store_>
std::tuple<T_, T_> norms(const btas::Tensor<T_, Range_, Store_> &t) {
T_ norm_2_square = 0.0;
T_ norm_inf = 0.0;
for (auto k : t) {
norm_2_square += k * k;
norm_inf = std::max(norm_inf, std::abs(k));
auto norms(const btas::Tensor<T_, Range_, Store_> &t) {
using T = decltype(std::abs(std::declval<T_>()));
T norm_2_square = 0.0;
T norm_inf = 0.0;
for (auto elem : t) {
T elem_norm_2_square, elem_norm_inf;
std::tie(elem_norm_2_square, elem_norm_inf) = norms(elem);
norm_2_square += elem_norm_2_square;
norm_inf = std::max(norm_inf, elem_norm_inf);
}
return std::make_tuple(norm_2_square, norm_inf);
}
#endif

std::tuple<double, double> norms(double t) { return std::make_tuple(t * t, std::abs(t)); }

template <typename Blk = blk_t>
std::tuple<double, double> norms(const SpMatrix<Blk> &A) {
double norm_2_square = 0.0;
double norm_inf = 0.0;
template <typename Blk>
auto norms(const SpMatrix<Blk> &A) {
using T = decltype(std::abs(std::declval<typename Blk::value_type>()));
T norm_2_square = 0.0;
T norm_inf = 0.0;
for (int i = 0; i < A.outerSize(); ++i) {
for (typename SpMatrix<Blk>::InnerIterator it(A, i); it; ++it) {
// cout << 1+it.row() << "\t"; // row index
// cout << 1+it.col() << "\t"; // col index (here it is equal to k)
// cout << it.value() << endl;
auto elem = it.value();
double elem_norm_2_square, elem_norm_inf;
T elem_norm_2_square, elem_norm_inf;
std::tie(elem_norm_2_square, elem_norm_inf) = norms(elem);
norm_2_square += elem_norm_2_square;
norm_inf = std::max(norm_inf, elem_norm_inf);
Expand Down
61 changes: 40 additions & 21 deletions examples/spmm/spmm_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,14 @@ struct DeviceTensor : public ttg::TTValue<DeviceTensor<_T, _Range, _Storage>>

};

using scalar_t = double;
#if defined(TTG_HAVE_CUDA) || defined(TTG_HAVE_HIPBLAS)
using blk_t = DeviceTensor<double, btas::DEFAULT::range,
btas::mohndle<btas::varray<double, TiledArray::device_pinned_allocator<double>>,
using blk_t = DeviceTensor<scalar_t, btas::DEFAULT::range,
btas::mohndle<btas::varray<scalar_t, TiledArray::device_pinned_allocator<scalar_t>>,
btas::Handle::shared_ptr>>;
#else
using blk_t = DeviceTensor<double, btas::DEFAULT::range,
btas::mohndle<btas::varray<double>, btas::Handle::shared_ptr>>;
using blk_t = DeviceTensor<scalar_t, btas::DEFAULT::range,
btas::mohndle<btas::varray<scalar_t>, btas::Handle::shared_ptr>>;
#endif


Expand All @@ -265,33 +266,50 @@ using blk_t = DeviceTensor<double, btas::DEFAULT::range,
//}

/* TODO: call CUDA gemm here */
static void device_gemm(blk_t &C, const blk_t &A, const blk_t &B) {
static const double alpha = 1.0;
static const double beta = 1.0;
template <typename Blk>
static void device_gemm(Blk &C, const Blk &A, const Blk &B) {
using blk_t = Blk;
using T = typename blk_t::value_type;
static_assert(std::is_same_v<T,double> || std::is_same_v<T,float>);
static const T alpha = 1.0;
static const T beta = 1.0;
// make sure all memory is on the device
// TODO: A and B are read-only so the owner device will be 0. How to fix?
//assert(A.b.get_current_device() != 0);
//assert(B.b.get_current_device() != 0);
int device = C.b.get_current_device();
assert(device != 0);
#if defined(TTG_HAVE_CUDA)
cublasDgemm(cublas_handle(),
CUBLAS_OP_N, CUBLAS_OP_N, C.extent(0), C.extent(1), A.extent(1), &alpha,
A.b.device_ptr_on(device), A.extent(0),
B.b.device_ptr_on(device), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
if constexpr (std::is_same_v<T,double>) {
cublasDgemm(ttg::detail::cublas_get_handle(), CUBLAS_OP_N, CUBLAS_OP_N, C.extent(0), C.extent(1), A.extent(1),
&alpha, A.b.device_ptr_on(device), A.extent(0), B.b.device_ptr_on(device), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
}
else if constexpr (std::is_same_v<T,float>) {
cublasSgemm(ttg::detail::cublas_get_handle(), CUBLAS_OP_N, CUBLAS_OP_N, C.extent(0), C.extent(1), A.extent(1),
&alpha, A.b.device_ptr_on(device), A.extent(0), B.b.device_ptr_on(device), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
}
#elif defined(TTG_HAVE_HIPBLAS)
hipblasDgemm(hipblas_handle(),
HIPBLAS_OP_N, HIPBLAS_OP_N,
C.extent(0), C.extent(1), A.extent(1), &alpha,
A.b.device_ptr_on(device), A.extent(0),
B.b.device_ptr_on(device), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
if constexpr (std::is_same_v<T,double>) {
hipblasDgemm(hipblas_handle(),
HIPBLAS_OP_N, HIPBLAS_OP_N,
C.extent(0), C.extent(1), A.extent(1), &alpha,
A.b.device_ptr_on(device), A.extent(0),
B.b.device_ptr_on(device), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
} else if constexpr (std::is_same_v<T,float>) {
hipblasSgemm(hipblas_handle(),
HIPBLAS_OP_N, HIPBLAS_OP_N,
C.extent(0), C.extent(1), A.extent(1), &alpha,
A.b.device_ptr_on(device), A.extent(0),
B.b.device_ptr_on(device), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
}

#endif
}

//using blk_t = btas::Tensor<double, btas::DEFAULT::range, btas::mohndle<btas::varray<double>, btas::Handle::shared_ptr>>;

#if defined(TTG_USE_PARSEC)
namespace ttg {
template <>
Expand All @@ -313,8 +331,9 @@ namespace ttg {
return dim;
}
static auto get_data(blk_t &b) {
using T = typename blk_t::value_type;
if (!b.empty())
return boost::container::small_vector<iovec, 1>(1, iovec{b.size() * sizeof(double), b.data()});
return boost::container::small_vector<iovec, 1>(1, iovec{b.size() * sizeof(T), b.data()});
else
return boost::container::small_vector<iovec, 1>{};
}
Expand Down
2 changes: 2 additions & 0 deletions ttg/ttg/config.in.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#cmakedefine TTG_HAVE_HIPBLAS

#cmakedefine TTG_HAVE_DEVICE

#cmakedefine TTG_HAVE_MPI
#cmakedefine TTG_HAVE_MPIEXT

Expand Down
20 changes: 17 additions & 3 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,14 @@ class CallableWrapTTArgs
#ifdef TTG_HAS_COROUTINE
std::conditional_t<std::is_same_v<returnT, ttg::resumable_task>,
ttg::coroutine_handle<>,
#ifdef TTG_HAVE_DEVICE
std::conditional_t<std::is_same_v<returnT, ttg::device_task>,
ttg::device_task::base_type,
void>>;
void>
#else // TTG_HAVE_DEVICE
void
#endif // TTG_HAVE_DEVICE
>;
#else // TTG_HAS_COROUTINE
void;
#endif // TTG_HAS_COROUTINE
Expand All @@ -180,11 +185,20 @@ class CallableWrapTTArgs
coro_handle = ret;
}
return coro_handle;
} else if constexpr (std::is_same_v<returnT, ttg::device_task>) {
} else
#ifdef TTG_HAVE_DEVICE
if constexpr (std::is_same_v<returnT, ttg::device_task>) {
ttg::device_task::base_type coro_handle = ret;
return coro_handle;
}
if constexpr (!(std::is_same_v<returnT, ttg::resumable_task> || std::is_same_v<returnT, ttg::device_task>))
#else // TTG_HAVE_DEVICE
ttg::abort(); // should not happen
#endif // TTG_HAVE_DEVICE
if constexpr (!(std::is_same_v<returnT, ttg::resumable_task>
#ifdef TTG_HAVE_DEVICE
|| std::is_same_v<returnT, ttg::device_task>
#endif // TTG_HAVE_DEVICE
))
#endif
{
static_assert(std::tuple_size_v<std::remove_reference_t<decltype(out)>> == 1,
Expand Down
17 changes: 14 additions & 3 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
#include "ttg/util/print.h"
#include "ttg/util/trace.h"
#include "ttg/util/typelist.h"
#ifdef TTG_HAVE_DEVICE
#include "ttg/device/task.h"
#endif // TTG_HAVE_DEVICE

#include "ttg/serialization/data_descriptor.h"

Expand Down Expand Up @@ -1300,6 +1302,7 @@ namespace ttg_parsec {
task->copies[IS]->get_ptr()))...};
}

#ifdef TTG_HAVE_DEVICE
/**
* Submit callback called by PaRSEC once all input transfers have completed.
*/
Expand Down Expand Up @@ -1491,6 +1494,7 @@ namespace ttg_parsec {
ttg::abort();
return PARSEC_HOOK_RETURN_DONE; // will not be reacehed
}
#endif // TTG_HAVE_DEVICE

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t static_op(parsec_task_t *parsec_task) {
Expand Down Expand Up @@ -1532,6 +1536,7 @@ namespace ttg_parsec {
detail::parsec_ttg_caller = nullptr;
}
else { // resume the suspended coroutine
#ifdef TTG_HAVE_DEVICE
auto coro = static_cast<ttg::device_task>(ttg::device_task_handle_type::from_address(suspended_task_address));
assert(detail::parsec_ttg_caller == nullptr);
detail::parsec_ttg_caller = static_cast<detail::parsec_ttg_task_base_t*>(task);
Expand Down Expand Up @@ -1559,9 +1564,12 @@ namespace ttg_parsec {
}
task->suspended_task_address = suspended_task_address;
#else
#endif // 0
#endif // TTG_HAS_COROUTINE
ttg::abort(); // should not happen
#endif
#endif // 0
#else // TTG_HAVE_DEVICE
ttg::abort(); // should not happen
#endif // TTG_HAVE_DEVICE
}
task->suspended_task_address = suspended_task_address;

Expand Down Expand Up @@ -3444,6 +3452,7 @@ namespace ttg_parsec {

/* if we still have a coroutine handle we invoke it one more time to get the sends/broadcasts */
if (task->suspended_task_address) {
#ifdef TTG_HAVE_DEVICE
// get the device task from the coroutine handle
auto dev_task = ttg::device_task_handle_type::from_address(task->suspended_task_address);

Expand All @@ -3460,7 +3469,9 @@ namespace ttg_parsec {
dev_data.do_sends();
detail::parsec_ttg_caller = nullptr;
}

#else // TTG_HAVE_DEVICE
ttg::abort(); // should not happen
#endif // TTG_HAVE_DEVICE
/* the coroutine should have completed and we cannot access the promise anymore */
task->suspended_task_address = nullptr;
}
Expand Down

0 comments on commit 94f3a81

Please sign in to comment.