Skip to content

Commit

Permalink
cpu: x64: ip: brgemm: adjust number of threads to use in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 20, 2021
1 parent ca1eb77 commit ba2e5a9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/cpu/x64/jit_brgemm_inner_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
// not create parallel section at all. We do not limit num_threads
// for 1 < work_amount < dnnl_get_max_threads() case to avoid potential
// overhead on spawning different number of OMP threads from layer to layer.
const int num_threads = (work_amount == 1 ? 1 : 0);
const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr);
parallel(num_threads, [&](const int ithr, const int nthr) {
int nthr_ic {1}, nthr_oc_mb {1}, ithr_ic {0}, ithr_oc_mb {0};
bool ok = init_thr_groups(
Expand Down Expand Up @@ -730,10 +730,11 @@ void brgemm_inner_product_bwd_data_t<isa>::execute_backward_data(

const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking);
const int work_amount = jbgp.nb_ic * os_chunks;
const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr);
if (jbgp.ip_bwd_d_global_b_transpose && jbgp.use_buffer_b) {
assert(IMPLICATION(
jbgp.ip_bwd_d_global_b_transpose, jbgp.nthr_oc_b == 1));
parallel(0, [&](const int ithr, const int nthr) {
parallel(num_threads, [&](const int ithr, const int nthr) {
int start {0}, end {0};
int max_ch_block = nstl::max(jbgp.ic_block, jbgp.oc_block);
int ic_chunk_sz = max_ch_block / jbgp.ic_block;
Expand Down Expand Up @@ -773,7 +774,7 @@ void brgemm_inner_product_bwd_data_t<isa>::execute_backward_data(
});
}

parallel(0, [&](const int ithr, const int nthr) {
parallel(num_threads, [&](const int ithr, const int nthr) {
const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1;
const int nthr_ic_mb = nthr / nthr_oc;
const int ithr_ic_mb = ithr % nthr_ic_mb;
Expand Down Expand Up @@ -820,7 +821,7 @@ void brgemm_inner_product_bwd_data_t<isa>::execute_backward_data(
});

if (jbgp.nthr_oc_b > 1) {
parallel(0, [&](const int ithr, const int nthr) {
parallel(num_threads, [&](const int ithr, const int nthr) {
const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1;
if (nthr_oc <= 1) return;

Expand Down

0 comments on commit ba2e5a9

Please sign in to comment.