Skip to content

Commit

Permalink
revert: threadpool: make user's choice of threads in runtime as primary
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 10, 2021
1 parent f1c2f9f commit 0af92ec
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/common/dnnl_thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ inline int dnnl_get_max_threads() {
assert(def_max_threads > 0);
});

// Make user responsible for number of threads provided at execution time.
// This relates to the fact that the library may identify `def_max_threads`
// incorrectly for a platform.
return tp ? std::max(1, tp->get_num_threads()) : def_max_threads;
// Use the default value if the threadpool-provided is outside the range
// [1, def_max_threads]
return tp ? std::min(std::max(1, tp->get_num_threads()), def_max_threads)
: def_max_threads;
}
inline int dnnl_in_parallel() {
using namespace dnnl::impl::threadpool_utils;
Expand Down
2 changes: 0 additions & 2 deletions src/cpu/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ unsigned get_num_cores() {
// order to simulate the number of cores available in such environment, this
// function supports process affinity.
unsigned get_max_threads_to_use() {
// TODO: the logic below should involve number of sockets to provide exact
// number of cores on 2+ socket systems.
int num_cores_per_socket = (int)dnnl::impl::cpu::platform::get_num_cores();
// It may happen that XByak doesn't get num of threads identified, e.g. for
// AMD. In order to make threadpool working, we supply an additional
Expand Down
14 changes: 7 additions & 7 deletions tests/test_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL

#include <algorithm>
#include <mutex>

#ifdef _WIN32
Expand Down Expand Up @@ -48,11 +47,14 @@ inline int read_num_threads_from_env() {
env_num_threads = ::getenv(env_var_name);
#endif

int num_threads = (int)dnnl::impl::cpu::platform::get_max_threads_to_use();
int num_threads = 0;
if (env_num_threads) {
char *endp;
int nt = std::max(1L, strtol(env_num_threads, &endp, 10));
if (*endp == '\0') num_threads = std::min(nt, num_threads);
int nt = strtol(env_num_threads, &endp, 10);
if (*endp == '\0') num_threads = nt;
}
if (num_threads <= 0) {
num_threads = (int)dnnl::impl::cpu::platform::get_max_threads_to_use();
}
return num_threads;
}
Expand Down Expand Up @@ -116,16 +118,14 @@ class threadpool : public dnnl::threadpool_interop::threadpool_iface {
#include "tbb/parallel_for.h"
#include "tbb/task_arena.h"

#include "src/cpu/platform.hpp"
namespace dnnl {
namespace testing {

class threadpool : public dnnl::threadpool_interop::threadpool_iface {
public:
explicit threadpool(int num_threads = 0) { (void)num_threads; }
int get_num_threads() const override {
return std::min(tbb::this_task_arena::max_concurrency(),
(int)dnnl::impl::cpu::platform::get_max_threads_to_use());
return tbb::this_task_arena::max_concurrency();
}
bool get_in_parallel() const override { return 0; }
uint64_t get_flags() const override { return 0; }
Expand Down

0 comments on commit 0af92ec

Please sign in to comment.