diff --git a/include/dlaf/eigensolver/internal/get_red2band_panel_nworkers.h b/include/dlaf/eigensolver/internal/get_red2band_panel_nworkers.h index f269232d18..06b2585b1f 100644 --- a/include/dlaf/eigensolver/internal/get_red2band_panel_nworkers.h +++ b/include/dlaf/eigensolver/internal/get_red2band_panel_nworkers.h @@ -19,7 +19,7 @@ namespace dlaf::eigensolver::internal { -inline size_t getReductionToBandPanelNWorkers() noexcept { +inline size_t get_red2band_panel_nworkers() noexcept { // Note: precautionarily we leave at least 1 thread "free" to do other stuff (if possible) const std::size_t available_workers = pika::resource::get_thread_pool("default").get_os_thread_count(); const std::size_t min_workers = 1; diff --git a/include/dlaf/eigensolver/reduction_to_band/impl.h b/include/dlaf/eigensolver/reduction_to_band/impl.h index e3f4befb0e..8e48b8981b 100644 --- a/include/dlaf/eigensolver/reduction_to_band/impl.h +++ b/include/dlaf/eigensolver/reduction_to_band/impl.h @@ -311,23 +311,28 @@ void computePanelReflectors(MatrixLikeA& mat_a, MatrixLikeTaus& mat_taus, const panel_tiles.emplace_back(matrix::splitTile(mat_a.readwrite(i), spec)); } - const std::size_t nthreads = getReductionToBandPanelNWorkers(); - auto s = - ex::when_all(ex::just(std::make_unique>(nthreads), + const std::size_t nworkers = [nrtiles = panel_tiles.size()]() { + const std::size_t min_workers = 1; + const std::size_t available_workers = get_red2band_panel_nworkers(); + const std::size_t ideal_workers = to_sizet(nrtiles); + return std::clamp(ideal_workers, min_workers, available_workers); + }(); + ex::start_detached( + ex::when_all(ex::just(std::make_unique>(nworkers), std::vector>{}), // w (internally required) mat_taus.readwrite(LocalTileIndex(j_sub, 0)), ex::when_all_vector(std::move(panel_tiles))) | di::continues_on(di::getBackendScheduler(thread_priority::high)) | - ex::bulk(nthreads, [nthreads, cols = panel_view.cols()](const std::size_t index, auto& barrier_ptr, + ex::bulk(nworkers, [nworkers, cols = panel_view.cols()](const std::size_t index, auto& barrier_ptr, auto& w, auto& taus, auto& tiles) { const auto barrier_busy_wait = getReductionToBandBarrierBusyWait(); - const std::size_t batch_size = util::ceilDiv(tiles.size(), nthreads); + const std::size_t batch_size = util::ceilDiv(tiles.size(), nworkers); const std::size_t begin = index * batch_size; const std::size_t end = std::min(index * batch_size + batch_size, tiles.size()); const SizeType nrefls = taus.size().rows(); if (index == 0) { - w.resize(nthreads); + w.resize(nworkers); } for (SizeType j = 0; j < nrefls; ++j) { @@ -357,8 +362,7 @@ void computePanelReflectors(MatrixLikeA& mat_a, MatrixLikeTaus& mat_taus, const updateTrailingPanel(has_head, tiles, j, w[0], taus({j, 0}), begin, end); barrier_ptr->arrive_and_wait(barrier_busy_wait); } - }); - ex::start_detached(std::move(s)); + })); } template @@ -632,27 +636,33 @@ void computePanelReflectors(TriggerSender&& trigger, comm::IndexT_MPI rank_v0, panel_tiles.emplace_back(matrix::splitTile(mat_a.readwrite(i), spec)); } - const std::size_t nthreads = getReductionToBandPanelNWorkers(); - auto s = - ex::when_all(ex::just(std::make_unique>(nthreads), + const std::size_t nworkers = [nrtiles = panel_tiles.size()]() { + const std::size_t min_workers = 1; + const std::size_t available_workers = get_red2band_panel_nworkers(); + const std::size_t ideal_workers = util::ceilDiv(to_sizet(nrtiles), to_sizet(2)); + return std::clamp(ideal_workers, min_workers, available_workers); + }(); + + ex::start_detached( + ex::when_all(ex::just(std::make_unique>(nworkers), std::vector>{}), // w (internally required) mat_taus.readwrite(GlobalTileIndex(j_sub, 0)), ex::when_all_vector(std::move(panel_tiles)), std::forward(mpi_col_chain_panel), std::forward(trigger)) | di::continues_on(di::getBackendScheduler(pika::execution::thread_priority::high)) | - ex::bulk(nthreads, [nthreads, rank_v0, + ex::bulk(nworkers, [nworkers, rank_v0, cols = panel_view.cols()](const std::size_t index, auto& barrier_ptr, auto& w, auto& taus, auto& tiles, auto&& pcomm) { const bool rankHasHead = rank_v0 == pcomm.get().rank(); const auto barrier_busy_wait = getReductionToBandBarrierBusyWait(); - const std::size_t batch_size = util::ceilDiv(tiles.size(), nthreads); + const std::size_t batch_size = util::ceilDiv(tiles.size(), nworkers); const std::size_t begin = index * batch_size; const std::size_t end = std::min(index * batch_size + batch_size, tiles.size()); const SizeType nrefls = taus.size().rows(); if (index == 0) { - w.resize(nthreads); + w.resize(nworkers); } for (SizeType j = 0; j < nrefls; ++j) { @@ -685,8 +695,7 @@ void computePanelReflectors(TriggerSender&& trigger, comm::IndexT_MPI rank_v0, updateTrailingPanel(has_head, tiles, j, w[0], taus({j, 0}), begin, end); barrier_ptr->arrive_and_wait(barrier_busy_wait); } - }); - ex::start_detached(std::move(s)); + })); } template