diff --git a/include/gridtools/fn/backend/gpu.hpp b/include/gridtools/fn/backend/gpu.hpp index ee2c8c02f..eca8dd435 100644 --- a/include/gridtools/fn/backend/gpu.hpp +++ b/include/gridtools/fn/backend/gpu.hpp @@ -39,6 +39,9 @@ namespace gridtools::fn::backend { /* * ThreadBlockSizes and LoopBlockSizes must be meta maps, mapping dimensions to integral constant block sizes. * + * ThreadBlockSizes defines how many GPU threads are employed inside a GPU thread block along each dimension. + * LoopBlockSizes defines how many consecutive elements along each dimension a single thread works on. + * * For example, meta::list>, * meta::list>, * meta::list>>; @@ -64,6 +67,7 @@ namespace gridtools::fn::backend { using block_sizes_for_sizes = hymap::from_meta_map::template apply, get_keys>>; + // helper function to compute the initial global index (where all loops start) of a specific GPU thread struct global_thread_index_f { template GT_FUNCTION_DEVICE static constexpr int index_at_dim(Index const &idx) { @@ -78,9 +82,11 @@ namespace gridtools::fn::backend { template GT_FUNCTION_DEVICE constexpr auto operator()(ThreadBlockSize, LoopBlockSize) const { if constexpr (I < 3) { + // use GPU block and thread indices for the first three dimensions return index_at_dim(blockIdx) * (ThreadBlockSize::value * LoopBlockSize::value) + index_at_dim(threadIdx) * LoopBlockSize::value; } else { + // higher dimensions are always fully looped-over, so the loop start index is always zero return integral_constant(); } // disable incorrect warning "missing return statement at end of non-void function" @@ -89,16 +95,21 @@ namespace gridtools::fn::backend { GT_NVCC_DIAG_POP_SUPPRESS(940) }; + // helper function to compute the effective (possibly clamped) block size struct block_size_f { template GT_FUNCTION_DEVICE constexpr auto operator()( GlobalThreadIndex global_thread_index, LoopBlockSize, Size size) const { if constexpr (I < 3) { + // on the first three dimensions, the loops can effectively be blocked if constexpr (LoopBlockSize::value == 1) + // block size is known to be compile-time constant 1 if we have unit block size return integral_constant(); else + // larger block sizes have to be clamped at run time at the end of the domain return std::clamp(size - global_thread_index, 0, int(LoopBlockSize::value)); } else { + // higher dimensions are always fully looped-over, so loop blocking is ignored return size; } // disable incorrect warning "missing return statement at end of non-void function"