Skip to content

Commit

Permalink
Add support for passing device selector to distr_queue ctor
Browse files Browse the repository at this point in the history
... and runtime::init
  • Loading branch information
almightyvats authored and psalz committed May 17, 2022
1 parent 2a56b50 commit 556b6f2
Show file tree
Hide file tree
Showing 11 changed files with 730 additions and 243 deletions.
15 changes: 13 additions & 2 deletions include/celerity.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef RUNTIME_INCLUDE_ENTRY_CELERITY
#define RUNTIME_INCLUDE_ENTRY_CELERITY

#include "device_queue.h"
#include "runtime.h"

#include "accessor.h"
Expand All @@ -15,14 +16,24 @@ namespace runtime {
/**
* @brief Initializes the Celerity runtime.
*/
inline void init(int* argc, char** argv[]) { detail::runtime::init(argc, argv, nullptr); }
inline void init(int* argc, char** argv[]) { detail::runtime::init(argc, argv, detail::auto_select_device{}); }

/**
* @brief Initializes the Celerity runtime and instructs it to use a particular device.
*
* @param device The device to be used on the current node. This can vary between nodes.
*/
inline void init(int* argc, char** argv[], cl::sycl::device& device) { detail::runtime::init(argc, argv, &device); }
[[deprecated("Use the overload with device selector instead, this will be removed in future release")]] inline void init(
int* argc, char** argv[], sycl::device& device) {
detail::runtime::init(argc, argv, device);
}

/**
* @brief Initializes the Celerity runtime and instructs it to use a particular device.
*
* @param device_selector The device selector to be used on the current node. This can vary between nodes.
*/
inline void init(int* argc, char** argv[], const detail::device_selector& device_selector) { detail::runtime::init(argc, argv, device_selector); }
} // namespace runtime
} // namespace celerity

Expand Down
3 changes: 2 additions & 1 deletion include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace detail {
};

class config {
friend struct config_testspy;

public:
/**
* Initializes the @p config by parsing environment variables and passed arguments.
Expand Down Expand Up @@ -48,7 +50,6 @@ namespace detail {
std::optional<device_config> device_cfg;
std::optional<bool> enable_device_profiling;
size_t graph_print_max_verts = 200;
friend struct config_testspy;
};

} // namespace detail
Expand Down
172 changes: 133 additions & 39 deletions include/device_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
#include <memory>

#include <CL/sycl.hpp>
#include <type_traits>
#include <variant>

#include "config.h"
#include "workaround.h"

namespace celerity {
namespace detail {

struct auto_select_device {};
using device_selector = std::function<int(const sycl::device&)>;
using device_or_selector = std::variant<auto_select_device, sycl::device, device_selector>;

class task;

/**
Expand All @@ -21,9 +27,9 @@ namespace detail {
* @brief Initializes the @p device_queue, selecting an appropriate device in the process.
*
* @param cfg The configuration is used to select the appropriate SYCL device.
* @param user_device Optionally a device can be provided, which will take precedence over any configuration.
* @param user_device_or_selector Optionally a device (which will take precedence over any configuration) or a device selector can be provided.
*/
void init(const config& cfg, cl::sycl::device* user_device);
void init(const config& cfg, const device_or_selector& user_device_or_selector);

/**
* @brief Executes the kernel associated with task @p ctsk over the chunk @p chnk.
Expand Down Expand Up @@ -62,12 +68,111 @@ namespace detail {
void handle_async_exceptions(cl::sycl::exception_list el) const;
};

// Try to find a platform that can provide a unique device for each node using a device selector.
template <typename DeviceT, typename PlatformT, typename SelectorT>
bool try_find_device_per_node(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
std::vector<std::tuple<DeviceT, size_t>> devices_with_platform_idx;
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
for(auto device : platform.get_devices()) {
if(selector(device) == -1) { continue; }
devices_with_platform_idx.emplace_back(device, i);
}
}

std::stable_sort(devices_with_platform_idx.begin(), devices_with_platform_idx.end(),
[selector](const auto& a, const auto& b) { return selector(std::get<0>(a)) > selector(std::get<0>(b)); });
bool same_platform = true;
bool same_device_type = true;
if(devices_with_platform_idx.size() >= host_cfg.node_count) {
auto [device_from_platform, idx] = devices_with_platform_idx[0];
const auto platform = device_from_platform.get_platform();
const auto device_type = device_from_platform.template get_info<sycl::info::device::device_type>();

for(size_t i = 1; i < host_cfg.node_count; ++i) {
auto [device_from_platform, idx] = devices_with_platform_idx[i];
if(device_from_platform.get_platform() != platform) { same_platform = false; }
if(device_from_platform.template get_info<sycl::info::device::device_type>() != device_type) { same_device_type = false; }
}

if(!same_platform || !same_device_type) { CELERITY_WARN("Selected devices are of different type and/or do not belong to the same platform"); }

auto [selected_device_from_platform, selected_idx] = devices_with_platform_idx[host_cfg.local_rank];
how_selected = fmt::format("device selector specified: platform {}, device {}", selected_idx, host_cfg.local_rank);
device = selected_device_from_platform;
return true;
}

return false;
}

// Try to find a platform that can provide a unique device for each node.
template <typename DeviceT, typename PlatformT>
DeviceT pick_device(const config& cfg, DeviceT* user_device, const std::vector<PlatformT>& platforms) {
bool try_find_device_per_node(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
std::vector<DeviceT> platform_devices;

platform_devices = platform.get_devices(type);
if(platform_devices.size() >= host_cfg.node_count) {
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
device = platform_devices[host_cfg.local_rank];
return true;
}
}

return false;
}

template <typename DeviceT, typename PlatformT, typename SelectorT>
bool try_find_one_device(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
std::vector<DeviceT> platform_devices;
for(auto& p : platforms) {
auto p_devices = p.get_devices();
platform_devices.insert(platform_devices.end(), p_devices.begin(), p_devices.end());
}

std::stable_sort(platform_devices.begin(), platform_devices.end(), [selector](const auto& a, const auto& b) { return selector(a) > selector(b); });
if(!platform_devices.empty()) {
if(selector(platform_devices[0]) == -1) { return false; }
device = platform_devices[0];
return true;
}

return false;
};

template <typename DeviceT, typename PlatformT>
bool try_find_one_device(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
for(auto& p : platforms) {
for(auto& d : p.get_devices(type)) {
device = d;
return true;
}
}

return false;
};


template <typename DevicePtrOrSelector, typename PlatformT>
auto pick_device(const config& cfg, const DevicePtrOrSelector& user_device_or_selector, const std::vector<PlatformT>& platforms) {
using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;

constexpr bool user_device_provided = std::is_same_v<DevicePtrOrSelector, DeviceT>;
constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicePtrOrSelector, DeviceT>;
constexpr bool auto_select = std::is_same_v<auto_select_device, DevicePtrOrSelector>;
static_assert(
user_device_provided ^ device_selector_provided ^ auto_select, "pick_device requires either a device, a selector, or the auto_select_device tag");

DeviceT device;
std::string how_selected = "automatically selected";
if(user_device != nullptr) {
device = *user_device;
if constexpr(user_device_provided) {
device = user_device_or_selector;
how_selected = "specified by user";
} else {
const auto device_cfg = cfg.get_device_config();
Expand All @@ -86,48 +191,37 @@ namespace detail {
} else {
const auto host_cfg = cfg.get_host_config();

const auto try_find_device_per_node = [&host_cfg, &device, &how_selected, &platforms](cl::sycl::info::device_type type) {
// Try to find a platform that can provide a unique device for each node.
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
const auto devices = platform.get_devices(type);
if(devices.size() >= host_cfg.node_count) {
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
device = devices[host_cfg.local_rank];
return true;
}
}
return false;
};

const auto try_find_one_device = [&device, &platforms](cl::sycl::info::device_type type) {
for(auto& p : platforms) {
for(auto& d : p.get_devices(type)) {
device = d;
return true;
if constexpr(!device_selector_provided) {
// Try to find a unique GPU per node.
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)) {
if(try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
} else {
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
// Just use the first available device. Prefer GPUs, but settle for anything.
if(!try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)
&& !try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
throw std::runtime_error("Automatic device selection failed: No device available");
}
}
}
return false;
};

// Try to find a unique GPU per node.
if(!try_find_device_per_node(cl::sycl::info::device_type::gpu)) {
// Try to find a unique device (of any type) per node.
if(try_find_device_per_node(cl::sycl::info::device_type::all)) {
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
} else {
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
// Just use the first available device. Prefer GPUs, but settle for anything.
if(!try_find_one_device(cl::sycl::info::device_type::gpu) && !try_find_one_device(cl::sycl::info::device_type::all)) {
throw std::runtime_error("Automatic device selection failed: No device available");
} else {
// Try to find a unique device per node using a selector.
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
CELERITY_WARN("No suitable platform found that can provide {} devices that match the specified device selector, and "
"CELERITY_DEVICES not set",
host_cfg.node_count);
// Use the first available device according to the selector, but fails if no such device is found.
if(!try_find_one_device(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
throw std::runtime_error("Device selection with device selector failed: No device available");
}
}
}
}
}

const auto platform_name = device.get_platform().template get_info<cl::sycl::info::platform::name>();
const auto device_name = device.template get_info<cl::sycl::info::device::name>();
const auto platform_name = device.get_platform().template get_info<sycl::info::platform::name>();
const auto device_name = device.template get_info<sycl::info::device::name>();
CELERITY_INFO("Using platform '{}', device '{}' ({})", platform_name, device_name, how_selected);

return device;
Expand Down
20 changes: 15 additions & 5 deletions include/distr_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <type_traits>

#include "device_queue.h"
#include "runtime.h"
#include "task_manager.h"

Expand All @@ -25,10 +26,19 @@ inline constexpr allow_by_ref_t allow_by_ref{};

class distr_queue {
public:
distr_queue() { init(nullptr); }
distr_queue(cl::sycl::device& device) {
distr_queue() { init(detail::auto_select_device{}); }

[[deprecated("Use the overload with device selector instead, this will be removed in future release")]] distr_queue(cl::sycl::device& device) {
if(detail::runtime::is_initialized()) { throw std::runtime_error("Passing explicit device not possible, runtime has already been initialized."); }
init(&device);
init(device);
}

template <typename DeviceSelector>
distr_queue(const DeviceSelector& device_selector) {
if(detail::runtime::is_initialized()) {
throw std::runtime_error("Passing explicit device selector not possible, runtime has already been initialized.");
}
init(device_selector);
}

distr_queue(const distr_queue&) = default;
Expand Down Expand Up @@ -77,8 +87,8 @@ class distr_queue {
private:
std::shared_ptr<detail::distr_queue_tracker> tracker;

void init(cl::sycl::device* user_device) {
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr, user_device); }
void init(detail::device_or_selector device_or_selector) {
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr, device_or_selector); }
try {
detail::runtime::get_instance().startup();
} catch(detail::runtime_already_started_error&) {
Expand Down
7 changes: 4 additions & 3 deletions include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ namespace detail {

public:
/**
* @param user_device This optional device can be provided by the user, overriding any other device selection strategy.
* @param user_device_or_selector This optional device (overriding any other device selection strategy) or device selector can be provided by the user.
*/
static void init(int* argc, char** argv[], cl::sycl::device* user_device = nullptr);
static void init(int* argc, char** argv[], device_or_selector user_device_or_selector = auto_select_device{});

static bool is_initialized() { return instance != nullptr; }
static runtime& get_instance();

Expand Down Expand Up @@ -117,7 +118,7 @@ namespace detail {
};
std::deque<flush_handle> active_flushes;

runtime(int* argc, char** argv[], cl::sycl::device* user_device = nullptr);
runtime(int* argc, char** argv[], device_or_selector user_device_or_selector);
runtime(const runtime&) = delete;
runtime(runtime&&) = delete;

Expand Down
7 changes: 4 additions & 3 deletions src/device_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
namespace celerity {
namespace detail {

void device_queue::init(const config& cfg, cl::sycl::device* user_device) {
void device_queue::init(const config& cfg, const device_or_selector& user_device_or_selector) {
assert(sycl_queue == nullptr);
const auto profiling_cfg = cfg.get_enable_device_profiling();
device_profiling_enabled = profiling_cfg != std::nullopt && *profiling_cfg;
if(device_profiling_enabled) { CELERITY_INFO("Device profiling enabled."); }

const auto props = device_profiling_enabled ? cl::sycl::property_list{cl::sycl::property::queue::enable_profiling()} : cl::sycl::property_list{};
const auto handle_exceptions = cl::sycl::async_handler{[this](cl::sycl::exception_list el) { this->handle_async_exceptions(el); }};
auto device = pick_device(cfg, user_device, cl::sycl::platform::get_platforms());

auto device = std::visit(
[&cfg](const auto& value) { return ::celerity::detail::pick_device(cfg, value, cl::sycl::platform::get_platforms()); }, user_device_or_selector);
sycl_queue = std::make_unique<cl::sycl::queue>(device, handle_exceptions, props);
}


void device_queue::handle_async_exceptions(cl::sycl::exception_list el) const {
for(auto& e : el) {
try {
Expand Down
8 changes: 4 additions & 4 deletions src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ namespace detail {
mpi_finalized = true;
}

void runtime::init(int* argc, char** argv[], cl::sycl::device* user_device) {
void runtime::init(int* argc, char** argv[], device_or_selector user_device_or_selector) {
assert(!instance);
instance = std::unique_ptr<runtime>(new runtime(argc, argv, user_device));
instance = std::unique_ptr<runtime>(new runtime(argc, argv, user_device_or_selector));
}

runtime& runtime::get_instance() {
Expand Down Expand Up @@ -91,7 +91,7 @@ namespace detail {
#endif
}

runtime::runtime(int* argc, char** argv[], cl::sycl::device* user_device) {
runtime::runtime(int* argc, char** argv[], device_or_selector user_device_or_selector) {
if(test_mode) {
assert(test_active && "initializing the runtime from a test without a runtime_fixture");
} else {
Expand Down Expand Up @@ -145,7 +145,7 @@ namespace detail {

CELERITY_INFO(
"Celerity runtime version {} running on {}. PID = {}, build type = {}", get_version_string(), get_sycl_version(), get_pid(), get_build_type());
d_queue->init(*cfg, user_device);
d_queue->init(*cfg, user_device_or_selector);
}

runtime::~runtime() {
Expand Down
Loading

0 comments on commit 556b6f2

Please sign in to comment.