From f5e0a1709f8e71a5b3a849dd1960436c2ff26f47 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Thu, 13 Jun 2024 12:50:31 +0100 Subject: [PATCH 1/5] Add workgroup-level tile --- examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 5 +++-- include/cutlass/gemm/collective/intel_pvc_mma.hpp | 2 +- include/cutlass/gemm/kernel/intel_pvc_gemm.hpp | 15 ++++++--------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 542bafb346..731edfa15f 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -353,11 +353,12 @@ int main(int argc, const char** argv) using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; - using TileShape = Shape<_1, _1, _1>; + // Workgroup-level tile + using TileShape = Shape<_32, _256, _32>; using TiledMma = TiledMMA, Layout>, - Tile<_32,_64,_32>>; + Tile<_32,_64,_32>>; // Subgroup level-tile using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index d587fbcd9d..8f2e4cb34e 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -103,7 +103,7 @@ struct CollectiveMma< using DpasShape = typename TiledMma::Shape_MNK; using TileDpasShape = decltype(tile_shape(TiledMma())); - static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape()); + static constexpr uint32_t MaxThreadsPerBlock = cute::size(TileShape{}) / cute::size(TileDpasShape{}) * SubgroupSize; static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 1a91854374..db9aa03514 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -101,14 +101,10 @@ class GemmUniversal< static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; - static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor; - - static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group using DpasShape = typename CollectiveMainloop::DpasShape; using TileDpasShape = typename CollectiveMainloop::TileDpasShape; - static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -182,9 +178,9 @@ class GemmUniversal< const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments return dim3( - sg_m, - cute::ceil_div(sg_n, num_sg), - batch_count + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + batch_count ); } @@ -218,9 +214,10 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); - auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto workgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) const int m_coord = BlockIdxX() * get<0>(subgroup_shape); - const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape); + const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0), From 8c72fd5da61baa19a88d4dd948256fdb03fcb886 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Thu, 13 Jun 2024 13:45:55 +0100 Subject: [PATCH 2/5] Rename tile shapes --- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 18 +++++++++------- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 21 ++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 8f2e4cb34e..0ee08a0abc 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -81,7 +81,7 @@ struct CollectiveMma< // Type Aliases // using DispatchPolicy = MainloopIntelPVCUnpredicated; - using TileShape = TileShape_; + using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; using ElementB = ElementB_; @@ -101,13 +101,14 @@ struct CollectiveMma< static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; using DpasShape = typename TiledMma::Shape_MNK; - using TileDpasShape = decltype(tile_shape(TiledMma())); + using SubgroupTileShape = decltype(tile_shape(TiledMma())); - static constexpr uint32_t MaxThreadsPerBlock = cute::size(TileShape{}) / cute::size(TileDpasShape{}) * SubgroupSize; + static constexpr uint32_t MaxThreadsPerBlock = + cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; - static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group - static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape()); + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(DpasShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(DpasShape()); // B frags per sub_group + static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(DpasShape()); // Calculate the vector width based on the amount of registers // required per work item by dividing the total fragment size by @@ -186,8 +187,9 @@ struct CollectiveMma< static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape(TileDpasShape{}) * FragsK>, Int<1>>{}); - Tensor tBr = make_tensor(Shape(TileDpasShape{}) / FragsN>, Int>{}); + Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, Int<1>>{}); + Tensor tBr = make_tensor( + Shape(SubgroupTileShape{}) / FragsN>, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index db9aa03514..1f823297de 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -65,7 +65,8 @@ class GemmUniversal< // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; using TiledMma = typename CollectiveMainloop::TiledMma; using ArchTag = typename CollectiveMainloop::ArchTag; using ElementA = typename CollectiveMainloop::ElementA; @@ -81,7 +82,7 @@ class GemmUniversal< "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, + TileScheduler_, ArchTag, WorkgroupTileShape, cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; @@ -103,7 +104,7 @@ class GemmUniversal< static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; using DpasShape = typename CollectiveMainloop::DpasShape; - using TileDpasShape = typename CollectiveMainloop::TileDpasShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -174,12 +175,12 @@ class GemmUniversal< auto M = get<0>(params.problem_shape); auto N = get<1>(params.problem_shape); - const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments - const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments + const int sg_m = (M - 1) / get<0>(SubgroupTileShape{}) + 1; // sub_groups required to process A fragments + const int sg_n = (N - 1) / get<1>(SubgroupTileShape{}) + 1; // sub_groups required to process B fragments return dim3( - cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), - cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), batch_count ); } @@ -196,7 +197,7 @@ class GemmUniversal< (void)smem_buf; // Preconditions - CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) @@ -214,8 +215,8 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); - constexpr auto workgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) - constexpr auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) const int m_coord = BlockIdxX() * get<0>(subgroup_shape); const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); From f18cb1f542f9dc785f6dfc8f3758f967d4b80320 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Thu, 13 Jun 2024 15:36:11 +0100 Subject: [PATCH 3/5] Rename mma shape --- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 16 ++++++++-------- include/cutlass/gemm/kernel/intel_pvc_gemm.hpp | 18 ++++++++++-------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 0ee08a0abc..c552ee8616 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -100,22 +100,22 @@ struct CollectiveMma< static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - using DpasShape = typename TiledMma::Shape_MNK; + using MmaAtomShape = typename TiledMma::AtomShape_MNK; using SubgroupTileShape = decltype(tile_shape(TiledMma())); static constexpr uint32_t MaxThreadsPerBlock = cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; - static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(DpasShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(DpasShape()); // B frags per sub_group - static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(DpasShape()); + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(MmaAtomShape()); // Calculate the vector width based on the amount of registers // required per work item by dividing the total fragment size by // the sub_group size. - static constexpr int VecC = (get<1>(DpasShape()) * get<0>(DpasShape())) / SubgroupSize; - static constexpr int VecA = (get<0>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize; - static constexpr int VecB = (get<1>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize; + static constexpr int VecC = (get<1>(MmaAtomShape()) * get<0>(MmaAtomShape())) / SubgroupSize; + static constexpr int VecA = (get<0>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; + static constexpr int VecB = (get<1>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; // Host side kernel arguments struct Arguments { @@ -202,7 +202,7 @@ struct CollectiveMma< // // Mainloop // - for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(DpasShape()) * FragsK) + for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(MmaAtomShape()) * FragsK) { // Copy gmem to rmem for the first k_tile copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 1f823297de..4f43d99c14 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -103,7 +103,7 @@ class GemmUniversal< static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; - using DpasShape = typename CollectiveMainloop::DpasShape; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; static constexpr int FragsM = CollectiveMainloop::FragsM; @@ -221,13 +221,15 @@ class GemmUniversal< const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); - Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0), - make_shape(_1{}, K, L), - make_stride(Int{} * get<0>(DpasShape()), _1{})); + Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( + make_coord(m_coord, 0, 0), + make_shape(_1{}, K, L), + make_stride(Int{} * get<0>(MmaAtomShape()),_1{})); - Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, 0), - make_shape(K, Int{}, L), - make_stride(_1{}, get<1>(DpasShape()))); + Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( + make_coord(0, n_coord, 0), + make_shape(K, Int{}, L), + make_stride(_1{}, get<1>(MmaAtomShape()))); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord @@ -261,7 +263,7 @@ class GemmUniversal< Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0), make_shape(Int{}, Int{}, L), - make_stride(get<0>(DpasShape()), get<1>(DpasShape()))); + make_stride(get<0>(MmaAtomShape()), get<1>(MmaAtomShape()))); copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); } From c96164e5592e33ff14e0600ab90038bc89bab0bf Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Thu, 13 Jun 2024 15:40:42 +0100 Subject: [PATCH 4/5] Remove unused code --- include/cutlass/gemm/kernel/intel_pvc_gemm.hpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 4f43d99c14..5c9b6d019e 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -172,12 +172,6 @@ class GemmUniversal< batch_count = cute::size<3>(params.problem_shape); } - auto M = get<0>(params.problem_shape); - auto N = get<1>(params.problem_shape); - - const int sg_m = (M - 1) / get<0>(SubgroupTileShape{}) + 1; // sub_groups required to process A fragments - const int sg_n = (N - 1) / get<1>(SubgroupTileShape{}) + 1; // sub_groups required to process B fragments - return dim3( cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), From 64342df33920822b8cdf7e8bef730986ee97cc77 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Tue, 18 Jun 2024 13:58:10 +0100 Subject: [PATCH 5/5] Update benchmark --- ...ench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp index 67b76929db..6d36bb4d40 100644 --- a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp +++ b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp @@ -67,10 +67,8 @@ int main(int argc, const char** argv) // to use a GPU other than that with device ID 0. hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - bool passed; - - // The code section below describes datatype for input, output matrices and computation between - // elements in input matrices. +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A @@ -82,16 +80,20 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + // Workgroup-level tile + using TileShape = Shape<_32, _256, _32>; - using TileShape = Shape<_32, _64, _32>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, + Tile<_32,_64,_32>>; // Subgroup level-tile - using TiledMma = TiledMMA, - Layout>>; + using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // <- data type of output matrix 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized