From 4f86f7be6ccf670fafd6b2865db97cb7d2adc438 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Wed, 15 May 2024 17:57:37 -0400 Subject: [PATCH] Add device hint to TT and buffer For POTRF, we want to provide a hint that tasks on the same column should be executed on the same device, to reduce data movement and provide a hint on load balancing up front. Signed-off-by: Joseph Schuchart --- examples/potrf/pmw.h | 10 ++++++- examples/potrf/potrf.h | 24 ++++++++++++++++ ttg/ttg/device/device.h | 7 +++++ ttg/ttg/madness/device.h | 9 ++++++ ttg/ttg/madness/ttg.h | 1 + ttg/ttg/parsec/buffer.h | 8 ++++++ ttg/ttg/parsec/device.h | 7 +++++ ttg/ttg/parsec/fwd.h | 2 ++ ttg/ttg/parsec/ttg.h | 61 ++++++++++++++++++++++++++++++++++++++++ ttg/ttg/util/meta.h | 14 ++++----- 10 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 ttg/ttg/madness/device.h diff --git a/examples/potrf/pmw.h b/examples/potrf/pmw.h index 0f8d75d7d..25f01933c 100644 --- a/examples/potrf/pmw.h +++ b/examples/potrf/pmw.h @@ -102,6 +102,14 @@ class PaRSECMatrixWrapper { (pm->uplo == PARSEC_MATRIX_UPPER && col >= row); } + int P() const { + return pm->grid.rows; + } + + int Q() const { + return pm->grid.cols; + } + PaRSECMatrixT* parsec() { return pm; } @@ -132,7 +140,7 @@ class PaRSECMatrixWrapper { }; template -using MatrixT = PaRSECMatrixWrapper; +using MatrixT = PaRSECMatrixWrapper; static auto make_load_tt(MatrixT &A, ttg::Edge> &toop, bool defer_write) { diff --git a/examples/potrf/potrf.h b/examples/potrf/potrf.h index 1ecb2c23f..4adbff106 100644 --- a/examples/potrf/potrf.h +++ b/examples/potrf/potrf.h @@ -678,6 +678,18 @@ namespace potrf { auto keymap3 = [&](const Key3& key) { return A.rank_of(key[0], key[1]); }; + /** + * Device map hints: we try to keep tiles on one row on the same device to minimize + * data movement between devices. This provides hints for load-balancing up front + * and avoids movement of the TRSM result to GEMM tasks. + */ + auto devmap1 = [&](const key1& key) { return (key[0] / A.P()) % ttg::device::num_devices(); } + + auto devmap2a = [&](const key2& key) { return (key[0] / A.P()) % ttg::device::num_devices(); } + auto devmap2b = [&](const key2& key) { return (key[1] / A.P()) % ttg::device::num_devices(); } + + auto devmap3 = [&](const key3& key) { return (key[0] / A.P()) % ttg::device::num_devices(); } + ttg::Edge> syrk_potrf("syrk_potrf"), disp_potrf("disp_potrf"); ttg::Edge> potrf_trsm("potrf_trsm"), trsm_syrk("trsm_syrk"), gemm_trsm("gemm_trsm"), @@ -692,18 +704,30 @@ namespace potrf { auto tt_potrf = make_potrf(A, disp_potrf, syrk_potrf, potrf_trsm, output); tt_potrf->set_keymap(keymap1); tt_potrf->set_defer_writer(defer_write); +#ifdef ENABLE_DEVICE_KERNEL + tt_potrf->set_devmap(devmap1); +#endif // 0 auto tt_trsm = make_trsm(A, disp_trsm, potrf_trsm, gemm_trsm, trsm_syrk, trsm_gemm_row, trsm_gemm_col, output); tt_trsm->set_keymap(keymap2a); tt_trsm->set_defer_writer(defer_write); +#ifdef ENABLE_DEVICE_KERNEL + tt_trsm->set_devmap(devmap2a); +#endif // 0 auto tt_syrk = make_syrk(A, disp_syrk, trsm_syrk, syrk_syrk, syrk_potrf, syrk_syrk); tt_syrk->set_keymap(keymap2b); tt_syrk->set_defer_writer(defer_write); +#ifdef ENABLE_DEVICE_KERNEL + tt_syrk->set_devmap(devmap2b); +#endif // 0 auto tt_gemm = make_gemm(A, disp_gemm, trsm_gemm_row, trsm_gemm_col, gemm_gemm, gemm_trsm, gemm_gemm); tt_gemm->set_keymap(keymap3); tt_gemm->set_defer_writer(defer_write); +#ifdef ENABLE_DEVICE_KERNEL + tt_gemm->set_devmap(devmap3); +#endif // 0 /* Priorities taken from DPLASMA */ auto nt = A.cols(); diff --git a/ttg/ttg/device/device.h b/ttg/ttg/device/device.h index 6690982f6..e815aaf87 100644 --- a/ttg/ttg/device/device.h +++ b/ttg/ttg/device/device.h @@ -2,6 +2,7 @@ #include "ttg/config.h" #include "ttg/execution.h" +#include "ttg/impl_selector.h" @@ -180,3 +181,9 @@ namespace ttg::device { } } // namespace ttg #endif // defined(TTG_HAVE_HIP) + +namespace ttg::device { + inline int num_devices() { + return TTG_IMPL_NS::num_devices(); + } +} diff --git a/ttg/ttg/madness/device.h b/ttg/ttg/madness/device.h new file mode 100644 index 000000000..e13321194 --- /dev/null +++ b/ttg/ttg/madness/device.h @@ -0,0 +1,9 @@ +#ifndef TTG_MADNESS_DEVICE_H +#define TTG_MADNESS_DEVICE_H + +namespace ttg_madness { + /* no device support in MADNESS */ + inline int num_devices() { return 0; } +} + +#endif // TTG_MADNESS_DEVICE_H \ No newline at end of file diff --git a/ttg/ttg/madness/ttg.h b/ttg/ttg/madness/ttg.h index 5fab6f6dd..34c576f56 100644 --- a/ttg/ttg/madness/ttg.h +++ b/ttg/ttg/madness/ttg.h @@ -13,6 +13,7 @@ #include "ttg/base/keymap.h" #include "ttg/base/tt.h" #include "ttg/func.h" +#include "ttg/madness/device.h" #include "ttg/runtimes.h" #include "ttg/tt.h" #include "ttg/util/bug.h" diff --git a/ttg/ttg/parsec/buffer.h b/ttg/ttg/parsec/buffer.h index 98b14eb12..6e609a158 100644 --- a/ttg/ttg/parsec/buffer.h +++ b/ttg/ttg/parsec/buffer.h @@ -329,6 +329,14 @@ struct Buffer : public detail::ttg_parsec_data_wrapper_t // << " parsec_data " << m_data.get() << std::endl; } + void prefer_device(ttg::device::Device dev) { + /* only set device if the host has the latest copy as otherwise we might end up with a stale copy */ + if (dev.is_device() && this->parsec_data()->owner_device == 0) { + parsec_advise_data_on_device(this->parsec_data(), detail::ttg_device_to_parsec_device(dev), + PARSEC_DEV_DATA_ADVICE_PREFERRED_DEVICE); + } + } + /* serialization support */ #ifdef TTG_SERIALIZATION_SUPPORTS_BOOST diff --git a/ttg/ttg/parsec/device.h b/ttg/ttg/parsec/device.h index 77722b1c1..9f8ada05c 100644 --- a/ttg/ttg/parsec/device.h +++ b/ttg/ttg/parsec/device.h @@ -2,6 +2,7 @@ #define TTG_PARSEC_DEVICE_H #include "ttg/device/device.h" +#include namespace ttg_parsec { @@ -35,6 +36,12 @@ namespace ttg_parsec { } } // namespace detail + + inline + int num_devices() { + return parsec_nb_devices - detail::first_device_id; + } + } // namespace ttg_parsec #endif // TTG_PARSEC_DEVICE_H \ No newline at end of file diff --git a/ttg/ttg/parsec/fwd.h b/ttg/ttg/parsec/fwd.h index d5bc8931e..0cd798e87 100644 --- a/ttg/ttg/parsec/fwd.h +++ b/ttg/ttg/parsec/fwd.h @@ -82,6 +82,8 @@ namespace ttg_parsec { template inline void mark_device_out(std::tuple &b); + inline int num_devices(); + #if 0 template inline std::pair>...>> get_ptr(Args&&... args); diff --git a/ttg/ttg/parsec/ttg.h b/ttg/ttg/parsec/ttg.h index 740c8fd07..1c795677c 100644 --- a/ttg/ttg/parsec/ttg.h +++ b/ttg/ttg/parsec/ttg.h @@ -1296,6 +1296,7 @@ namespace ttg_parsec { ttg::World world; ttg::meta::detail::keymap_t keymap; ttg::meta::detail::keymap_t priomap; + ttg::meta::detail::keymap_t devicemap; // For now use same type for unary/streaming input terminals, and stream reducers assigned at runtime ttg::meta::detail::input_reducers_t input_reducers; //!< Reducers for the input terminals (empty = expect single value) @@ -1502,6 +1503,12 @@ namespace ttg_parsec { gpu_task->pushout = 0; gpu_task->submit = &TT::device_static_submit; + // one way to force the task device + // currently this will probably break all of PaRSEC if this hint + // does not match where the data is located, not really useful for us + // instead we set a hint on the data if there is no hint set yet + //parsec_task->selected_device = ...; + /* set the gpu_task so it's available in register_device_memory */ task->dev_ptr->gpu_task = gpu_task; @@ -1525,6 +1532,29 @@ namespace ttg_parsec { } tc.nb_flows = MAX_PARAM_COUNT; + /* set the device hint on the data */ + TT *tt = task->tt; + if (tt->devicemap) { + int parsec_dev; + if constexpr (std::is_void_v) { + parsec_dev = ttg::device::ttg_device_to_parsec_device(tt->devicemap()); + } else { + parsec_dev = ttg::device::ttg_device_to_parsec_device(tt->devicemap(tt->key)); + } + for (int i = 0; i < MAX_PARAM_COUNT; ++i) { + /* only set on mutable data since we have exclusive access */ + if (tc.in[i].flow_flags & PARSEC_FLOW_ACCESS_WRITE) { + parsec_data_t *data = parsec_task->data[i].data_in->original; + /* only set the preferred device if the host has the latest copy + * as otherwise we may end up with the wrong data if there is a newer + * version on a different device. Also, keep fingers crossed. */ + if (data->owner_device == 0) { + parsec_advise_data_on_device(data, parsec_dev, PARSEC_DEV_DATA_ADVICE_PREFERRED_DEVICE); + } + } + } + } + /* set the new task class that contains the flows */ task->parsec_task.task_class = &task->dev_ptr->task_class; @@ -4195,6 +4225,37 @@ ttg::abort(); // should not happen priomap = std::forward(pm); } + /// device map setter + /// The device map provides a hint on which device a task should execute. + /// TTG may not be able to honor the request and the corresponding task + /// may execute on a different device. + /// @arg pm a function that provides a hint on which device the task should execute. + template + void set_devicemap(Devicemap&& dm) { + static_assert(derived_has_device_op(), "Device map only allowed on device-enabled TT!"); + if constexpr (std::is_same_v()))>) { + // dm returns a Device + devicemap = std::forward(dm); + } else { + // convert dm return into a Device + devicemap = [=](const keyT& key) { + if constexpr (derived_has_cuda_op()) { + return ttg::device::Device(dm(key), ttg::ExecutionSpace::CUDA); + } else if constexpr (derived_has_hip_op()) { + return ttg::device::Device(dm(key), ttg::ExecutionSpace::HIP); + } else if constexpr (derived_has_level_zero_op()) { + return ttg::device::Device(dm(key), ttg::ExecutionSpace::L0); + } else { + throw std::runtime_error("Unknown device type!"); + } + }; + } + } + + /// device map accessor + /// @return the device map + auto get_devicemap() { return devicemap; } + // Register the static_op function to associate it to instance_id void register_static_op_function(void) { int rank; diff --git a/ttg/ttg/util/meta.h b/ttg/ttg/util/meta.h index b7bb31690..5322fa577 100644 --- a/ttg/ttg/util/meta.h +++ b/ttg/ttg/util/meta.h @@ -848,18 +848,18 @@ namespace ttg { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // keymap_t = std::function, protected against void key //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template + template struct keymap; - template + template struct keymap>> { - using type = std::function; + using type = std::function; }; - template + template struct keymap>> { - using type = std::function; + using type = std::function; }; - template - using keymap_t = typename keymap::type; + template + using keymap_t = typename keymap::type; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // input_reducers_t = std::tuple<