Skip to content

Commit

Permalink
cpu: threadpool: add setters/getters for max_concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
mgouicem committed Oct 18, 2022
1 parent 25ccee3 commit 8a1e959
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 20 deletions.
22 changes: 21 additions & 1 deletion include/oneapi/dnnl/dnnl_threadpool.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020 Intel Corporation
* Copyright 2020-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,6 +59,26 @@ dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_create(
dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_get_threadpool(
dnnl_stream_t astream, void **threadpool);

/// Sets the maximum concurrency assumed by oneDNN when outside a
/// parallel call.
///
/// @param max_concurrency. The maximum concurrency assumed by oneDNN
/// when outside a parallel call. This is a threadlocal setting.
/// @returns #dnnl_success on success and a status describing the
/// error otherwise.
dnnl_status_t DNNL_API dnnl_threadpool_interop_set_max_concurrency(
int max_concurrency);

/// Gets the maximum concurrency assumed by oneDNN when outside a
/// parallel call.
///
/// @param max_concurrency. The maximum concurrency assumed by oneDNN
/// when outside a parallel call. This is a threadlocal setting.
/// @returns #dnnl_success on success and a status describing the
/// error otherwise.
dnnl_status_t DNNL_API dnnl_threadpool_interop_get_max_concurrency(
int *max_concurrency);

/// @copydoc dnnl_sgemm()
/// @param threadpool A pointer to a threadpool interface (only when built with
/// the THREADPOOL CPU runtime).
Expand Down
28 changes: 10 additions & 18 deletions src/common/dnnl_thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,23 @@ void deactivate_threadpool();
// Returns the active threadpool for the calling thread.
dnnl::threadpool_interop::threadpool_iface *get_active_threadpool();

// returns the maximum concurrency available in the given global context
int get_max_concurrency();

} // namespace threadpool_utils
} // namespace impl
} // namespace dnnl

inline int dnnl_get_max_threads() {
using namespace dnnl::impl::threadpool_utils;
dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool();
// This is the maximum number of threads oneDNN would use
static int def_max_threads = 0;
// get_max_threads_to_use() will return the number of physical cores in a
// socket. If running in a VM, a limited number of cores will be used (e.g.,
// 4 or 8) depending on the configuration of the cpuid mask. It is expected
// that the number of threads in user's threadpool will not exceed this
// value.
static std::once_flag initialization_flag_;
std::call_once(initialization_flag_, [&] {
def_max_threads
= (int)dnnl::impl::cpu::platform::get_max_threads_to_use();
assert(def_max_threads > 0);
});

// Make user responsible for number of threads provided at execution time.
// This relates to the fact that the library may identify `def_max_threads`
// incorrectly for a platform.
return tp ? std::max(1, tp->get_num_threads()) : def_max_threads;

// This is the maximum number of threads oneDNN would use by default
int max_concurrency = dnnl::impl::threadpool_utils::get_max_concurrency();

// Use the default max_concurrency only when no tp is passed by
// user (e.g. primitive creation).
return tp ? std::max(1, tp->get_num_threads()) : max_concurrency;
}
inline int dnnl_in_parallel() {
using namespace dnnl::impl::threadpool_utils;
Expand Down
63 changes: 63 additions & 0 deletions src/common/dnnl_threadpool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "oneapi/dnnl/dnnl_config.h"

#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL

#include <mutex>
#include "dnnl_threadpool.h"

#include "c_types_map.hpp"
#include "cpu/platform.hpp"
#include "dnnl_thread.hpp"
#include "utils.hpp"

namespace dnnl {
namespace impl {
namespace threadpool_utils {

int DNNL_API &get_threadlocal_max_concurrency() {
thread_local int max_concurrency
= (int)cpu::platform::get_max_threads_to_use();
assert(max_concurrency > 0);
return max_concurrency;
}

int DNNL_API get_max_concurrency() {
return get_threadlocal_max_concurrency();
}

} // namespace threadpool_utils
} // namespace impl
} // namespace dnnl

dnnl_status_t dnnl_threadpool_interop_set_max_concurrency(int max_concurrency) {
using namespace dnnl::impl;
threadpool_utils::get_threadlocal_max_concurrency() = max_concurrency;
return status::success;
}

dnnl_status_t dnnl_threadpool_interop_get_max_concurrency(
int *max_concurrency) {
using namespace dnnl::impl;
if (max_concurrency == nullptr) return status::invalid_arguments;

*max_concurrency = threadpool_utils::get_threadlocal_max_concurrency();
return status::success;
}

#endif
3 changes: 3 additions & 0 deletions tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ if(NOT DNNL_CPU_RUNTIME STREQUAL "NONE")
test_convolution_format_any.cpp
test_global_scratchpad.cpp
)
if(DNNL_CPU_RUNTIME STREQUAL "THREADPOOL")
list(APPEND CPU_SPECIFIC_TESTS test_iface_threadpool.cpp)
endif()
foreach(TEST_FILE ${CPU_SPECIFIC_TESTS})
list(APPEND PRIM_TEST_CASES_SRC "${TEST_FILE}")
endforeach()
Expand Down
74 changes: 74 additions & 0 deletions tests/gtests/test_iface_threadpool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "dnnl_test_common.hpp"
#include "gtest/gtest.h"

#include "oneapi/dnnl/dnnl.hpp"
#include "oneapi/dnnl/dnnl_threadpool.h"
#include "tests/test_isa_common.hpp"

namespace dnnl {

class threadpool_test_t : public ::testing::Test {
protected:
void SetUp() override {}
};

void test_threadpool_maxconcurrency_st(dnnl_status_t &res) {
int tid = std::hash<std::thread::id> {}(std::this_thread::get_id()) % 23;
tid++; // to avoid zeros.

auto multipliers = {1, 5, 7, 12, 24, 56};
for (auto m : multipliers) {
dnnl_status_t st = dnnl_success;

int expected = tid * m % 21;
st = dnnl_threadpool_interop_set_max_concurrency(expected);
if (st != dnnl_success) {
res = st;
return;
}

int obtained = 0;
st = dnnl_threadpool_interop_get_max_concurrency(&obtained);
if (st != dnnl_success) {
res = st;
return;
}

if (expected != obtained) {
res = dnnl_runtime_error;
return;
}
}
res = dnnl_success;
}

TEST_F(threadpool_test_t, TestMaxConcurrencyConcurrent) {
const int nthreads = 100;
std::vector<std::thread> threads;
std::vector<dnnl_status_t> results(nthreads);
for (int i = 0; i <= nthreads; i++)
threads.emplace_back(
test_threadpool_maxconcurrency_st, std::ref(results[i]));
for (auto &t : threads)
t.join();
for (auto &r : results)
ASSERT_EQ(r, dnnl_success);
}

} // namespace dnnl
7 changes: 7 additions & 0 deletions tests/test_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ void deactivate_threadpool() {}
dnnl::threadpool_interop::threadpool_iface *get_active_threadpool() {
return testing::get_threadpool();
}

// here we return 0 so that parallel* calls use the
// default number of threads in the threadpool.
int get_max_concurrency() {
return 0;
}

} // namespace testing_threadpool_utils

} // namespace impl
Expand Down
3 changes: 2 additions & 1 deletion tests/test_thread.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
* Copyright 2020-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -73,6 +73,7 @@ namespace threadpool_utils {
void activate_threadpool(dnnl::threadpool_interop::threadpool_iface *tp);
void deactivate_threadpool();
dnnl::threadpool_interop::threadpool_iface *get_active_threadpool();
int get_max_concurrency();
} // namespace threadpool_utils
} // namespace impl

Expand Down

0 comments on commit 8a1e959

Please sign in to comment.