diff --git a/src/common/threading_utils.cc b/src/common/threading_utils.cc index 349cc0ba7348..5e730e96d34e 100644 --- a/src/common/threading_utils.cc +++ b/src/common/threading_utils.cc @@ -3,14 +3,23 @@ */ #include "threading_utils.h" -#include -#include +#include // for max +#include // for exception +#include // for path, exists +#include // for ifstream +#include // for string -#include "xgboost/logging.h" +#include "common.h" // for DivRoundUp -namespace xgboost { -namespace common { -int32_t GetCfsCPUCount() noexcept { +namespace xgboost::common { +/** + * Modified from + * github.com/psiha/sweater/blob/master/include/boost/sweater/hardware_concurrency.hpp + * + * MIT License: Copyright (c) 2016 Domagoj Šarić + */ +std::int32_t GetCGroupV1Count(std::filesystem::path const& quota_path, + std::filesystem::path const& peroid_path) { #if defined(__linux__) // https://bugs.openjdk.java.net/browse/JDK-8146115 // http://hg.openjdk.java.net/jdk/hs/rev/7f22774a5f42 @@ -31,8 +40,8 @@ int32_t GetCfsCPUCount() noexcept { } }; // complete fair scheduler from Linux - auto const cfs_quota(read_int("/sys/fs/cgroup/cpu/cpu.cfs_quota_us")); - auto const cfs_period(read_int("/sys/fs/cgroup/cpu/cpu.cfs_period_us")); + auto const cfs_quota(read_int(quota_path.c_str())); + auto const cfs_period(read_int(peroid_path.c_str())); if ((cfs_quota > 0) && (cfs_period > 0)) { return std::max(cfs_quota / cfs_period, 1); } @@ -40,6 +49,47 @@ int32_t GetCfsCPUCount() noexcept { return -1; } +std::int32_t GetCGroupV2Count(std::filesystem::path const& bandwidth_path) noexcept(true) { + std::int32_t cnt{-1}; +#if defined(__linux__) + namespace fs = std::filesystem; + + std::int32_t a{0}, b{0}; + + auto warn = [] { LOG(WARNING) << "Invalid cgroupv2 file."; }; + try { + std::ifstream fin{bandwidth_path, std::ios::in}; + fin >> a; + fin >> b; + } catch (std::exception const&) { + warn(); + return cnt; + } + if (a > 0 && b > 0) { + cnt = std::max(common::DivRoundUp(a, b), 1); + } +#endif // defined(__linux__) + return cnt; +} + +std::int32_t GetCfsCPUCount() noexcept { + namespace fs = std::filesystem; + fs::path const bandwidth_path{"/sys/fs/cgroup/cpu.max"}; + auto has_v2 = fs::exists(bandwidth_path); + if (has_v2) { + return GetCGroupV2Count(bandwidth_path); + } + + fs::path const quota_path{"/sys/fs/cgroup/cpu/cpu.cfs_quota_us"}; + fs::path const peroid_path{"/sys/fs/cgroup/cpu/cpu.cfs_period_us"}; + auto has_v1 = fs::exists(quota_path) && fs::exists(peroid_path); + if (has_v1) { + return GetCGroupV1Count(quota_path, peroid_path); + } + + return -1; +} + std::int32_t OmpGetNumThreads(std::int32_t n_threads) { // Don't use parallel if we are in a parallel region. if (omp_in_parallel()) { @@ -54,5 +104,4 @@ std::int32_t OmpGetNumThreads(std::int32_t n_threads) { n_threads = std::max(n_threads, 1); return n_threads; } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 4ca4ca0707d9..ac71190353a7 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -253,11 +253,6 @@ inline std::int32_t OmpGetThreadLimit() { * \brief Get thread limit from CFS. * * This function has non-trivial overhead and should not be called repeatly. - * - * Modified from - * github.com/psiha/sweater/blob/master/include/boost/sweater/hardware_concurrency.hpp - * - * MIT License: Copyright (c) 2016 Domagoj Šarić */ std::int32_t GetCfsCPUCount() noexcept;