diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index e4a38d74f9a..3f43ffc7283 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -286,7 +286,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( // not create parallel section at all. We do not limit num_threads // for 1 < work_amount < dnnl_get_max_threads() case to avoid potential // overhead on spawning different number of OMP threads from layer to layer. - const int num_threads = (work_amount == 1 ? 1 : 0); + const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr); parallel(num_threads, [&](const int ithr, const int nthr) { int nthr_ic {1}, nthr_oc_mb {1}, ithr_ic {0}, ithr_oc_mb {0}; bool ok = init_thr_groups( @@ -730,10 +730,11 @@ void brgemm_inner_product_bwd_data_t::execute_backward_data( const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking); const int work_amount = jbgp.nb_ic * os_chunks; + const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr); if (jbgp.ip_bwd_d_global_b_transpose && jbgp.use_buffer_b) { assert(IMPLICATION( jbgp.ip_bwd_d_global_b_transpose, jbgp.nthr_oc_b == 1)); - parallel(0, [&](const int ithr, const int nthr) { + parallel(num_threads, [&](const int ithr, const int nthr) { int start {0}, end {0}; int max_ch_block = nstl::max(jbgp.ic_block, jbgp.oc_block); int ic_chunk_sz = max_ch_block / jbgp.ic_block; @@ -773,7 +774,7 @@ void brgemm_inner_product_bwd_data_t::execute_backward_data( }); } - parallel(0, [&](const int ithr, const int nthr) { + parallel(num_threads, [&](const int ithr, const int nthr) { const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1; const int nthr_ic_mb = nthr / nthr_oc; const int ithr_ic_mb = ithr % nthr_ic_mb; @@ -820,7 +821,7 @@ void brgemm_inner_product_bwd_data_t::execute_backward_data( }); if (jbgp.nthr_oc_b > 1) { - parallel(0, [&](const int ithr, const int nthr) { + parallel(num_threads, [&](const int ithr, const int nthr) { const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1; if (nthr_oc <= 1) return;