Skip to content

Commit

Permalink
Refactor corresponding to requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
almightyvats committed Apr 11, 2022
1 parent 8387b34 commit cb0b74e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 86 deletions.
2 changes: 1 addition & 1 deletion include/celerity.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace runtime {
*
* @param device The device to be used on the current node. This can vary between nodes.
*/
[[deprecated("Use the variant with device selector instead, this will be removed in future release")]] inline void init(
[[deprecated("Use the overload with device selector instead, this will be removed in future release")]] inline void init(
int* argc, char** argv[], cl::sycl::device& device) {
detail::runtime::init(argc, argv, device);
}
Expand Down
186 changes: 102 additions & 84 deletions include/device_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,106 @@ namespace detail {
void handle_async_exceptions(cl::sycl::exception_list el) const;
};

template <typename DeviceT, typename PlatformT, typename SelectorT>
bool try_find_device_per_node(std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg,
SelectorT selector, cl::sycl::info::device_type type) {
// Try to find a platform that can provide a unique device for each node.
std::vector<std::tuple<DeviceT, size_t>> devices_with_platform_idx;
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
if(platform.get_devices().size() >= host_cfg.node_count) {
for(auto device : platform.get_devices()) {
if(selector(device) == -1) { continue; }
devices_with_platform_idx.emplace_back(device, i);
}
}
}

std::sort(devices_with_platform_idx.begin(), devices_with_platform_idx.end(),
[selector](const auto& a, const auto& b) { return selector(std::get<0>(a)) > selector(std::get<0>(b)); });
bool same_platform = true;
if(devices_with_platform_idx.size() >= host_cfg.node_count) {
auto [device_from_platform, idx] = devices_with_platform_idx[0];
auto platform_name = device_from_platform.get_platform().template get_info<cl::sycl::info::platform::name>();

for(size_t i = 1; i < host_cfg.node_count; ++i) {
auto [device_from_platform, idx] = devices_with_platform_idx[i];
if(device_from_platform.get_platform().template get_info<cl::sycl::info::platform::name>() != platform_name) { same_platform = false; }
}

if(same_platform) {
auto [device_from_platform, idx] = devices_with_platform_idx[host_cfg.local_rank];
how_selected = fmt::format("device selector specified: platform {}, device {}", idx, host_cfg.local_rank);
device = device_from_platform;
return true;
}
}

return false;
}

template <typename DeviceT, typename PlatformT>
bool try_find_device_per_node(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, cl::sycl::info::device_type type) {
// Try to find a platform that can provide a unique device for each node.
std::vector<std::tuple<DeviceT, size_t>> devices_with_platform_idx;
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
std::vector<DeviceT> platform_devices;

platform_devices = platform.get_devices(type);
if(platform_devices.size() >= host_cfg.node_count) {
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
device = platform_devices[host_cfg.local_rank];
return true;
}
}

return false;
}

template <typename DeviceT, typename PlatformT, typename SelectorT>
bool try_find_one_device(std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg,
SelectorT selector, cl::sycl::info::device_type type) {
std::vector<DeviceT> platform_devices;
for(auto& p : platforms) {
auto p_devices = p.get_devices();
platform_devices.insert(platform_devices.end(), p_devices.begin(), p_devices.end());
}

std::sort(platform_devices.begin(), platform_devices.end(), [selector](const auto& a, const auto& b) { return selector(a) > selector(b); });
if(!platform_devices.empty()) {
if(selector(platform_devices[0]) == -1) { return false; }
device = platform_devices[0];
return true;
}

return false;
};

template <typename DeviceT, typename PlatformT>
bool try_find_one_device(
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, cl::sycl::info::device_type type) {
for(auto& p : platforms) {
for(auto& d : p.get_devices(type)) {
device = d;
return true;
}
}

return false;
};


template <typename DevicePtrOrSelector, typename PlatformT>
auto pick_device(const config& cfg, const DevicePtrOrSelector& user_device_or_selector, const std::vector<PlatformT>& platforms) {
using DeviceT = typename std::result_of<decltype (&PlatformT::get_devices)(PlatformT, cl::sycl::info::device_type)>::type::value_type;
using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;

constexpr bool user_device_provided = std::is_same_v<DevicePtrOrSelector, DeviceT>;
constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicePtrOrSelector, DeviceT>;
constexpr bool auto_select = std::is_same_v<auto_select_device, DevicePtrOrSelector>;
static_assert(user_device_provided ^ device_selector_provided ^ auto_select);
static_assert(
user_device_provided ^ device_selector_provided ^ auto_select, "pick_device requires either a device, a selector, or the auto_select_device tag");

DeviceT device;
std::string how_selected = "automatically selected";
Expand All @@ -98,102 +190,28 @@ namespace detail {
} else {
const auto host_cfg = cfg.get_host_config();

const auto try_find_device_per_node = [&](cl::sycl::info::device_type type) {
// Try to find a platform that can provide a unique device for each node.
std::vector<std::tuple<DeviceT, size_t>> devices_with_platform_idx;
for(size_t i = 0; i < platforms.size(); ++i) {
auto&& platform = platforms[i];
std::vector<DeviceT> platform_devices;
if constexpr(device_selector_provided) {
platform_devices = platform.get_devices();
if(platform_devices.size() >= host_cfg.node_count) {
for(auto device : platform_devices) {
if(user_device_or_selector(device) == -1) { continue; }
devices_with_platform_idx.emplace_back(device, i);
}
}
} else {
platform_devices = platform.get_devices(type);
if(platform_devices.size() >= host_cfg.node_count) {
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
device = platform_devices[host_cfg.local_rank];
return true;
}
}
}

if constexpr(device_selector_provided) {
std::sort(devices_with_platform_idx.begin(), devices_with_platform_idx.end(), [&user_device_or_selector](const auto& a, const auto& b) {
return user_device_or_selector(std::get<0>(a)) > user_device_or_selector(std::get<0>(b));
});
bool same_platform = true;
if(devices_with_platform_idx.size() >= host_cfg.node_count) {
auto [device_from_platform, idx] = devices_with_platform_idx[0];
auto platform_name = device_from_platform.get_platform().template get_info<cl::sycl::info::platform::name>();

for(size_t i = 1; i < host_cfg.node_count; ++i) {
auto [device_from_platform, idx] = devices_with_platform_idx[i];
if(device_from_platform.get_platform().template get_info<cl::sycl::info::platform::name>() != platform_name) {
same_platform = false;
}
}

if(same_platform) {
auto [device_from_platform, idx] = devices_with_platform_idx[host_cfg.local_rank];
how_selected = fmt::format("device selector specified: platform {}, device {}", idx, host_cfg.local_rank);
device = device_from_platform;
return true;
}
}
}

return false;
};

const auto try_find_one_device = [&](cl::sycl::info::device_type type) {
std::vector<DeviceT> platform_devices;
for(auto& p : platforms) {
if constexpr(device_selector_provided) {
auto p_devices = p.get_devices();
platform_devices.insert(platform_devices.end(), p_devices.begin(), p_devices.end());
} else
for(auto& d : p.get_devices(type)) {
device = d;
return true;
}
}
if constexpr(device_selector_provided) {
std::sort(platform_devices.begin(), platform_devices.end(),
[&user_device_or_selector](const auto& a, const auto& b) { return user_device_or_selector(a) > user_device_or_selector(b); });
if(!platform_devices.empty()) {
if(user_device_or_selector(platform_devices[0]) == -1) { return false; }
device = platform_devices[0];
return true;
}
}
return false;
};

if constexpr(!device_selector_provided) {
// Try to find a unique GPU per node.
if(!try_find_device_per_node(cl::sycl::info::device_type::gpu)) {
// Try to find a unique device (of any type) per node.
if(try_find_device_per_node(cl::sycl::info::device_type::all)) {
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, cl::sycl::info::device_type::gpu)) {
if(try_find_device_per_node(how_selected, device, platforms, host_cfg, cl::sycl::info::device_type::all)) {
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
} else {
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
// Just use the first available device. Prefer GPUs, but settle for anything.
if(!try_find_one_device(cl::sycl::info::device_type::gpu) && !try_find_one_device(cl::sycl::info::device_type::all)) {
if(!try_find_one_device(how_selected, device, platforms, host_cfg, cl::sycl::info::device_type::gpu)
&& !try_find_one_device(how_selected, device, platforms, host_cfg, cl::sycl::info::device_type::all)) {
throw std::runtime_error("Automatic device selection failed: No device available");
}
}
}
} else {
if(!try_find_device_per_node(cl::sycl::info::device_type::all)) {
// Try to find a unique device per node using a selector.
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, user_device_or_selector, cl::sycl::info::device_type::all)) {
CELERITY_WARN("No suitable platform found that can provide {} devices that matches the specified device selector, and "
"CELERITY_DEVICES not set",
host_cfg.node_count);
if(!try_find_one_device(cl::sycl::info::device_type::all)) {
// Use the first available device according to the selector, but fails if no such device is found.
if(!try_find_one_device(how_selected, device, platforms, host_cfg, user_device_or_selector, cl::sycl::info::device_type::all)) {
throw std::runtime_error("Device selection with device selector failed: No device available");
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ namespace detail {
CELERITY_INFO(
"Celerity runtime version {} running on {}. PID = {}, build type = {}", get_version_string(), get_sycl_version(), get_pid(), get_build_type());
d_queue->init(*cfg, user_device_or_selector);
} // namespace celerity
}

runtime::~runtime() {
if(is_master_node()) {
Expand Down

0 comments on commit cb0b74e

Please sign in to comment.