Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overload for runtime::init/distr_queue ctor that accepts a device selector #113

Merged
merged 3 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/celerity.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef RUNTIME_INCLUDE_ENTRY_CELERITY
#define RUNTIME_INCLUDE_ENTRY_CELERITY

#include "device_queue.h"
#include "runtime.h"

#include "accessor.h"
Expand All @@ -15,14 +16,24 @@ namespace runtime {
/**
* @brief Initializes the Celerity runtime.
*/
inline void init(int* argc, char** argv[]) { detail::runtime::init(argc, argv, nullptr); }
inline void init(int* argc, char** argv[]) { detail::runtime::init(argc, argv, detail::auto_select_device{}); }

/**
* @brief Initializes the Celerity runtime and instructs it to use a particular device.
*
* @param device The device to be used on the current node. This can vary between nodes.
*/
inline void init(int* argc, char** argv[], cl::sycl::device& device) { detail::runtime::init(argc, argv, &device); }
[[deprecated("Use the overload with device selector instead, this will be removed in future release")]] inline void init(
int* argc, char** argv[], sycl::device& device) {
detail::runtime::init(argc, argv, device);
}

/**
* @brief Initializes the Celerity runtime and instructs it to use a particular device.
*
* @param device_selector The device selector to be used on the current node. This can vary between nodes.
*/
inline void init(int* argc, char** argv[], const detail::device_selector& device_selector) { detail::runtime::init(argc, argv, device_selector); }
} // namespace runtime
} // namespace celerity

Expand Down
3 changes: 2 additions & 1 deletion include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ namespace detail {
struct host_config {
size_t node_count;
size_t local_rank;
size_t local_num_cpus;
};

struct device_config {
Expand All @@ -19,6 +18,8 @@ namespace detail {
};

class config {
friend struct config_testspy;

public:
/**
* Initializes the @p config by parsing environment variables and passed arguments.
Expand Down
170 changes: 167 additions & 3 deletions include/device_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
#include <memory>

#include <CL/sycl.hpp>
#include <type_traits>
#include <variant>

#include "config.h"
#include "workaround.h"

namespace celerity {
namespace detail {

struct auto_select_device {};
using device_selector = std::function<int(const sycl::device&)>;
using device_or_selector = std::variant<auto_select_device, sycl::device, device_selector>;

class task;

/**
Expand All @@ -21,9 +27,9 @@ namespace detail {
* @brief Initializes the @p device_queue, selecting an appropriate device in the process.
*
* @param cfg The configuration is used to select the appropriate SYCL device.
* @param user_device Optionally a device can be provided, which will take precedence over any configuration.
* @param user_device_or_selector Optionally a device (which will take precedence over any configuration) or a device selector can be provided.
*/
void init(const config& cfg, cl::sycl::device* user_device);
void init(const config& cfg, const device_or_selector& user_device_or_selector);

/**
* @brief Executes the kernel associated with task @p ctsk over the chunk @p chnk.
Expand Down Expand Up @@ -59,9 +65,167 @@ namespace detail {
std::unique_ptr<cl::sycl::queue> sycl_queue;
bool device_profiling_enabled = false;

cl::sycl::device pick_device(const config& cfg, cl::sycl::device* user_device) const;
void handle_async_exceptions(cl::sycl::exception_list el) const;
};

// Try to find a platform that can provide a unique device for each node using a device selector.
template <typename DeviceT, typename PlatformT, typename SelectorT>
bool try_find_device_per_node(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
std::vector<std::tuple<DeviceT, size_t>> devices_with_platform_idx;
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
for(auto device : platform.get_devices()) {
if(selector(device) == -1) { continue; }
devices_with_platform_idx.emplace_back(device, i);
}
}

std::stable_sort(devices_with_platform_idx.begin(), devices_with_platform_idx.end(),
[selector](const auto& a, const auto& b) { return selector(std::get<0>(a)) > selector(std::get<0>(b)); });
bool same_platform = true;
bool same_device_type = true;
if(devices_with_platform_idx.size() >= host_cfg.node_count) {
auto [device_from_platform, idx] = devices_with_platform_idx[0];
const auto platform = device_from_platform.get_platform();
const auto device_type = device_from_platform.template get_info<sycl::info::device::device_type>();

for(size_t i = 1; i < host_cfg.node_count; ++i) {
auto [device_from_platform, idx] = devices_with_platform_idx[i];
if(device_from_platform.get_platform() != platform) { same_platform = false; }
if(device_from_platform.template get_info<sycl::info::device::device_type>() != device_type) { same_device_type = false; }
}

if(!same_platform || !same_device_type) { CELERITY_WARN("Selected devices are of different type and/or do not belong to the same platform"); }

auto [selected_device_from_platform, selected_idx] = devices_with_platform_idx[host_cfg.local_rank];
how_selected = fmt::format("device selector specified: platform {}, device {}", selected_idx, host_cfg.local_rank);
device = selected_device_from_platform;
return true;
}

return false;
}

// Try to find a platform that can provide a unique device for each node.
template <typename DeviceT, typename PlatformT>
bool try_find_device_per_node(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
std::vector<DeviceT> platform_devices;

platform_devices = platform.get_devices(type);
if(platform_devices.size() >= host_cfg.node_count) {
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
device = platform_devices[host_cfg.local_rank];
return true;
}
}

return false;
}

template <typename DeviceT, typename PlatformT, typename SelectorT>
bool try_find_one_device(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
std::vector<DeviceT> platform_devices;
for(auto& p : platforms) {
auto p_devices = p.get_devices();
platform_devices.insert(platform_devices.end(), p_devices.begin(), p_devices.end());
}

std::stable_sort(platform_devices.begin(), platform_devices.end(), [selector](const auto& a, const auto& b) { return selector(a) > selector(b); });
if(!platform_devices.empty()) {
if(selector(platform_devices[0]) == -1) { return false; }
device = platform_devices[0];
return true;
}

return false;
};

template <typename DeviceT, typename PlatformT>
bool try_find_one_device(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
for(auto& p : platforms) {
for(auto& d : p.get_devices(type)) {
device = d;
return true;
}
}

return false;
};


template <typename DevicePtrOrSelector, typename PlatformT>
auto pick_device(const config& cfg, const DevicePtrOrSelector& user_device_or_selector, const std::vector<PlatformT>& platforms) {
almightyvats marked this conversation as resolved.
Show resolved Hide resolved
using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;

constexpr bool user_device_provided = std::is_same_v<DevicePtrOrSelector, DeviceT>;
constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicePtrOrSelector, DeviceT>;
constexpr bool auto_select = std::is_same_v<auto_select_device, DevicePtrOrSelector>;
static_assert(
user_device_provided ^ device_selector_provided ^ auto_select, "pick_device requires either a device, a selector, or the auto_select_device tag");

DeviceT device;
std::string how_selected = "automatically selected";
if constexpr(user_device_provided) {
device = user_device_or_selector;
how_selected = "specified by user";
} else {
const auto device_cfg = cfg.get_device_config();
if(device_cfg != std::nullopt) {
how_selected = fmt::format("set by CELERITY_DEVICES: platform {}, device {}", device_cfg->platform_id, device_cfg->device_id);
CELERITY_DEBUG("{} platforms available", platforms.size());
if(device_cfg->platform_id >= platforms.size()) {
throw std::runtime_error(fmt::format("Invalid platform id {}: Only {} platforms available", device_cfg->platform_id, platforms.size()));
}
const auto devices = platforms[device_cfg->platform_id].get_devices();
if(device_cfg->device_id >= devices.size()) {
throw std::runtime_error(fmt::format(
"Invalid device id {}: Only {} devices available on platform {}", device_cfg->device_id, devices.size(), device_cfg->platform_id));
}
device = devices[device_cfg->device_id];
} else {
const auto host_cfg = cfg.get_host_config();

if constexpr(!device_selector_provided) {
// Try to find a unique GPU per node.
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)) {
if(try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
} else {
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
// Just use the first available device. Prefer GPUs, but settle for anything.
if(!try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)
&& !try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
throw std::runtime_error("Automatic device selection failed: No device available");
}
}
}
} else {
// Try to find a unique device per node using a selector.
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
CELERITY_WARN("No suitable platform found that can provide {} devices that match the specified device selector, and "
"CELERITY_DEVICES not set",
host_cfg.node_count);
// Use the first available device according to the selector, but fails if no such device is found.
if(!try_find_one_device(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
throw std::runtime_error("Device selection with device selector failed: No device available");
}
}
}
}
}

const auto platform_name = device.get_platform().template get_info<sycl::info::platform::name>();
const auto device_name = device.template get_info<sycl::info::device::name>();
CELERITY_INFO("Using platform '{}', device '{}' ({})", platform_name, device_name, how_selected);

return device;
}

} // namespace detail
} // namespace celerity
20 changes: 15 additions & 5 deletions include/distr_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <type_traits>

#include "device_queue.h"
#include "runtime.h"
#include "task_manager.h"

Expand All @@ -25,10 +26,19 @@ inline constexpr allow_by_ref_t allow_by_ref{};

class distr_queue {
public:
distr_queue() { init(nullptr); }
distr_queue(cl::sycl::device& device) {
distr_queue() { init(detail::auto_select_device{}); }

[[deprecated("Use the overload with device selector instead, this will be removed in future release")]] distr_queue(cl::sycl::device& device) {
if(detail::runtime::is_initialized()) { throw std::runtime_error("Passing explicit device not possible, runtime has already been initialized."); }
init(&device);
init(device);
}

template <typename DeviceSelector>
distr_queue(const DeviceSelector& device_selector) {
if(detail::runtime::is_initialized()) {
throw std::runtime_error("Passing explicit device selector not possible, runtime has already been initialized.");
}
init(device_selector);
}

distr_queue(const distr_queue&) = default;
Expand Down Expand Up @@ -77,8 +87,8 @@ class distr_queue {
private:
std::shared_ptr<detail::distr_queue_tracker> tracker;

void init(cl::sycl::device* user_device) {
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr, user_device); }
void init(detail::device_or_selector device_or_selector) {
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr, device_or_selector); }
try {
detail::runtime::get_instance().startup();
} catch(detail::runtime_already_started_error&) {
Expand Down
7 changes: 4 additions & 3 deletions include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ namespace detail {

public:
/**
* @param user_device This optional device can be provided by the user, overriding any other device selection strategy.
* @param user_device_or_selector This optional device (overriding any other device selection strategy) or device selector can be provided by the user.
*/
static void init(int* argc, char** argv[], cl::sycl::device* user_device = nullptr);
static void init(int* argc, char** argv[], device_or_selector user_device_or_selector = auto_select_device{});

static bool is_initialized() { return instance != nullptr; }
static runtime& get_instance();

Expand Down Expand Up @@ -117,7 +118,7 @@ namespace detail {
};
std::deque<flush_handle> active_flushes;

runtime(int* argc, char** argv[], cl::sycl::device* user_device = nullptr);
runtime(int* argc, char** argv[], device_or_selector user_device_or_selector);
runtime(const runtime&) = delete;
runtime(runtime&&) = delete;

Expand Down
12 changes: 0 additions & 12 deletions src/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,6 @@ namespace detail {
const auto result = get_env("CELERITY_FORCE_WG");
if(result.first) { CELERITY_WARN("Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0."); }
}

// -------------------------------- CELERITY_HOST_CPUS --------------------------------

{
host_cfg.local_num_cpus = std::thread::hardware_concurrency();
const auto result = get_env("CELERITY_HOST_CPUS");
if(result.first) {
const auto parsed = parse_uint(result.second.c_str());
if(parsed.first) { host_cfg.local_num_cpus = parsed.second; }
}
}
}

} // namespace detail
} // namespace celerity
Loading